Merge pull request #374 from ghohoj/fix_flow_output

fix: fix flow output when having tools
This commit is contained in:
zR 2024-07-22 15:27:31 +08:00 committed by GitHub
commit 7f96ee326e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 50 additions and 3 deletions

View File

@ -446,9 +446,10 @@ async def predict_stream(model_id, gen_params):
response_id = generate_id('chatcmpl-', 29)
system_fingerprint = generate_id('fp_', 9)
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
delta_text = ""
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
delta_text += decoded_unicode[len(output):]
output = decoded_unicode
lines = output.strip().split("\n")
@ -461,6 +462,7 @@ async def predict_stream(model_id, gen_params):
if first_line in tools:
is_function_call = True
function_name = first_line
delta_text = lines[1]
# 工具调用返回
if is_function_call:
@ -496,6 +498,7 @@ async def predict_stream(model_id, gen_params):
has_send_first_chunk = True
function_call = {"name": None, "arguments": delta_text}
delta_text = ""
tool_call = ChatCompletionMessageToolCall(
index=0,
id=None,
@ -557,6 +560,7 @@ async def predict_stream(model_id, gen_params):
role="assistant",
function_call=None,
)
delta_text = ""
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
@ -592,8 +596,51 @@ async def predict_stream(model_id, gen_params):
object="chat.completion.chunk",
usage=None
).model_dump_json(exclude_unset=True)
yield '[DONE]'
elif delta_text != "":
message = DeltaMessage(
content="",
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
finish_reason = 'stop'
message = DeltaMessage(
content=delta_text,
role="assistant",
function_call=None,
)
delta_text = ""
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
yield '[DONE]'
else:
yield '[DONE]'
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
delta = DeltaMessage(role="assistant", content=value)