update issues THUDM#618 使用tools时无法stream流式输出的问题
This commit is contained in:
parent
6bf9f85f70
commit
9a84c470b9
|
@ -19,7 +19,7 @@ from sse_starlette.sse import EventSourceResponse
|
|||
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
|
||||
MAX_MODEL_LENGTH = 8192
|
||||
MAX_MODEL_LENGTH = 8192
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
@ -444,23 +444,35 @@ async def predict_stream(model_id, gen_params):
|
|||
system_fingerprint = generate_id('fp_', 9)
|
||||
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
|
||||
delta_text = ""
|
||||
delta_confirming_texts = []
|
||||
confirm_tool_state = 'un_confirm' if tools else 'none'
|
||||
# 带有tools时可以确认是否调用工具的最大字符长度 = 工具名最大长度 + 可能的前面有“\n”、后面“\n{”共3个字符。
|
||||
max_confirm_tool_length = len(max(tools, len)) + 3 if tools else 0
|
||||
async for new_response in generate_stream_glm4(gen_params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text += decoded_unicode[len(output):]
|
||||
if confirm_tool_state == 'un_confirm':
|
||||
delta_confirming_texts.append(decoded_unicode[len(output):])
|
||||
|
||||
output = decoded_unicode
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
# 检查是否为工具
|
||||
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
|
||||
|
||||
if not is_function_call and len(lines) >= 2:
|
||||
if confirm_tool_state == 'un_confirm' and len(lines) >= 2 and lines[1].startswith("{"):
|
||||
first_line = lines[0].strip()
|
||||
if first_line in tools:
|
||||
is_function_call = True
|
||||
function_name = first_line
|
||||
delta_text = lines[1]
|
||||
confirm_tool_state == 'confirmed'
|
||||
else:
|
||||
confirm_tool_state == 'none'
|
||||
|
||||
# 当传入tools时,经过大模型输出几轮后,已经可以确认不需要调用工具了
|
||||
if confirm_tool_state == 'un_confirm' and max_confirm_tool_length < len(delta_text):
|
||||
confirm_tool_state == 'none'
|
||||
# 工具调用返回
|
||||
if is_function_call:
|
||||
if not has_send_first_chunk:
|
||||
|
@ -524,7 +536,7 @@ async def predict_stream(model_id, gen_params):
|
|||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
|
||||
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
|
||||
elif confirm_tool_state == 'un_confirm':
|
||||
continue
|
||||
|
||||
# 常规返回
|
||||
|
@ -552,6 +564,29 @@ async def predict_stream(model_id, gen_params):
|
|||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
has_send_first_chunk = True
|
||||
|
||||
for text in delta_confirming_texts:
|
||||
message = DeltaMessage(
|
||||
content=text,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
delta_confirming_texts = []
|
||||
delta_text = ""
|
||||
|
||||
message = DeltaMessage(
|
||||
content=delta_text,
|
||||
role="assistant",
|
||||
|
@ -613,7 +648,7 @@ async def predict_stream(model_id, gen_params):
|
|||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
|
||||
finish_reason = 'stop'
|
||||
message = DeltaMessage(
|
||||
content=delta_text,
|
||||
|
|
Loading…
Reference in New Issue