diff --git a/entity/friend.go b/entity/friend.go index 626dbe4..69d4b6c 100644 --- a/entity/friend.go +++ b/entity/friend.go @@ -13,6 +13,7 @@ type Friend struct { Pinyin string `json:"pinyin"` // 昵称拼音大写首字母 PinyinAll string `json:"pinyinAll"` // 昵称全拼 EnableAi bool `json:"enableAI" gorm:"type:tinyint(1) default 0 not null"` // 是否使用AI + AiModel string `json:"aiModel"` // AI模型 EnableChatRank bool `json:"enableChatRank" gorm:"type:tinyint(1) default 0 not null"` // 是否使用聊天排行 EnableWelcome bool `json:"enableWelcome" gorm:"type:tinyint(1) default 0 not null"` // 是否启用迎新 IsOk bool `json:"isOk" gorm:"type:tinyint(1) default 0 not null"` // 是否正常 diff --git a/plugin/plugins/ai.go b/plugin/plugins/ai.go index 59eb786..93b58ef 100644 --- a/plugin/plugins/ai.go +++ b/plugin/plugins/ai.go @@ -28,9 +28,9 @@ func AI(m *plugin.MessageContext) { } // 取出所有启用了AI的好友或群组 - var count int64 - client.MySQL.Model(&entity.Friend{}).Where("enable_ai IS TRUE").Where("wxid = ?", m.FromUser).Count(&count) - if count < 1 { + var friendInfo entity.Friend + client.MySQL.Where("wxid = ?", m.FromUser).First(&friendInfo) + if friendInfo.Wxid == "" { return } @@ -93,7 +93,9 @@ func AI(m *plugin.MessageContext) { // 配置模型 chatModel := openai.GPT3Dot5Turbo0613 - if config.Conf.Ai.Model != "" { + if friendInfo.AiModel != "" { + chatModel = friendInfo.AiModel + } else if config.Conf.Ai.Model != "" { chatModel = config.Conf.Ai.Model } diff --git a/vo/friend.go b/vo/friend.go index f8704d2..874d368 100644 --- a/vo/friend.go +++ b/vo/friend.go @@ -13,6 +13,7 @@ type FriendItem struct { PinyinAll string // 昵称全拼 Wxid string // 微信原始Id EnableAi bool // 是否使用AI + AiModel string // AI模型 EnableChatRank bool // 是否使用聊天排行 EnableWelcome bool // 是否使用迎新 EnableCommand bool // 是否启用指令