122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
|
import torch
|
|||
|
from typing import TypedDict, Literal, List, Optional, Tuple, Iterator
|
|||
|
|
|||
|
|
|||
|
#### data types #########
|
|||
|
# 下面的数据类型定义与CharacterGLM API一致,但与modeling_chatglm.py的chat方法不一致
|
|||
|
# 参考 https://open.bigmodel.cn/dev/api#characterglm
|
|||
|
RoleType = Literal["user", "assistant"]
|
|||
|
|
|||
|
class Msg(TypedDict):
|
|||
|
role: RoleType
|
|||
|
content: str
|
|||
|
|
|||
|
|
|||
|
class SessionMeta(TypedDict):
|
|||
|
user_name: str
|
|||
|
bot_name: str
|
|||
|
bot_info: str
|
|||
|
user_info: Optional[str]
|
|||
|
|
|||
|
|
|||
|
HistoryType = List[Msg]
|
|||
|
|
|||
|
|
|||
|
class CharacterGLMGenerationUtils:
|
|||
|
@staticmethod
|
|||
|
def convert_chatglm_history_to_characterglm_history(user_query: str, history: List[Tuple[str, str]]) -> HistoryType:
|
|||
|
characterglm_history: HistoryType = []
|
|||
|
for i, (query, response) in enumerate(history):
|
|||
|
if i == 0 and query == '':
|
|||
|
# first empty query is an placeholder
|
|||
|
pass
|
|||
|
else:
|
|||
|
characterglm_history.append({
|
|||
|
"role": "user",
|
|||
|
"content": query
|
|||
|
})
|
|||
|
characterglm_history.append({
|
|||
|
"role": "assistant",
|
|||
|
"content": response
|
|||
|
})
|
|||
|
|
|||
|
characterglm_history.append({
|
|||
|
"role": "user",
|
|||
|
"content": user_query
|
|||
|
})
|
|||
|
return characterglm_history
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def build_inputs(session_meta: SessionMeta, history: HistoryType) -> str:
|
|||
|
"""
|
|||
|
注意:这里假设history最后一条消息是用户query
|
|||
|
"""
|
|||
|
texts = []
|
|||
|
texts.append(
|
|||
|
f"以下是一段{session_meta['bot_name']}和{session_meta['user_name']}之间的对话。")
|
|||
|
if session_meta.get("bot_info"):
|
|||
|
texts.append(f"关于{session_meta['bot_name']}的信息:{session_meta['bot_info']}")
|
|||
|
if session_meta.get("user_info"):
|
|||
|
texts.append(
|
|||
|
f"关于{session_meta['user_name']}的信息:{session_meta['user_info']}")
|
|||
|
|
|||
|
assert history and history[-1]['role'] == 'user'
|
|||
|
for msg in history:
|
|||
|
name = session_meta['user_name'] if msg['role'] == 'user' else session_meta['bot_name']
|
|||
|
texts.append(f"[{name}]" + msg['content'].strip())
|
|||
|
|
|||
|
texts = [text.replace('\n', ' ') for text in texts]
|
|||
|
texts.append(f"[{session_meta['bot_name']}]")
|
|||
|
return '\n'.join(texts)
|
|||
|
|
|||
|
|
|||
|
class CharacterGLMAPI:
|
|||
|
@staticmethod
|
|||
|
def build_api_arguments(session_meta: SessionMeta, history: HistoryType) -> dict:
|
|||
|
return {
|
|||
|
"model": "characterglm",
|
|||
|
"meta": session_meta,
|
|||
|
"prompt": history
|
|||
|
}
|
|||
|
|
|||
|
@classmethod
|
|||
|
def async_invoke(cls, session_meta: SessionMeta, history: HistoryType):
|
|||
|
"""
|
|||
|
注意:
|
|||
|
1. 先设置zhipuai.api_key
|
|||
|
2. 建议传入`return_type='text'`,否则返回结果是json字符串
|
|||
|
|
|||
|
参考:
|
|||
|
https://open.bigmodel.cn/dev/api#characterglm
|
|||
|
"""
|
|||
|
import zhipuai
|
|||
|
kwargs = cls.build_api_arguments(session_meta, history)
|
|||
|
return zhipuai.model_api.async_invoke(**kwargs, return_type='text')
|
|||
|
|
|||
|
@classmethod
|
|||
|
def invoke(cls, session_meta: SessionMeta, history: HistoryType):
|
|||
|
"""
|
|||
|
注意:
|
|||
|
1. 先设置zhipuai.api_key
|
|||
|
2. 建议传入`return_type='text'`,否则返回结果是json字符串
|
|||
|
3. 需要再次调用`zhipuai.model_api.query_async_invoke_result`才能获取生成结果
|
|||
|
|
|||
|
参考:
|
|||
|
https://open.bigmodel.cn/dev/api#characterglm
|
|||
|
"""
|
|||
|
import zhipuai
|
|||
|
kwargs = cls.build_api_arguments(session_meta, history)
|
|||
|
return zhipuai.model_api.invoke(**kwargs, return_type='text')
|
|||
|
|
|||
|
@classmethod
|
|||
|
def generate(cls, session_meta: SessionMeta, history: HistoryType) -> str:
|
|||
|
result = cls.invoke(session_meta, history)
|
|||
|
if not result['success']:
|
|||
|
raise RuntimeError(result)
|
|||
|
return result['data']['choices'][0]['content']
|
|||
|
|
|||
|
@classmethod
|
|||
|
def stream_generate(cls, session_meta: SessionMeta, history: HistoryType) -> Iterator[str]:
|
|||
|
# 伪流式生成
|
|||
|
return iter(cls.generate(session_meta, history))
|