From cab6b2633ebb6be59dc47df951c78b1474f0aaae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=BB=E6=AC=A2?= Date: Sat, 9 Dec 2023 11:52:11 +0800 Subject: [PATCH] =?UTF-8?q?:art:=20=E4=BC=98=E5=8C=96=20AI=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=AE=B0=E5=BD=95=E5=87=A0=E5=8F=A5=E4=B8=8A?= =?UTF-8?q?=E4=B8=8B=E6=96=87=EF=BC=8C=E5=B9=B6=E4=BC=98=E5=8C=96=E4=BA=86?= =?UTF-8?q?=20at=20=E6=B6=88=E6=81=AF=E5=8C=B9=E9=85=8D=E7=9A=84=E6=AD=A3?= =?UTF-8?q?=E5=88=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- handler/at_message.go | 54 ++++++++++++++++++++++++++++++++++++++++--- service/message.go | 1 + 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/handler/at_message.go b/handler/at_message.go index 4a0d394..c28db25 100644 --- a/handler/at_message.go +++ b/handler/at_message.go @@ -3,15 +3,19 @@ package handler import ( "context" "fmt" + "github.com/duke-git/lancet/v2/slice" "github.com/sashabaranov/go-openai" "go-wechat/client" + "go-wechat/common/current" "go-wechat/config" "go-wechat/entity" "go-wechat/model" + "go-wechat/types" "go-wechat/utils" "log" "regexp" "strings" + "time" ) // handleAtMessage @@ -30,7 +34,7 @@ func handleAtMessage(m model.Message) { } // 预处理一下发送的消息,用正则去掉@机器人的内容 - re := regexp.MustCompile(`@([^ ]+)`) + re := regexp.MustCompile(`@([^ | ]+)`) matches := re.FindStringSubmatch(m.Content) if len(matches) > 0 { // 过滤掉第一个匹配到的 @@ -46,6 +50,38 @@ func handleAtMessage(m model.Message) { Content: config.Conf.Ai.Personality, }) } + + // 查询发信人前面几条文字信息,组装进来 + var oldMessages []entity.Message + client.MySQL.Model(&entity.Message{}). + Where("create_at >= DATE_SUB(NOW(),INTERVAL 30 MINUTE)"). + Where("from_user = ? AND group_user = ? AND display_full_content LIKE ?", m.FromUser, m.GroupUser, "%在群聊中@了你"). + Or("to_user = ? AND group_user = ?", m.FromUser, m.GroupUser). + Order("create_at desc"). + Limit(4).Find(&oldMessages) + // 翻转数组 + slice.Reverse(oldMessages) + // 循环填充消息 + for _, message := range oldMessages { + // 剔除@机器人的内容 + msgStr := message.Content + matches = re.FindStringSubmatch(msgStr) + if len(matches) > 0 { + // 过滤掉第一个匹配到的 + msgStr = strings.Replace(msgStr, matches[0], "", 1) + } + // 填充消息 + role := openai.ChatMessageRoleUser + if message.ToUser != current.GetRobotInfo().WxId { + // 如果收信人不是机器人,表示这条消息是 AI 发的 + role = openai.ChatMessageRoleAssistant + } + messages = append(messages, openai.ChatCompletionMessage{ + Role: role, + Content: msgStr, + }) + } + // 填充用户消息 messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, @@ -63,8 +99,8 @@ func handleAtMessage(m model.Message) { if config.Conf.Ai.BaseUrl != "" { conf.BaseURL = fmt.Sprintf("%s/v1", config.Conf.Ai.BaseUrl) } - client := openai.NewClientWithConfig(conf) - resp, err := client.CreateChatCompletion( + ai := openai.NewClientWithConfig(conf) + resp, err := ai.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ Model: chatModel, @@ -78,6 +114,18 @@ func handleAtMessage(m model.Message) { return } + // 保存一下AI 返回的消息,消息 Id 使用传入 Id 的负数 + var replyMessage entity.Message + replyMessage.MsgId = -m.MsgId + replyMessage.CreateTime = int(time.Now().Local().Unix()) + replyMessage.CreateAt = time.Now().Local() + replyMessage.Content = resp.Choices[0].Message.Content + replyMessage.FromUser = current.GetRobotInfo().WxId // 发信人是机器人 + replyMessage.GroupUser = m.GroupUser // 群成员 + replyMessage.ToUser = m.FromUser // 收信人是发信人 + replyMessage.Type = types.MsgTypeText + client.MySQL.Create(&replyMessage) // 保存入库 + // 发送消息 utils.SendMessage(m.FromUser, m.GroupUser, "\n"+resp.Choices[0].Message.Content, 0) } diff --git a/service/message.go b/service/message.go index 245a2ec..f76a55d 100644 --- a/service/message.go +++ b/service/message.go @@ -18,6 +18,7 @@ func SaveMessage(msg entity.Message) { return } if count > 0 { + //log.Printf("消息已存在,消息Id: %d", msg.MsgId) return } err = client.MySQL.Create(&msg).Error