支持原生的多轮对话

This commit is contained in:
Vinlic 2024-04-28 10:16:29 +08:00
parent a695921e73
commit 295e69e6cb
3 changed files with 57 additions and 28 deletions

View File

@ -278,6 +278,9 @@ Authorization: Bearer [refresh_token]
{ {
// 如果使用智能体请填写智能体ID到此处否则可以乱填 // 如果使用智能体请填写智能体ID到此处否则可以乱填
"model": "glm4", "model": "glm4",
// 目前多轮对话基于消息合并实现某些场景可能导致能力下降且token最高为4096
// 如果您想获得原生的多轮对话体验可以传入首轮消息获得的id来接续上下文
// "conversation_id": "65f6c28546bae1f0fbb532de",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -292,6 +295,7 @@ Authorization: Bearer [refresh_token]
响应数据: 响应数据:
```json ```json
{ {
// conversation_id你可以传入到下一轮对话来接续上下文
"id": "65f6c28546bae1f0fbb532de", "id": "65f6c28546bae1f0fbb532de",
"model": "glm4", "model": "glm4",
"object": "chat.completion", "object": "chat.completion",

View File

@ -170,6 +170,7 @@ async function createCompletion(
messages: any[], messages: any[],
refreshToken: string, refreshToken: string,
assistantId = DEFAULT_ASSISTANT_ID, assistantId = DEFAULT_ASSISTANT_ID,
refConvId = '',
retryCount = 0 retryCount = 0
) { ) {
return (async () => { return (async () => {
@ -184,13 +185,14 @@ async function createCompletion(
: []; : [];
// 请求流 // 请求流
console.log(refConvId)
const token = await acquireToken(refreshToken); const token = await acquireToken(refreshToken);
const result = await axios.post( const result = await axios.post(
"https://chatglm.cn/chatglm/backend-api/assistant/stream", "https://chatglm.cn/chatglm/backend-api/assistant/stream",
{ {
assistant_id: assistantId, assistant_id: assistantId,
conversation_id: "", conversation_id: refConvId,
messages: messagesPrepare(messages, refs), messages: messagesPrepare(messages, refs, !!refConvId),
meta_data: { meta_data: {
channel: "", channel: "",
draft_id: "", draft_id: "",
@ -232,7 +234,7 @@ async function createCompletion(
// 异步移除会话 // 异步移除会话
removeConversation(answer.id, refreshToken, assistantId).catch((err) => removeConversation(answer.id, refreshToken, assistantId).catch((err) =>
console.error(err) !refConvId && console.error(err)
); );
return answer; return answer;
@ -246,6 +248,7 @@ async function createCompletion(
messages, messages,
refreshToken, refreshToken,
assistantId, assistantId,
refConvId,
retryCount + 1 retryCount + 1
); );
})(); })();
@ -266,6 +269,7 @@ async function createCompletionStream(
messages: any[], messages: any[],
refreshToken: string, refreshToken: string,
assistantId = DEFAULT_ASSISTANT_ID, assistantId = DEFAULT_ASSISTANT_ID,
refConvId = '',
retryCount = 0 retryCount = 0
) { ) {
return (async () => { return (async () => {
@ -285,8 +289,8 @@ async function createCompletionStream(
`https://chatglm.cn/chatglm/backend-api/assistant/stream`, `https://chatglm.cn/chatglm/backend-api/assistant/stream`,
{ {
assistant_id: assistantId, assistant_id: assistantId,
conversation_id: "", conversation_id: refConvId,
messages: messagesPrepare(messages, refs), messages: messagesPrepare(messages, refs, !!refConvId),
meta_data: { meta_data: {
channel: "", channel: "",
draft_id: "", draft_id: "",
@ -349,7 +353,7 @@ async function createCompletionStream(
); );
// 流传输结束后异步移除会话 // 流传输结束后异步移除会话
removeConversation(convId, refreshToken, assistantId).catch((err) => removeConversation(convId, refreshToken, assistantId).catch((err) =>
console.error(err) !refConvId && console.error(err)
); );
}); });
})().catch((err) => { })().catch((err) => {
@ -362,6 +366,7 @@ async function createCompletionStream(
messages, messages,
refreshToken, refreshToken,
assistantId, assistantId,
refConvId,
retryCount + 1 retryCount + 1
); );
})(); })();
@ -488,8 +493,10 @@ function extractRefFileUrls(messages: any[]) {
* *
* *
* @param messages gpt系列消息格式 * @param messages gpt系列消息格式
* @param refs
* @param isRefConv
*/ */
function messagesPrepare(messages: any[], refs: any[]) { function messagesPrepare(messages: any[], refs: any[], isRefConv = false) {
// 检查最新消息是否含有"type": "image_url"或"type": "file",如果有则注入消息 // 检查最新消息是否含有"type": "image_url"或"type": "file",如果有则注入消息
let latestMessage = messages[messages.length - 1]; let latestMessage = messages[messages.length - 1];
let hasFileOrImage = let hasFileOrImage =
@ -514,27 +521,46 @@ function messagesPrepare(messages: any[], refs: any[]) {
// logger.info("注入提升尾部消息注意力system prompt"); // logger.info("注入提升尾部消息注意力system prompt");
} }
const content = ( let content;
messages.reduce((content, message) => { if(isRefConv || messages.length < 2) {
const role = message.role content = messages.reduce((content, message) => {
.replace("system", "<|sytstem|>")
.replace("assistant", "<|assistant|>")
.replace("user", "<|user|>");
if (_.isArray(message.content)) { if (_.isArray(message.content)) {
return ( return (
message.content.reduce((_content, v) => { message.content.reduce((_content, v) => {
if (!_.isObject(v) || v["type"] != "text") return _content; if (!_.isObject(v) || v["type"] != "text") return _content;
return _content + (`${role}\n` + v["text"] || "") + "\n"; return _content + (v["text"] || "") + "\n";
}, content) }, content)
); );
} }
return (content += `${role}\n${message.content}\n`); return content + `${message.content}\n`;
}, "") + "<|assistant|>\n" }, "");
) logger.info("\n透传内容\n" + content);
// 移除MD图像URL避免幻觉 }
.replace(/\!\[.+\]\(.+\)/g, "") else {
// 移除临时路径避免在新会话引发幻觉 content = (
.replace(/\/mnt\/data\/.+/g, ""); messages.reduce((content, message) => {
const role = message.role
.replace("system", "<|sytstem|>")
.replace("assistant", "<|assistant|>")
.replace("user", "<|user|>");
if (_.isArray(message.content)) {
return (
message.content.reduce((_content, v) => {
if (!_.isObject(v) || v["type"] != "text") return _content;
return _content + (`${role}\n` + v["text"] || "") + "\n";
}, content)
);
}
return (content += `${role}\n${message.content}\n`);
}, "") + "<|assistant|>\n"
)
// 移除MD图像URL避免幻觉
.replace(/\!\[.+\]\(.+\)/g, "")
// 移除临时路径避免在新会话引发幻觉
.replace(/\/mnt\/data\/.+/g, "");
logger.info("\n对话合并\n" + content);
}
const fileRefs = refs.filter((ref) => !ref.width && !ref.height); const fileRefs = refs.filter((ref) => !ref.width && !ref.height);
const imageRefs = refs const imageRefs = refs
.filter((ref) => ref.width || ref.height) .filter((ref) => ref.width || ref.height)
@ -542,8 +568,6 @@ function messagesPrepare(messages: any[], refs: any[]) {
ref.image_url = ref.file_url; ref.image_url = ref.file_url;
return ref; return ref;
}); });
content
logger.info("\n对话合并\n" + content);
return [ return [
{ {
role: "user", role: "user",

View File

@ -13,22 +13,23 @@ export default {
'/completions': async (request: Request) => { '/completions': async (request: Request) => {
request request
.validate('body.conversation_id', v => _.isUndefined(v) || _.isString(v))
.validate('body.messages', _.isArray) .validate('body.messages', _.isArray)
.validate('headers.authorization', _.isString) .validate('headers.authorization', _.isString)
// refresh_token切分 // refresh_token切分
const tokens = chat.tokenSplit(request.headers.authorization); const tokens = chat.tokenSplit(request.headers.authorization);
// 随机挑选一个refresh_token // 随机挑选一个refresh_token
const token = _.sample(tokens); const token = _.sample(tokens);
const messages = request.body.messages; const { model, conversation_id: convId, messages, stream } = request.body;
const assistantId = /^[a-z0-9]{24,}$/.test(request.body.model) ? request.body.model : undefined const assistantId = /^[a-z0-9]{24,}$/.test(model) ? model : undefined
if (request.body.stream) { if (stream) {
const stream = await chat.createCompletionStream(request.body.messages, token, assistantId); const stream = await chat.createCompletionStream(messages, token, assistantId, convId);
return new Response(stream, { return new Response(stream, {
type: "text/event-stream" type: "text/event-stream"
}); });
} }
else else
return await chat.createCompletion(messages, token, assistantId); return await chat.createCompletion(messages, token, assistantId, convId);
} }
} }