diff --git a/common/current/robot.go b/common/current/robot.go index b68591df..dae5fa69 100644 --- a/common/current/robot.go +++ b/common/current/robot.go @@ -1,19 +1,44 @@ package current -import "go-wechat/model" +import ( + "go-wechat/model" + plugin "go-wechat/plugin" +) -var robotInfo model.RobotUserInfo +// robotInfo +// @description: 机器人信息 +type robotInfo struct { + info model.RobotUserInfo + MessageHandler plugin.MessageHandler // 启用的插件 +} + +// 当前接入的机器人信息 +var ri robotInfo // SetRobotInfo // @description: 设置机器人信息 // @param info func SetRobotInfo(info model.RobotUserInfo) { - robotInfo = info + ri.info = info } // GetRobotInfo // @description: 获取机器人信息 // @return model.RobotUserInfo func GetRobotInfo() model.RobotUserInfo { - return robotInfo + return ri.info +} + +// GetRobotMessageHandler +// @description: 获取机器人插件信息 +// @return robotInfo +func GetRobotMessageHandler() plugin.MessageHandler { + return ri.MessageHandler +} + +// SetRobotMessageHandler +// @description: 设置机器人插件信息 +// @param handler +func SetRobotMessageHandler(handler plugin.MessageHandler) { + ri.MessageHandler = handler } diff --git a/initialization/plugin.go b/initialization/plugin.go new file mode 100644 index 00000000..2efd4011 --- /dev/null +++ b/initialization/plugin.go @@ -0,0 +1,38 @@ +package initialization + +import ( + "go-wechat/common/current" + "go-wechat/model" + plugin "go-wechat/plugin" + "go-wechat/plugin/plugins" +) + +// Plugin +// @description: 初始化插件 +func Plugin() { + // 定义一个处理器 + dispatcher := plugin.NewMessageMatchDispatcher() + // 设置为异步处理 + dispatcher.SetAsync(true) + + // 注册插件 + + // 保存消息进数据库 + dispatcher.RegisterHandler(func(*model.Message) bool { + return true + }, plugins.SaveToDb) + + // AI消息插件 + dispatcher.RegisterHandler(func(m *model.Message) bool { + // 群内@或者私聊文字消息 + return m.IsAt() || m.IsPrivateText() + }, plugins.AI) + + // 欢迎新成员 + dispatcher.RegisterHandler(func(m *model.Message) bool { + return m.IsNewUserJoin() + }, plugins.WelcomeNew) + + // 注册消息处理器 + current.SetRobotMessageHandler(plugin.DispatchMessage(dispatcher)) +} diff --git a/main.go b/main.go index c128d4e3..51cfda90 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( func init() { initialization.InitConfig() // 初始化配置 initialization.InitWechatRobotInfo() // 初始化机器人信息 + initialization.Plugin() // 注册插件 tasks.InitTasks() // 初始化定时任务 } diff --git a/model/message.go b/model/message.go index 43531dfa..1d5cf27f 100644 --- a/model/message.go +++ b/model/message.go @@ -20,6 +20,7 @@ type Message struct { Signature string `json:"signature"` ToUser string `json:"toUser"` Type types.MessageType `json:"type"` + Raw string `json:"raw"` } // systemMsgDataXml @@ -89,3 +90,12 @@ func (m Message) IsNewUserJoin() bool { func (m Message) IsAt() bool { return strings.HasSuffix(m.DisplayFullContent, "在群聊中@了你") } + +// IsPrivateText +// @description: 是否是私聊消息 +// @receiver m +// @return bool +func (m Message) IsPrivateText() bool { + // 发信人不以@chatroom结尾且消息类型为文本 + return !strings.HasSuffix(m.FromUser, "chatroom") && m.Type == types.MsgTypeText +} diff --git a/plugin/plugin.go b/plugin/plugin.go new file mode 100644 index 00000000..1ae565d5 --- /dev/null +++ b/plugin/plugin.go @@ -0,0 +1,143 @@ +package plugin + +import ( + "go-wechat/model" +) + +// MessageHandler 消息处理函数 +type MessageHandler func(msg *model.Message) + +// MessageDispatcher 消息分发处理接口 +// 跟 DispatchMessage 结合封装成 MessageHandler +type MessageDispatcher interface { + Dispatch(msg *model.Message) +} + +// DispatchMessage 跟 MessageDispatcher 结合封装成 MessageHandler +func DispatchMessage(dispatcher MessageDispatcher) func(msg *model.Message) { + return func(msg *model.Message) { dispatcher.Dispatch(msg) } +} + +// MessageDispatcher impl + +// MessageContextHandler 消息处理函数 +type MessageContextHandler func(ctx *MessageContext) + +type MessageContextHandlerGroup []MessageContextHandler + +// MessageContext 消息处理上下文对象 +type MessageContext struct { + index int + abortIndex int + messageHandlers MessageContextHandlerGroup + *model.Message +} + +// Next 主动调用下一个消息处理函数(或开始调用) +func (c *MessageContext) Next() { + c.index++ + for c.index <= len(c.messageHandlers) { + if c.IsAbort() { + return + } + handle := c.messageHandlers[c.index-1] + handle(c) + c.index++ + } +} + +// IsAbort 判断是否被中断 +func (c *MessageContext) IsAbort() bool { + return c.abortIndex > 0 +} + +// Abort 中断当前消息处理, 不会调用下一个消息处理函数, 但是不会中断当前的处理函数 +func (c *MessageContext) Abort() { + c.abortIndex = c.index +} + +// AbortHandler 获取当前中断的消息处理函数 +func (c *MessageContext) AbortHandler() MessageContextHandler { + if c.abortIndex > 0 { + return c.messageHandlers[c.abortIndex-1] + } + return nil +} + +// MatchFunc 消息匹配函数,返回为true则表示匹配 +type MatchFunc func(*model.Message) bool + +// MatchFuncList 将多个MatchFunc封装成一个MatchFunc +func MatchFuncList(matchFuncs ...MatchFunc) MatchFunc { + return func(message *model.Message) bool { + for _, matchFunc := range matchFuncs { + if !matchFunc(message) { + return false + } + } + return true + } +} + +type matchNode struct { + matchFunc MatchFunc + group MessageContextHandlerGroup +} + +type matchNodes []*matchNode + +// MessageMatchDispatcher impl MessageDispatcher interface +// +// dispatcher := NewMessageMatchDispatcher() +// dispatcher.OnText(func(msg *model.Message){ +// msg.ReplyText("hello") +// }) +// bot := DefaultBot() +// bot.MessageHandler = DispatchMessage(dispatcher) +type MessageMatchDispatcher struct { + async bool + matchNodes matchNodes +} + +// NewMessageMatchDispatcher Constructor +func NewMessageMatchDispatcher() *MessageMatchDispatcher { + return &MessageMatchDispatcher{} +} + +// SetAsync 设置是否异步处理 +func (m *MessageMatchDispatcher) SetAsync(async bool) { + m.async = async +} + +// Dispatch impl MessageDispatcher +// 遍历 MessageMatchDispatcher 所有的消息处理函数 +// 获取所有匹配上的函数 +// 执行处理的消息处理方法 +func (m *MessageMatchDispatcher) Dispatch(msg *model.Message) { + var group MessageContextHandlerGroup + for _, node := range m.matchNodes { + if node.matchFunc(msg) { + group = append(group, node.group...) + } + } + ctx := &MessageContext{Message: msg, messageHandlers: group} + if m.async { + go m.do(ctx) + } else { + m.do(ctx) + } +} + +func (m *MessageMatchDispatcher) do(ctx *MessageContext) { + ctx.Next() +} + +// RegisterHandler 注册消息处理函数, 根据自己的需求自定义 +// matchFunc返回true则表示处理对应的handlers +func (m *MessageMatchDispatcher) RegisterHandler(matchFunc MatchFunc, handlers ...MessageContextHandler) { + if matchFunc == nil { + panic("MatchFunc can not be nil") + } + node := &matchNode{matchFunc: matchFunc, group: handlers} + m.matchNodes = append(m.matchNodes, node) +} diff --git a/handler/at_message.go b/plugin/plugins/ai.go similarity index 70% rename from handler/at_message.go rename to plugin/plugins/ai.go index cbc96b81..59eb7864 100644 --- a/handler/at_message.go +++ b/plugin/plugins/ai.go @@ -1,4 +1,4 @@ -package handler +package plugins import ( "context" @@ -9,7 +9,7 @@ import ( "go-wechat/common/current" "go-wechat/config" "go-wechat/entity" - "go-wechat/model" + "go-wechat/plugin" "go-wechat/service" "go-wechat/types" "go-wechat/utils" @@ -19,10 +19,10 @@ import ( "time" ) -// handleAtMessage -// @description: 处理At机器人的消息 +// AI +// @description: AI消息 // @param m -func handleAtMessage(m model.Message) { +func AI(m *plugin.MessageContext) { if !config.Conf.Ai.Enable { return } @@ -54,13 +54,14 @@ func handleAtMessage(m model.Message) { // 查询发信人前面几条文字信息,组装进来 var oldMessages []entity.Message - client.MySQL.Model(&entity.Message{}). - Where("msg_id != ?", m.MsgId). - Where("create_at >= DATE_SUB(NOW(),INTERVAL 30 MINUTE)"). - Where("from_user = ? AND group_user = ? AND display_full_content LIKE ?", m.FromUser, m.GroupUser, "%在群聊中@了你"). - Or("to_user = ? AND group_user = ?", m.FromUser, m.GroupUser). - Order("create_at desc"). - Limit(4).Find(&oldMessages) + if m.GroupUser == "" { + // 私聊 + oldMessages = getUserPrivateMessages(m.FromUser) + } else { + // 群聊 + oldMessages = getGroupUserMessages(m.MsgId, m.FromUser, m.GroupUser) + } + // 翻转数组 slice.Reverse(oldMessages) // 循环填充消息 @@ -74,7 +75,7 @@ func handleAtMessage(m model.Message) { } // 填充消息 role := openai.ChatMessageRoleUser - if message.ToUser != current.GetRobotInfo().WxId { + if message.FromUser == current.GetRobotInfo().WxId { // 如果收信人不是机器人,表示这条消息是 AI 发的 role = openai.ChatMessageRoleAssistant } @@ -129,5 +130,43 @@ func handleAtMessage(m model.Message) { service.SaveMessage(replyMessage) // 保存消息 // 发送消息 - utils.SendMessage(m.FromUser, m.GroupUser, "\n"+resp.Choices[0].Message.Content, 0) + replyMsg := resp.Choices[0].Message.Content + if m.GroupUser != "" { + replyMsg = "\n" + resp.Choices[0].Message.Content + } + utils.SendMessage(m.FromUser, m.GroupUser, replyMsg, 0) +} + +// getGroupUserMessages +// @description: 获取群成员消息 +// @return records +func getGroupUserMessages(msgId int64, groupId, groupUserId string) (records []entity.Message) { + subQuery := client.MySQL. + Where("from_user = ? AND group_user = ? AND display_full_content LIKE ?", groupId, groupUserId, "%在群聊中@了你"). + Or("to_user = ? AND group_user = ?", groupId, groupUserId) + + client.MySQL.Model(&entity.Message{}). + Where("msg_id != ?", msgId). + Where("type = ?", types.MsgTypeText). + Where("create_at >= DATE_SUB(NOW(),INTERVAL 30 MINUTE)"). + Where(subQuery). + Order("create_at desc"). + Limit(4).Find(&records) + return +} + +// getUserPrivateMessages +// @description: 获取用户私聊消息 +// @return records +func getUserPrivateMessages(userId string) (records []entity.Message) { + subQuery := client.MySQL. + Where("from_user = ?", userId).Or("to_user = ?", userId) + + client.MySQL.Model(&entity.Message{}). + Where("type = ?", types.MsgTypeText). + Where("create_at >= DATE_SUB(NOW(),INTERVAL 30 MINUTE)"). + Where(subQuery). + Order("create_at desc"). + Limit(4).Find(&records) + return } diff --git a/plugin/plugins/save2db.go b/plugin/plugins/save2db.go new file mode 100644 index 00000000..18d294b7 --- /dev/null +++ b/plugin/plugins/save2db.go @@ -0,0 +1,27 @@ +package plugins + +import ( + "go-wechat/entity" + "go-wechat/plugin" + "go-wechat/service" + "time" +) + +// SaveToDb +// @description: 保存消息到数据库 +// @param m +func SaveToDb(m *plugin.MessageContext) { + var ent entity.Message + ent.MsgId = m.MsgId + ent.CreateTime = m.CreateTime + ent.CreateAt = time.Unix(int64(m.CreateTime), 0) + ent.Content = m.Content + ent.FromUser = m.FromUser + ent.GroupUser = m.GroupUser + ent.ToUser = m.ToUser + ent.Type = m.Type + ent.DisplayFullContent = m.DisplayFullContent + ent.Raw = m.Raw + // 保存入库 + service.SaveMessage(ent) +} diff --git a/handler/sys_message.go b/plugin/plugins/welconenew.go similarity index 87% rename from handler/sys_message.go rename to plugin/plugins/welconenew.go index 425a87ef..2a89d1c1 100644 --- a/handler/sys_message.go +++ b/plugin/plugins/welconenew.go @@ -1,17 +1,17 @@ -package handler +package plugins import ( "go-wechat/client" "go-wechat/config" "go-wechat/entity" - "go-wechat/model" + "go-wechat/plugin" "go-wechat/utils" ) -// handleNewUserJoin +// WelcomeNew // @description: 欢迎新成员 // @param m -func handleNewUserJoin(m model.Message) { +func WelcomeNew(m *plugin.MessageContext) { // 判断是否开启迎新 var count int64 client.MySQL.Model(&entity.Friend{}).Where("enable_welcome IS TRUE").Where("wxid = ?", m.FromUser).Count(&count) diff --git a/plugins/plugin.go b/plugins/plugin.go deleted file mode 100644 index 5eb7c32f..00000000 --- a/plugins/plugin.go +++ /dev/null @@ -1,10 +0,0 @@ -package plugins - -// Message -// @description: 插件消息 -type Message struct { - GroupId string // 消息来源群Id - UserId string // 消息来源用户Id - Message string // 消息内容 - IsBreak bool // 是否中断消息传递 -} diff --git a/tcpserver/handle.go b/tcpserver/handle.go index a38943d5..98f2338f 100644 --- a/tcpserver/handle.go +++ b/tcpserver/handle.go @@ -3,7 +3,6 @@ package tcpserver import ( "bytes" "go-wechat/config" - "go-wechat/handler" "io" "log" "net" @@ -24,7 +23,7 @@ func process(conn net.Conn) { log.Printf("[%s]返回数据失败,错误信息: %v", conn.RemoteAddr(), err) } log.Printf("[%s]数据长度: %d", conn.RemoteAddr(), buf.Len()) - go handler.Parse(conn.RemoteAddr(), buf.Bytes()) + go parse(conn.RemoteAddr(), buf.Bytes()) // 转发到其他地方去 if len(config.Conf.Wechat.Forward) > 0 { diff --git a/handler/parse.go b/tcpserver/parse.go similarity index 52% rename from handler/parse.go rename to tcpserver/parse.go index c031eaad..981d5436 100644 --- a/handler/parse.go +++ b/tcpserver/parse.go @@ -1,28 +1,28 @@ -package handler +package tcpserver import ( "encoding/json" - "go-wechat/entity" + "go-wechat/common/current" "go-wechat/model" - "go-wechat/service" "go-wechat/types" - "go-wechat/utils" "log" "net" "strings" - "time" ) -// Parse +// parse // @description: 解析消息 // @param msg -func Parse(remoteAddr net.Addr, msg []byte) { +func parse(remoteAddr net.Addr, msg []byte) { var m model.Message if err := json.Unmarshal(msg, &m); err != nil { log.Printf("[%s]消息解析失败: %v", remoteAddr, err) log.Printf("[%s]消息内容: %d -> %v", remoteAddr, len(msg), string(msg)) return } + // 记录原始数据 + m.Raw = string(msg) + // 提取出群成员信息 // Sys类型的消息正文不包含微信 Id,所以不需要处理 if m.IsGroup() && m.Type != types.MsgTypeSys { @@ -38,33 +38,9 @@ func Parse(remoteAddr net.Addr, msg []byte) { } log.Printf("%s\n消息来源: %s\n群成员: %s\n消息类型: %v\n消息内容: %s", remoteAddr, m.FromUser, m.GroupUser, m.Type, m.Content) - // 异步处理消息 - go func() { - if m.IsNewUserJoin() { - log.Printf("%s -> 开始迎新 -> %s", m.FromUser, m.Content) - // 欢迎新成员 - go handleNewUserJoin(m) - } else if m.IsAt() { - // @机器人的消息 - go handleAtMessage(m) - } else if !strings.Contains(m.FromUser, "@") && m.Type == types.MsgTypeText { - // 私聊消息处理 - utils.SendMessage(m.FromUser, "", "暂未开启私聊AI", 0) - } - }() + // 插件不为空,开始执行 + if p := current.GetRobotMessageHandler(); p != nil { + p(&m) + } - // 转换为结构体之后入库 - var ent entity.Message - ent.MsgId = m.MsgId - ent.CreateTime = m.CreateTime - ent.CreateAt = time.Unix(int64(m.CreateTime), 0) - ent.Content = m.Content - ent.FromUser = m.FromUser - ent.GroupUser = m.GroupUser - ent.ToUser = m.ToUser - ent.Type = m.Type - ent.DisplayFullContent = m.DisplayFullContent - ent.Raw = string(msg) - - go service.SaveMessage(ent) }