122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
import torch
|
||
from typing import TypedDict, Literal, List, Optional, Tuple, Iterator
|
||
|
||
|
||
#### data types #########
|
||
# 下面的数据类型定义与CharacterGLM API一致,但与modeling_chatglm.py的chat方法不一致
|
||
# 参考 https://open.bigmodel.cn/dev/api#characterglm
|
||
RoleType = Literal["user", "assistant"]
|
||
|
||
class Msg(TypedDict):
|
||
role: RoleType
|
||
content: str
|
||
|
||
|
||
class SessionMeta(TypedDict):
|
||
user_name: str
|
||
bot_name: str
|
||
bot_info: str
|
||
user_info: Optional[str]
|
||
|
||
|
||
HistoryType = List[Msg]
|
||
|
||
|
||
class CharacterGLMGenerationUtils:
|
||
@staticmethod
|
||
def convert_chatglm_history_to_characterglm_history(user_query: str, history: List[Tuple[str, str]]) -> HistoryType:
|
||
characterglm_history: HistoryType = []
|
||
for i, (query, response) in enumerate(history):
|
||
if i == 0 and query == '':
|
||
# first empty query is an placeholder
|
||
pass
|
||
else:
|
||
characterglm_history.append({
|
||
"role": "user",
|
||
"content": query
|
||
})
|
||
characterglm_history.append({
|
||
"role": "assistant",
|
||
"content": response
|
||
})
|
||
|
||
characterglm_history.append({
|
||
"role": "user",
|
||
"content": user_query
|
||
})
|
||
return characterglm_history
|
||
|
||
@staticmethod
|
||
def build_inputs(session_meta: SessionMeta, history: HistoryType) -> str:
|
||
"""
|
||
注意:这里假设history最后一条消息是用户query
|
||
"""
|
||
texts = []
|
||
texts.append(
|
||
f"以下是一段{session_meta['bot_name']}和{session_meta['user_name']}之间的对话。")
|
||
if session_meta.get("bot_info"):
|
||
texts.append(f"关于{session_meta['bot_name']}的信息:{session_meta['bot_info']}")
|
||
if session_meta.get("user_info"):
|
||
texts.append(
|
||
f"关于{session_meta['user_name']}的信息:{session_meta['user_info']}")
|
||
|
||
assert history and history[-1]['role'] == 'user'
|
||
for msg in history:
|
||
name = session_meta['user_name'] if msg['role'] == 'user' else session_meta['bot_name']
|
||
texts.append(f"[{name}]" + msg['content'].strip())
|
||
|
||
texts = [text.replace('\n', ' ') for text in texts]
|
||
texts.append(f"[{session_meta['bot_name']}]")
|
||
return '\n'.join(texts)
|
||
|
||
|
||
class CharacterGLMAPI:
|
||
@staticmethod
|
||
def build_api_arguments(session_meta: SessionMeta, history: HistoryType) -> dict:
|
||
return {
|
||
"model": "characterglm",
|
||
"meta": session_meta,
|
||
"prompt": history
|
||
}
|
||
|
||
@classmethod
|
||
def async_invoke(cls, session_meta: SessionMeta, history: HistoryType):
|
||
"""
|
||
注意:
|
||
1. 先设置zhipuai.api_key
|
||
2. 建议传入`return_type='text'`,否则返回结果是json字符串
|
||
|
||
参考:
|
||
https://open.bigmodel.cn/dev/api#characterglm
|
||
"""
|
||
import zhipuai
|
||
kwargs = cls.build_api_arguments(session_meta, history)
|
||
return zhipuai.model_api.async_invoke(**kwargs, return_type='text')
|
||
|
||
@classmethod
|
||
def invoke(cls, session_meta: SessionMeta, history: HistoryType):
|
||
"""
|
||
注意:
|
||
1. 先设置zhipuai.api_key
|
||
2. 建议传入`return_type='text'`,否则返回结果是json字符串
|
||
3. 需要再次调用`zhipuai.model_api.query_async_invoke_result`才能获取生成结果
|
||
|
||
参考:
|
||
https://open.bigmodel.cn/dev/api#characterglm
|
||
"""
|
||
import zhipuai
|
||
kwargs = cls.build_api_arguments(session_meta, history)
|
||
return zhipuai.model_api.invoke(**kwargs, return_type='text')
|
||
|
||
@classmethod
|
||
def generate(cls, session_meta: SessionMeta, history: HistoryType) -> str:
|
||
result = cls.invoke(session_meta, history)
|
||
if not result['success']:
|
||
raise RuntimeError(result)
|
||
return result['data']['choices'][0]['content']
|
||
|
||
@classmethod
|
||
def stream_generate(cls, session_meta: SessionMeta, history: HistoryType) -> Iterator[str]:
|
||
# 伪流式生成
|
||
return iter(cls.generate(session_meta, history))
|