fix: tool call bug

This commit is contained in:
liuzhenghua-jk 2024-06-12 08:55:31 +08:00
parent adeeb0e8e0
commit b9ffe763c5
1 changed files with 127 additions and 213 deletions

View File

@ -1,11 +1,13 @@
import time
from asyncio.log import logger
import re
import uvicorn
import gc
import json
import torch
import random
import string
import logging
import torch
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
@ -18,6 +20,7 @@ from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = 'THUDM/glm-4-9b-chat'
# max model length 128k
MAX_MODEL_LENGTH = 8192
@ -40,6 +43,11 @@ app.add_middleware(
)
def generate_id(prefix: str) -> str:
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=24))
return f"{prefix}-{suffix}"
class ModelCard(BaseModel):
id: str
object: str = "model"
@ -72,22 +80,23 @@ class UsageInfo(BaseModel):
class ChatCompletionMessageToolCall(BaseModel):
id: str
id: Optional[str] = Field(default_factory=lambda: generate_id('call'))
function: FunctionCall
type: Literal["function"]
type: Optional[Literal["function"]] = 'function'
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "tool"]
role: Literal["user", "assistant", "system", "function", "tool"]
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
function_call: Optional[FunctionCall] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
role: Optional[Literal["user", "assistant", "function", "system"]] = None
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
function_call: Optional[FunctionCall] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class ChatCompletionResponseChoice(BaseModel):
@ -104,12 +113,78 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel):
model: str
id: str
id: str = Field(default_factory=lambda: generate_id('chatcmpl'))
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
@staticmethod
def _convert_to_tool_calls_from_content(content: str) -> Union[List[ChatCompletionMessageToolCall], str]:
tool_calls = []
content = content.strip()
for response in content.split("<|assistant|>"):
if "\n" in response:
metadata, content = response.split("\n", maxsplit=1)
else:
metadata, content = "", response
if metadata.strip():
parameters = eval(content.strip())
function_call = FunctionCall(
name=metadata.strip(),
arguments=json.dumps(parameters, ensure_ascii=False)
)
tool_calls.append(ChatCompletionMessageToolCall(function=function_call))
return tool_calls if len(tool_calls) > 0 else content
@staticmethod
def stream_reply(model_id: str, content: str, finish_reason: str, use_tool: bool = False) -> str:
if content.startswith("\n"):
content = content[1:]
tool_calls = None
if use_tool:
parsed_tool_calls = ChatCompletionResponse._convert_to_tool_calls_from_content(content)
if isinstance(parsed_tool_calls, list):
tool_calls = parsed_tool_calls
finish_reason = "tool_calls"
content = None
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=content, tool_calls=tool_calls),
finish_reason=finish_reason
)
return ChatCompletionResponse(
model=model_id,
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
).model_dump_json(exclude_none=True)
@staticmethod
def reply(model_id: str, content: str, finish_reason: str, use_tool: bool = False, usage: UsageInfo = None) \
-> 'ChatCompletionResponse':
if content.startswith("\n"):
content = content[1:]
tool_calls = None
if use_tool:
parsed_tool_calls = ChatCompletionResponse._convert_to_tool_calls_from_content(content)
if isinstance(parsed_tool_calls, list):
tool_calls = parsed_tool_calls
finish_reason = "tool_calls"
content = None
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=content, tool_calls=tool_calls),
finish_reason=finish_reason
)
return ChatCompletionResponse(
model=model_id,
choices=[choice_data],
created=int(time.time()),
object="chat.completion",
usage=usage
)
class ChatCompletionRequest(BaseModel):
model: str
@ -119,7 +194,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
tool_choice: Optional[Union[str, dict]] = "None"
tool_choice: Optional[Union[str, dict]] = None
repetition_penalty: Optional[float] = 1.1
@ -133,48 +208,6 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
return scores
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
lines = output.strip().split("\n")
arguments_json = None
special_tools = ["cogview", "simple_browser"]
tool_call_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
if len(lines) >= 2 and tool_call_pattern.match(lines[0]):
function_name = lines[0].strip()
arguments = "\n".join(lines[1:]).strip()
try:
arguments_json = json.loads(arguments)
is_tool_call = True
except json.JSONDecodeError:
is_tool_call = function_name in special_tools
if is_tool_call and use_tool:
content = {
"name": function_name,
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments, ensure_ascii=False)
}
if function_name == "simple_browser":
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
match = search_pattern.match(arguments)
if match:
content["arguments"] = json.dumps({
"query": match.group(1),
"recency_days": int(match.group(2))
}, ensure_ascii=False)
elif function_name == "cogview":
content["arguments"] = json.dumps({
"prompt": arguments
}, ensure_ascii=False)
return content
return output.strip()
@torch.inference_mode()
async def generate_stream_glm4(params):
messages = params["messages"]
@ -184,7 +217,6 @@ async def generate_stream_glm4(params):
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 8192))
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
params_dict = {
@ -224,7 +256,7 @@ async def generate_stream_glm4(params):
torch.cuda.empty_cache()
def process_messages(messages, tools=None, tool_choice="none"):
def process_messages(messages, tools=None, tool_choice=None):
_messages = messages
processed_messages = []
msg_has_sys = False
@ -239,7 +271,7 @@ def process_messages(messages, tools=None, tool_choice="none"):
]
return filtered_tools
if tool_choice != "none":
if tool_choice and tool_choice != "none":
if isinstance(tool_choice, dict):
tools = filter_tools(tool_choice, tools)
if tools:
@ -317,7 +349,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
return processed_messages
@app.get("/health")
async def health() -> Response:
"""Health check."""
@ -334,8 +365,8 @@ async def list_models():
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
if request.tool_choice is None:
request.tool_choice = "auto" if request.tools else "none"
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
@ -347,187 +378,70 @@ async def create_chat_completion(request: ChatCompletionRequest):
tools=request.tools,
tool_choice=request.tool_choice,
)
logger.debug(f"==== request ====\n{gen_params}")
logger.debug(f"==== request ====\n{request.model_dump_json()}")
if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params)
output = await anext(predict_stream_generator)
if output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
logger.debug(f"First result output\n{output}")
function_call = None
if output and request.tools:
try:
function_call = process_response(output, use_tool=True)
except:
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
generate = parse_output_text(request.model, output, function_call=function_call)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream", sep="\n")
response = ""
async for response in generate_stream_glm4(gen_params):
pass
if response["text"].startswith("\n"):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
is_tool_call = is_return_tool_call(response["text"], request.tools)
usage = UsageInfo()
function_call, finish_reason = None, "stop"
tool_calls = None
if request.tools:
try:
function_call = process_response(response["text"], use_tool=True)
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict):
finish_reason = "tool_calls"
function_call_response = FunctionCallResponse(**function_call)
function_call_instance = FunctionCall(
name=function_call_response.name,
arguments=function_call_response.arguments
)
tool_calls = [
ChatCompletionMessageToolCall(
id=f"call_{int(time.time() * 1000)}",
function=function_call_instance,
type="function")]
message = ChatMessage(
role="assistant",
content=None if tool_calls else response["text"],
function_call=None,
tool_calls=tool_calls,
)
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse.reply(request.model, response["text"], response["finish_reason"], is_tool_call, usage)
return ChatCompletionResponse(
model=request.model,
id="",
choices=[choice_data],
object="chat.completion",
usage=usage
)
def calc_max_tool_name_len(tools: Optional[List[dict]]) -> int:
max_tool_name_len = 0
if not tools:
return max_tool_name_len
tool_names = [tool['function']['name'] for tool in tools if 'function' in tool and 'name' in tool['function']]
max_tool_name_len = max(len(tool_name) for tool_name in tool_names)
return max_tool_name_len
def is_return_tool_call(output: str, tools: Optional[List[dict]]) -> bool:
if not tools:
return False
output = output.strip()
tool_names = [tool['function']['name'] for tool in tools if 'function' in tool and 'name' in tool['function']]
return any(output.startswith(name) for name in tool_names)
async def predict_stream(model_id, gen_params):
output = ""
is_function_call = False
has_send_first_chunk = False
function_name = None
tools = gen_params.get("tools")
max_tool_name_len = calc_max_tool_name_len(tools)
finish_reason = "stop"
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode
lines = output.strip().split("\n")
if not is_function_call and len(lines) >= 2 and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', lines[0]):
is_function_call = True
function_name = lines[0].strip()
# read an extra char because the first generate char may be \n
if len(output) <= max_tool_name_len:
continue
if not is_function_call:
is_function_call = is_return_tool_call(output, tools)
if is_function_call:
for char in delta_text:
function_call = {"name": function_name, "arguments": char}
message = DeltaMessage(
content=None,
role="assistant",
function_call=function_call
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
continue
else:
if len(output) > 7:
finish_reason = new_response.get("finish_reason", None)
if not has_send_first_chunk:
message = DeltaMessage(
content="",
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
send_msg = delta_text if has_send_first_chunk else output
has_send_first_chunk = True
message = DeltaMessage(
content=send_msg,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
if is_function_call:
yield json.dumps({"text": output})
else:
yield '[DONE]'
async def parse_output_text(model_id: str, value: str, function_call: FunctionCallResponse = None):
delta = DeltaMessage(role="assistant", content=value)
if function_call is not None:
delta.function_call = function_call
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
finish_reason = new_response["finish_reason"]
send_msg = delta_text if has_send_first_chunk else output[1:] if output.startswith("\n") else output
has_send_first_chunk = True
yield ChatCompletionResponse.stream_reply(model_id, send_msg, finish_reason)
# if the total output length less than the max tool name length, has_send_first_chunk = False
if is_function_call or not has_send_first_chunk:
yield ChatCompletionResponse.stream_reply(model_id, output, finish_reason, is_function_call)
yield '[DONE]'
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
engine_args = AsyncEngineArgs(