From b9ffe763c56c17be9c8c054cb67f8b926d994422 Mon Sep 17 00:00:00 2001 From: liuzhenghua-jk Date: Wed, 12 Jun 2024 08:55:31 +0800 Subject: [PATCH] fix: tool call bug --- basic_demo/openai_api_server.py | 340 ++++++++++++-------------------- 1 file changed, 127 insertions(+), 213 deletions(-) diff --git a/basic_demo/openai_api_server.py b/basic_demo/openai_api_server.py index e758621..37a5b19 100644 --- a/basic_demo/openai_api_server.py +++ b/basic_demo/openai_api_server.py @@ -1,11 +1,13 @@ import time from asyncio.log import logger -import re import uvicorn import gc import json -import torch +import random +import string +import logging +import torch from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine from fastapi import FastAPI, HTTPException, Response from fastapi.middleware.cors import CORSMiddleware @@ -18,6 +20,7 @@ from sse_starlette.sse import EventSourceResponse EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 MODEL_PATH = 'THUDM/glm-4-9b-chat' +# max model length 128k MAX_MODEL_LENGTH = 8192 @@ -40,6 +43,11 @@ app.add_middleware( ) +def generate_id(prefix: str) -> str: + suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=24)) + return f"{prefix}-{suffix}" + + class ModelCard(BaseModel): id: str object: str = "model" @@ -72,22 +80,23 @@ class UsageInfo(BaseModel): class ChatCompletionMessageToolCall(BaseModel): - id: str + id: Optional[str] = Field(default_factory=lambda: generate_id('call')) function: FunctionCall - type: Literal["function"] + type: Optional[Literal["function"]] = 'function' class ChatMessage(BaseModel): - role: Literal["user", "assistant", "system", "tool"] + role: Literal["user", "assistant", "system", "function", "tool"] content: Optional[str] = None - function_call: Optional[FunctionCallResponse] = None + function_call: Optional[FunctionCall] = None tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None class DeltaMessage(BaseModel): - role: Optional[Literal["user", "assistant", "system"]] = None + role: Optional[Literal["user", "assistant", "function", "system"]] = None content: Optional[str] = None - function_call: Optional[FunctionCallResponse] = None + function_call: Optional[FunctionCall] = None + tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None class ChatCompletionResponseChoice(BaseModel): @@ -104,12 +113,78 @@ class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponse(BaseModel): model: str - id: str + id: str = Field(default_factory=lambda: generate_id('chatcmpl')) object: Literal["chat.completion", "chat.completion.chunk"] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) usage: Optional[UsageInfo] = None + @staticmethod + def _convert_to_tool_calls_from_content(content: str) -> Union[List[ChatCompletionMessageToolCall], str]: + tool_calls = [] + content = content.strip() + for response in content.split("<|assistant|>"): + if "\n" in response: + metadata, content = response.split("\n", maxsplit=1) + else: + metadata, content = "", response + if metadata.strip(): + parameters = eval(content.strip()) + function_call = FunctionCall( + name=metadata.strip(), + arguments=json.dumps(parameters, ensure_ascii=False) + ) + tool_calls.append(ChatCompletionMessageToolCall(function=function_call)) + return tool_calls if len(tool_calls) > 0 else content + + @staticmethod + def stream_reply(model_id: str, content: str, finish_reason: str, use_tool: bool = False) -> str: + if content.startswith("\n"): + content = content[1:] + tool_calls = None + if use_tool: + parsed_tool_calls = ChatCompletionResponse._convert_to_tool_calls_from_content(content) + if isinstance(parsed_tool_calls, list): + tool_calls = parsed_tool_calls + finish_reason = "tool_calls" + content = None + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant", content=content, tool_calls=tool_calls), + finish_reason=finish_reason + ) + return ChatCompletionResponse( + model=model_id, + choices=[choice_data], + created=int(time.time()), + object="chat.completion.chunk" + ).model_dump_json(exclude_none=True) + + @staticmethod + def reply(model_id: str, content: str, finish_reason: str, use_tool: bool = False, usage: UsageInfo = None) \ + -> 'ChatCompletionResponse': + if content.startswith("\n"): + content = content[1:] + tool_calls = None + if use_tool: + parsed_tool_calls = ChatCompletionResponse._convert_to_tool_calls_from_content(content) + if isinstance(parsed_tool_calls, list): + tool_calls = parsed_tool_calls + finish_reason = "tool_calls" + content = None + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=content, tool_calls=tool_calls), + finish_reason=finish_reason + ) + return ChatCompletionResponse( + model=model_id, + choices=[choice_data], + created=int(time.time()), + object="chat.completion", + usage=usage + ) + class ChatCompletionRequest(BaseModel): model: str @@ -119,7 +194,7 @@ class ChatCompletionRequest(BaseModel): max_tokens: Optional[int] = None stream: Optional[bool] = False tools: Optional[Union[dict, List[dict]]] = None - tool_choice: Optional[Union[str, dict]] = "None" + tool_choice: Optional[Union[str, dict]] = None repetition_penalty: Optional[float] = 1.1 @@ -133,48 +208,6 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): return scores -def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: - lines = output.strip().split("\n") - arguments_json = None - special_tools = ["cogview", "simple_browser"] - - tool_call_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') - - if len(lines) >= 2 and tool_call_pattern.match(lines[0]): - function_name = lines[0].strip() - arguments = "\n".join(lines[1:]).strip() - - try: - arguments_json = json.loads(arguments) - is_tool_call = True - except json.JSONDecodeError: - is_tool_call = function_name in special_tools - - if is_tool_call and use_tool: - content = { - "name": function_name, - "arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments, ensure_ascii=False) - } - if function_name == "simple_browser": - search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)') - match = search_pattern.match(arguments) - if match: - content["arguments"] = json.dumps({ - "query": match.group(1), - "recency_days": int(match.group(2)) - }, ensure_ascii=False) - elif function_name == "cogview": - content["arguments"] = json.dumps({ - "prompt": arguments - }, ensure_ascii=False) - - return content - return output.strip() - - - - - @torch.inference_mode() async def generate_stream_glm4(params): messages = params["messages"] @@ -184,7 +217,6 @@ async def generate_stream_glm4(params): repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 1.0)) max_new_tokens = int(params.get("max_tokens", 8192)) - messages = process_messages(messages, tools=tools, tool_choice=tool_choice) inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) params_dict = { @@ -224,7 +256,7 @@ async def generate_stream_glm4(params): torch.cuda.empty_cache() -def process_messages(messages, tools=None, tool_choice="none"): +def process_messages(messages, tools=None, tool_choice=None): _messages = messages processed_messages = [] msg_has_sys = False @@ -239,7 +271,7 @@ def process_messages(messages, tools=None, tool_choice="none"): ] return filtered_tools - if tool_choice != "none": + if tool_choice and tool_choice != "none": if isinstance(tool_choice, dict): tools = filter_tools(tool_choice, tools) if tools: @@ -317,7 +349,6 @@ def process_messages(messages, tools=None, tool_choice="none"): return processed_messages - @app.get("/health") async def health() -> Response: """Health check.""" @@ -334,8 +365,8 @@ async def list_models(): async def create_chat_completion(request: ChatCompletionRequest): if len(request.messages) < 1 or request.messages[-1].role == "assistant": raise HTTPException(status_code=400, detail="Invalid request") - - + if request.tool_choice is None: + request.tool_choice = "auto" if request.tools else "none" gen_params = dict( messages=request.messages, temperature=request.temperature, @@ -347,187 +378,70 @@ async def create_chat_completion(request: ChatCompletionRequest): tools=request.tools, tool_choice=request.tool_choice, ) - logger.debug(f"==== request ====\n{gen_params}") + logger.debug(f"==== request ====\n{request.model_dump_json()}") if request.stream: predict_stream_generator = predict_stream(request.model, gen_params) - output = await anext(predict_stream_generator) - if output: - return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") - logger.debug(f"First result output:\n{output}") - - function_call = None - if output and request.tools: - try: - function_call = process_response(output, use_tool=True) - except: - logger.warning("Failed to parse tool call") - - if isinstance(function_call, dict): - function_call = FunctionCallResponse(**function_call) - generate = parse_output_text(request.model, output, function_call=function_call) - return EventSourceResponse(generate, media_type="text/event-stream") - else: - return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") + return EventSourceResponse(predict_stream_generator, media_type="text/event-stream", sep="\n") response = "" async for response in generate_stream_glm4(gen_params): pass - - if response["text"].startswith("\n"): - response["text"] = response["text"][1:] - response["text"] = response["text"].strip() - + is_tool_call = is_return_tool_call(response["text"], request.tools) usage = UsageInfo() - - function_call, finish_reason = None, "stop" - tool_calls = None - if request.tools: - try: - function_call = process_response(response["text"], use_tool=True) - except Exception as e: - logger.warning(f"Failed to parse tool call: {e}") - - if isinstance(function_call, dict): - finish_reason = "tool_calls" - function_call_response = FunctionCallResponse(**function_call) - function_call_instance = FunctionCall( - name=function_call_response.name, - arguments=function_call_response.arguments - ) - tool_calls = [ - ChatCompletionMessageToolCall( - id=f"call_{int(time.time() * 1000)}", - function=function_call_instance, - type="function")] - - message = ChatMessage( - role="assistant", - content=None if tool_calls else response["text"], - function_call=None, - tool_calls=tool_calls, - ) - - logger.debug(f"==== message ====\n{message}") - - choice_data = ChatCompletionResponseChoice( - index=0, - message=message, - finish_reason=finish_reason, - ) task_usage = UsageInfo.model_validate(response["usage"]) for usage_key, usage_value in task_usage.model_dump().items(): setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + return ChatCompletionResponse.reply(request.model, response["text"], response["finish_reason"], is_tool_call, usage) - return ChatCompletionResponse( - model=request.model, - id="", - choices=[choice_data], - object="chat.completion", - usage=usage - ) + +def calc_max_tool_name_len(tools: Optional[List[dict]]) -> int: + max_tool_name_len = 0 + if not tools: + return max_tool_name_len + tool_names = [tool['function']['name'] for tool in tools if 'function' in tool and 'name' in tool['function']] + max_tool_name_len = max(len(tool_name) for tool_name in tool_names) + return max_tool_name_len + + +def is_return_tool_call(output: str, tools: Optional[List[dict]]) -> bool: + if not tools: + return False + output = output.strip() + tool_names = [tool['function']['name'] for tool in tools if 'function' in tool and 'name' in tool['function']] + return any(output.startswith(name) for name in tool_names) async def predict_stream(model_id, gen_params): output = "" is_function_call = False has_send_first_chunk = False - function_name = None + tools = gen_params.get("tools") + max_tool_name_len = calc_max_tool_name_len(tools) + finish_reason = "stop" + async for new_response in generate_stream_glm4(gen_params): decoded_unicode = new_response["text"] delta_text = decoded_unicode[len(output):] output = decoded_unicode - lines = output.strip().split("\n") - - if not is_function_call and len(lines) >= 2 and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', lines[0]): - is_function_call = True - function_name = lines[0].strip() - + # read an extra char because the first generate char may be \n + if len(output) <= max_tool_name_len: + continue + if not is_function_call: + is_function_call = is_return_tool_call(output, tools) if is_function_call: - for char in delta_text: - function_call = {"name": function_name, "arguments": char} - message = DeltaMessage( - content=None, - role="assistant", - function_call=function_call - ) - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=message, - finish_reason=None - ) - chunk = ChatCompletionResponse( - model=model_id, - id="", - choices=[choice_data], - created=int(time.time()), - object="chat.completion.chunk" - ) - yield chunk.model_dump_json(exclude_unset=True) + continue else: - if len(output) > 7: - finish_reason = new_response.get("finish_reason", None) - if not has_send_first_chunk: - message = DeltaMessage( - content="", - role="assistant", - function_call=None, - ) - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=message, - finish_reason=finish_reason - ) - chunk = ChatCompletionResponse( - model=model_id, - id="", - choices=[choice_data], - created=int(time.time()), - object="chat.completion.chunk" - ) - yield chunk.model_dump_json(exclude_unset=True) - - send_msg = delta_text if has_send_first_chunk else output - has_send_first_chunk = True - message = DeltaMessage( - content=send_msg, - role="assistant", - function_call=None, - ) - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=message, - finish_reason=finish_reason - ) - chunk = ChatCompletionResponse( - model=model_id, - id="", - choices=[choice_data], - created=int(time.time()), - object="chat.completion.chunk" - ) - yield chunk.model_dump_json(exclude_unset=True) - - if is_function_call: - yield json.dumps({"text": output}) - else: - yield '[DONE]' - - -async def parse_output_text(model_id: str, value: str, function_call: FunctionCallResponse = None): - delta = DeltaMessage(role="assistant", content=value) - if function_call is not None: - delta.function_call = function_call - - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=delta, - finish_reason=None - ) - chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") - yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + finish_reason = new_response["finish_reason"] + send_msg = delta_text if has_send_first_chunk else output[1:] if output.startswith("\n") else output + has_send_first_chunk = True + yield ChatCompletionResponse.stream_reply(model_id, send_msg, finish_reason) + # if the total output length less than the max tool name length, has_send_first_chunk = False + if is_function_call or not has_send_first_chunk: + yield ChatCompletionResponse.stream_reply(model_id, output, finish_reason, is_function_call) yield '[DONE]' + if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) engine_args = AsyncEngineArgs(