diff --git a/handler/at_message.go b/handler/at_message.go index 4a0d3949..c28db253 100644 --- a/handler/at_message.go +++ b/handler/at_message.go @@ -3,15 +3,19 @@ package handler import ( "context" "fmt" + "github.com/duke-git/lancet/v2/slice" "github.com/sashabaranov/go-openai" "go-wechat/client" + "go-wechat/common/current" "go-wechat/config" "go-wechat/entity" "go-wechat/model" + "go-wechat/types" "go-wechat/utils" "log" "regexp" "strings" + "time" ) // handleAtMessage @@ -30,7 +34,7 @@ func handleAtMessage(m model.Message) { } // 预处理一下发送的消息,用正则去掉@机器人的内容 - re := regexp.MustCompile(`@([^ ]+)`) + re := regexp.MustCompile(`@([^ | ]+)`) matches := re.FindStringSubmatch(m.Content) if len(matches) > 0 { // 过滤掉第一个匹配到的 @@ -46,6 +50,38 @@ func handleAtMessage(m model.Message) { Content: config.Conf.Ai.Personality, }) } + + // 查询发信人前面几条文字信息,组装进来 + var oldMessages []entity.Message + client.MySQL.Model(&entity.Message{}). + Where("create_at >= DATE_SUB(NOW(),INTERVAL 30 MINUTE)"). + Where("from_user = ? AND group_user = ? AND display_full_content LIKE ?", m.FromUser, m.GroupUser, "%在群聊中@了你"). + Or("to_user = ? AND group_user = ?", m.FromUser, m.GroupUser). + Order("create_at desc"). + Limit(4).Find(&oldMessages) + // 翻转数组 + slice.Reverse(oldMessages) + // 循环填充消息 + for _, message := range oldMessages { + // 剔除@机器人的内容 + msgStr := message.Content + matches = re.FindStringSubmatch(msgStr) + if len(matches) > 0 { + // 过滤掉第一个匹配到的 + msgStr = strings.Replace(msgStr, matches[0], "", 1) + } + // 填充消息 + role := openai.ChatMessageRoleUser + if message.ToUser != current.GetRobotInfo().WxId { + // 如果收信人不是机器人,表示这条消息是 AI 发的 + role = openai.ChatMessageRoleAssistant + } + messages = append(messages, openai.ChatCompletionMessage{ + Role: role, + Content: msgStr, + }) + } + // 填充用户消息 messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, @@ -63,8 +99,8 @@ func handleAtMessage(m model.Message) { if config.Conf.Ai.BaseUrl != "" { conf.BaseURL = fmt.Sprintf("%s/v1", config.Conf.Ai.BaseUrl) } - client := openai.NewClientWithConfig(conf) - resp, err := client.CreateChatCompletion( + ai := openai.NewClientWithConfig(conf) + resp, err := ai.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ Model: chatModel, @@ -78,6 +114,18 @@ func handleAtMessage(m model.Message) { return } + // 保存一下AI 返回的消息,消息 Id 使用传入 Id 的负数 + var replyMessage entity.Message + replyMessage.MsgId = -m.MsgId + replyMessage.CreateTime = int(time.Now().Local().Unix()) + replyMessage.CreateAt = time.Now().Local() + replyMessage.Content = resp.Choices[0].Message.Content + replyMessage.FromUser = current.GetRobotInfo().WxId // 发信人是机器人 + replyMessage.GroupUser = m.GroupUser // 群成员 + replyMessage.ToUser = m.FromUser // 收信人是发信人 + replyMessage.Type = types.MsgTypeText + client.MySQL.Create(&replyMessage) // 保存入库 + // 发送消息 utils.SendMessage(m.FromUser, m.GroupUser, "\n"+resp.Choices[0].Message.Content, 0) } diff --git a/service/message.go b/service/message.go index 245a2ec0..f76a55d1 100644 --- a/service/message.go +++ b/service/message.go @@ -18,6 +18,7 @@ func SaveMessage(msg entity.Message) { return } if count > 0 { + //log.Printf("消息已存在,消息Id: %d", msg.MsgId) return } err = client.MySQL.Create(&msg).Error