fix: tool call bug

This commit is contained in:
liuzhenghua-jk 2024-06-12 08:55:31 +08:00
parent adeeb0e8e0
commit b9ffe763c5
1 changed files with 127 additions and 213 deletions

View File

@ -1,11 +1,13 @@
import time import time
from asyncio.log import logger from asyncio.log import logger
import re
import uvicorn import uvicorn
import gc import gc
import json import json
import torch import random
import string
import logging
import torch
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from fastapi import FastAPI, HTTPException, Response from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -18,6 +20,7 @@ from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = 'THUDM/glm-4-9b-chat' MODEL_PATH = 'THUDM/glm-4-9b-chat'
# max model length 128k
MAX_MODEL_LENGTH = 8192 MAX_MODEL_LENGTH = 8192
@ -40,6 +43,11 @@ app.add_middleware(
) )
def generate_id(prefix: str) -> str:
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=24))
return f"{prefix}-{suffix}"
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str
object: str = "model" object: str = "model"
@ -72,22 +80,23 @@ class UsageInfo(BaseModel):
class ChatCompletionMessageToolCall(BaseModel): class ChatCompletionMessageToolCall(BaseModel):
id: str id: Optional[str] = Field(default_factory=lambda: generate_id('call'))
function: FunctionCall function: FunctionCall
type: Literal["function"] type: Optional[Literal["function"]] = 'function'
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "tool"] role: Literal["user", "assistant", "system", "function", "tool"]
content: Optional[str] = None content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None function_call: Optional[FunctionCall] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None role: Optional[Literal["user", "assistant", "function", "system"]] = None
content: Optional[str] = None content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None function_call: Optional[FunctionCall] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
@ -104,12 +113,78 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
model: str model: str
id: str id: str = Field(default_factory=lambda: generate_id('chatcmpl'))
object: Literal["chat.completion", "chat.completion.chunk"] object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None usage: Optional[UsageInfo] = None
@staticmethod
def _convert_to_tool_calls_from_content(content: str) -> Union[List[ChatCompletionMessageToolCall], str]:
tool_calls = []
content = content.strip()
for response in content.split("<|assistant|>"):
if "\n" in response:
metadata, content = response.split("\n", maxsplit=1)
else:
metadata, content = "", response
if metadata.strip():
parameters = eval(content.strip())
function_call = FunctionCall(
name=metadata.strip(),
arguments=json.dumps(parameters, ensure_ascii=False)
)
tool_calls.append(ChatCompletionMessageToolCall(function=function_call))
return tool_calls if len(tool_calls) > 0 else content
@staticmethod
def stream_reply(model_id: str, content: str, finish_reason: str, use_tool: bool = False) -> str:
if content.startswith("\n"):
content = content[1:]
tool_calls = None
if use_tool:
parsed_tool_calls = ChatCompletionResponse._convert_to_tool_calls_from_content(content)
if isinstance(parsed_tool_calls, list):
tool_calls = parsed_tool_calls
finish_reason = "tool_calls"
content = None
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=content, tool_calls=tool_calls),
finish_reason=finish_reason
)
return ChatCompletionResponse(
model=model_id,
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
).model_dump_json(exclude_none=True)
@staticmethod
def reply(model_id: str, content: str, finish_reason: str, use_tool: bool = False, usage: UsageInfo = None) \
-> 'ChatCompletionResponse':
if content.startswith("\n"):
content = content[1:]
tool_calls = None
if use_tool:
parsed_tool_calls = ChatCompletionResponse._convert_to_tool_calls_from_content(content)
if isinstance(parsed_tool_calls, list):
tool_calls = parsed_tool_calls
finish_reason = "tool_calls"
content = None
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=content, tool_calls=tool_calls),
finish_reason=finish_reason
)
return ChatCompletionResponse(
model=model_id,
choices=[choice_data],
created=int(time.time()),
object="chat.completion",
usage=usage
)
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
@ -119,7 +194,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: Optional[bool] = False stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None tools: Optional[Union[dict, List[dict]]] = None
tool_choice: Optional[Union[str, dict]] = "None" tool_choice: Optional[Union[str, dict]] = None
repetition_penalty: Optional[float] = 1.1 repetition_penalty: Optional[float] = 1.1
@ -133,48 +208,6 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
return scores return scores
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
lines = output.strip().split("\n")
arguments_json = None
special_tools = ["cogview", "simple_browser"]
tool_call_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
if len(lines) >= 2 and tool_call_pattern.match(lines[0]):
function_name = lines[0].strip()
arguments = "\n".join(lines[1:]).strip()
try:
arguments_json = json.loads(arguments)
is_tool_call = True
except json.JSONDecodeError:
is_tool_call = function_name in special_tools
if is_tool_call and use_tool:
content = {
"name": function_name,
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments, ensure_ascii=False)
}
if function_name == "simple_browser":
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
match = search_pattern.match(arguments)
if match:
content["arguments"] = json.dumps({
"query": match.group(1),
"recency_days": int(match.group(2))
}, ensure_ascii=False)
elif function_name == "cogview":
content["arguments"] = json.dumps({
"prompt": arguments
}, ensure_ascii=False)
return content
return output.strip()
@torch.inference_mode() @torch.inference_mode()
async def generate_stream_glm4(params): async def generate_stream_glm4(params):
messages = params["messages"] messages = params["messages"]
@ -184,7 +217,6 @@ async def generate_stream_glm4(params):
repetition_penalty = float(params.get("repetition_penalty", 1.0)) repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0)) top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 8192)) max_new_tokens = int(params.get("max_tokens", 8192))
messages = process_messages(messages, tools=tools, tool_choice=tool_choice) messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
params_dict = { params_dict = {
@ -224,7 +256,7 @@ async def generate_stream_glm4(params):
torch.cuda.empty_cache() torch.cuda.empty_cache()
def process_messages(messages, tools=None, tool_choice="none"): def process_messages(messages, tools=None, tool_choice=None):
_messages = messages _messages = messages
processed_messages = [] processed_messages = []
msg_has_sys = False msg_has_sys = False
@ -239,7 +271,7 @@ def process_messages(messages, tools=None, tool_choice="none"):
] ]
return filtered_tools return filtered_tools
if tool_choice != "none": if tool_choice and tool_choice != "none":
if isinstance(tool_choice, dict): if isinstance(tool_choice, dict):
tools = filter_tools(tool_choice, tools) tools = filter_tools(tool_choice, tools)
if tools: if tools:
@ -317,7 +349,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
return processed_messages return processed_messages
@app.get("/health") @app.get("/health")
async def health() -> Response: async def health() -> Response:
"""Health check.""" """Health check."""
@ -334,8 +365,8 @@ async def list_models():
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant": if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
if request.tool_choice is None:
request.tool_choice = "auto" if request.tools else "none"
gen_params = dict( gen_params = dict(
messages=request.messages, messages=request.messages,
temperature=request.temperature, temperature=request.temperature,
@ -347,187 +378,70 @@ async def create_chat_completion(request: ChatCompletionRequest):
tools=request.tools, tools=request.tools,
tool_choice=request.tool_choice, tool_choice=request.tool_choice,
) )
logger.debug(f"==== request ====\n{gen_params}") logger.debug(f"==== request ====\n{request.model_dump_json()}")
if request.stream: if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params) predict_stream_generator = predict_stream(request.model, gen_params)
output = await anext(predict_stream_generator) return EventSourceResponse(predict_stream_generator, media_type="text/event-stream", sep="\n")
if output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
logger.debug(f"First result output\n{output}")
function_call = None
if output and request.tools:
try:
function_call = process_response(output, use_tool=True)
except:
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
generate = parse_output_text(request.model, output, function_call=function_call)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
response = "" response = ""
async for response in generate_stream_glm4(gen_params): async for response in generate_stream_glm4(gen_params):
pass pass
is_tool_call = is_return_tool_call(response["text"], request.tools)
if response["text"].startswith("\n"):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
usage = UsageInfo() usage = UsageInfo()
function_call, finish_reason = None, "stop"
tool_calls = None
if request.tools:
try:
function_call = process_response(response["text"], use_tool=True)
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict):
finish_reason = "tool_calls"
function_call_response = FunctionCallResponse(**function_call)
function_call_instance = FunctionCall(
name=function_call_response.name,
arguments=function_call_response.arguments
)
tool_calls = [
ChatCompletionMessageToolCall(
id=f"call_{int(time.time() * 1000)}",
function=function_call_instance,
type="function")]
message = ChatMessage(
role="assistant",
content=None if tool_calls else response["text"],
function_call=None,
tool_calls=tool_calls,
)
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.model_validate(response["usage"]) task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items(): for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse.reply(request.model, response["text"], response["finish_reason"], is_tool_call, usage)
return ChatCompletionResponse(
model=request.model, def calc_max_tool_name_len(tools: Optional[List[dict]]) -> int:
id="", max_tool_name_len = 0
choices=[choice_data], if not tools:
object="chat.completion", return max_tool_name_len
usage=usage tool_names = [tool['function']['name'] for tool in tools if 'function' in tool and 'name' in tool['function']]
) max_tool_name_len = max(len(tool_name) for tool_name in tool_names)
return max_tool_name_len
def is_return_tool_call(output: str, tools: Optional[List[dict]]) -> bool:
if not tools:
return False
output = output.strip()
tool_names = [tool['function']['name'] for tool in tools if 'function' in tool and 'name' in tool['function']]
return any(output.startswith(name) for name in tool_names)
async def predict_stream(model_id, gen_params): async def predict_stream(model_id, gen_params):
output = "" output = ""
is_function_call = False is_function_call = False
has_send_first_chunk = False has_send_first_chunk = False
function_name = None tools = gen_params.get("tools")
max_tool_name_len = calc_max_tool_name_len(tools)
finish_reason = "stop"
async for new_response in generate_stream_glm4(gen_params): async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"] decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):] delta_text = decoded_unicode[len(output):]
output = decoded_unicode output = decoded_unicode
lines = output.strip().split("\n") # read an extra char because the first generate char may be \n
if len(output) <= max_tool_name_len:
if not is_function_call and len(lines) >= 2 and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', lines[0]): continue
is_function_call = True if not is_function_call:
function_name = lines[0].strip() is_function_call = is_return_tool_call(output, tools)
if is_function_call: if is_function_call:
for char in delta_text: continue
function_call = {"name": function_name, "arguments": char}
message = DeltaMessage(
content=None,
role="assistant",
function_call=function_call
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
else: else:
if len(output) > 7: finish_reason = new_response["finish_reason"]
finish_reason = new_response.get("finish_reason", None) send_msg = delta_text if has_send_first_chunk else output[1:] if output.startswith("\n") else output
if not has_send_first_chunk: has_send_first_chunk = True
message = DeltaMessage( yield ChatCompletionResponse.stream_reply(model_id, send_msg, finish_reason)
content="", # if the total output length less than the max tool name length, has_send_first_chunk = False
role="assistant", if is_function_call or not has_send_first_chunk:
function_call=None, yield ChatCompletionResponse.stream_reply(model_id, output, finish_reason, is_function_call)
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
send_msg = delta_text if has_send_first_chunk else output
has_send_first_chunk = True
message = DeltaMessage(
content=send_msg,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
if is_function_call:
yield json.dumps({"text": output})
else:
yield '[DONE]'
async def parse_output_text(model_id: str, value: str, function_call: FunctionCallResponse = None):
delta = DeltaMessage(role="assistant", content=value)
if function_call is not None:
delta.function_call = function_call
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]' yield '[DONE]'
if __name__ == "__main__": if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(