update issues THUDM#618 使用tools时无法stream流式输出的问题

This commit is contained in:
chenglj 2024-10-30 16:22:50 +08:00
parent 6bf9f85f70
commit 9a84c470b9
1 changed files with 40 additions and 5 deletions

View File

@ -19,7 +19,7 @@ from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MAX_MODEL_LENGTH = 8192
MAX_MODEL_LENGTH = 8192
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -444,23 +444,35 @@ async def predict_stream(model_id, gen_params):
system_fingerprint = generate_id('fp_', 9)
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
delta_text = ""
delta_confirming_texts = []
confirm_tool_state = 'un_confirm' if tools else 'none'
# 带有tools时可以确认是否调用工具的最大字符长度 = 工具名最大长度 + 可能的前面有“\n”、后面“\n{”共3个字符。
max_confirm_tool_length = len(max(tools, len)) + 3 if tools else 0
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text += decoded_unicode[len(output):]
if confirm_tool_state == 'un_confirm':
delta_confirming_texts.append(decoded_unicode[len(output):])
output = decoded_unicode
lines = output.strip().split("\n")
# 检查是否为工具
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
if not is_function_call and len(lines) >= 2:
if confirm_tool_state == 'un_confirm' and len(lines) >= 2 and lines[1].startswith("{"):
first_line = lines[0].strip()
if first_line in tools:
is_function_call = True
function_name = first_line
delta_text = lines[1]
confirm_tool_state == 'confirmed'
else:
confirm_tool_state == 'none'
# 当传入tools时经过大模型输出几轮后已经可以确认不需要调用工具了
if confirm_tool_state == 'un_confirm' and max_confirm_tool_length < len(delta_text):
confirm_tool_state == 'none'
# 工具调用返回
if is_function_call:
if not has_send_first_chunk:
@ -524,7 +536,7 @@ async def predict_stream(model_id, gen_params):
yield chunk.model_dump_json(exclude_unset=True)
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
elif confirm_tool_state == 'un_confirm':
continue
# 常规返回
@ -552,6 +564,29 @@ async def predict_stream(model_id, gen_params):
yield chunk.model_dump_json(exclude_unset=True)
has_send_first_chunk = True
for text in delta_confirming_texts:
message = DeltaMessage(
content=text,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
delta_confirming_texts = []
delta_text = ""
message = DeltaMessage(
content=delta_text,
role="assistant",
@ -613,7 +648,7 @@ async def predict_stream(model_id, gen_params):
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
finish_reason = 'stop'
message = DeltaMessage(
content=delta_text,