parent
a2e501f43e
commit
ef015f88f9
|
@ -63,8 +63,8 @@ class ModelList(BaseModel):
|
|||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||
|
@ -86,6 +86,9 @@ class ChatCompletionMessageToolCall(BaseModel):
|
|||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
# “function” 字段解释:
|
||||
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
|
||||
|
||||
role: Literal["user", "assistant", "system", "tool"]
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
|
@ -129,7 +132,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[Union[dict, List[dict]]] = None
|
||||
tool_choice: Optional[Union[str, dict]] = "None"
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
repetition_penalty: Optional[float] = 1.1
|
||||
|
||||
|
||||
|
@ -143,43 +146,45 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|||
return scores
|
||||
|
||||
|
||||
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
||||
def process_response(output: str, tools, use_tool: bool = False) -> Union[str, dict]:
|
||||
lines = output.strip().split("\n")
|
||||
arguments_json = None
|
||||
special_tools = ["cogview", "simple_browser"]
|
||||
tools = {tool['function']['name'] for tool in tools['tools']}
|
||||
|
||||
tool_call_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
|
||||
|
||||
if len(lines) >= 2 and tool_call_pattern.match(lines[0]):
|
||||
if len(lines) >= 2 and lines[1].startswith("{"):
|
||||
function_name = lines[0].strip()
|
||||
arguments = "\n".join(lines[1:]).strip()
|
||||
if function_name in tools or function_name in special_tools:
|
||||
try:
|
||||
arguments_json = json.loads(arguments)
|
||||
is_tool_call = True
|
||||
except json.JSONDecodeError:
|
||||
is_tool_call = function_name in special_tools
|
||||
|
||||
try:
|
||||
arguments_json = json.loads(arguments)
|
||||
is_tool_call = True
|
||||
except json.JSONDecodeError:
|
||||
is_tool_call = function_name in special_tools
|
||||
|
||||
if is_tool_call and use_tool:
|
||||
content = {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||
ensure_ascii=False)
|
||||
}
|
||||
if function_name == "simple_browser":
|
||||
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
|
||||
match = search_pattern.match(arguments)
|
||||
if match:
|
||||
if is_tool_call and use_tool:
|
||||
content = {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||
ensure_ascii=False)
|
||||
}
|
||||
if function_name == "simple_browser":
|
||||
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
|
||||
match = search_pattern.match(arguments)
|
||||
if match:
|
||||
content["arguments"] = json.dumps({
|
||||
"query": match.group(1),
|
||||
"recency_days": int(match.group(2))
|
||||
}, ensure_ascii=False)
|
||||
elif function_name == "cogview":
|
||||
content["arguments"] = json.dumps({
|
||||
"query": match.group(1),
|
||||
"recency_days": int(match.group(2))
|
||||
"prompt": arguments
|
||||
}, ensure_ascii=False)
|
||||
elif function_name == "cogview":
|
||||
content["arguments"] = json.dumps({
|
||||
"prompt": arguments
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return content
|
||||
return content
|
||||
return output.strip()
|
||||
|
||||
|
||||
|
@ -228,7 +233,6 @@ async def generate_stream_glm4(params):
|
|||
"finish_reason": output.outputs[0].finish_reason,
|
||||
}
|
||||
yield ret
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -366,7 +370,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
function_call = None
|
||||
if output and request.tools:
|
||||
try:
|
||||
function_call = process_response(output, use_tool=True)
|
||||
function_call = process_response(output, request.tools, use_tool=True)
|
||||
except:
|
||||
logger.warning("Failed to parse tool call")
|
||||
|
||||
|
@ -393,7 +397,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
function_call = process_response(response["text"], use_tool=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse tool call: {e}")
|
||||
|
||||
if isinstance(function_call, dict):
|
||||
finish_reason = "tool_calls"
|
||||
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
|
||||
|
@ -403,7 +406,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
)
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{int(time.time() * 1000)}",
|
||||
id=generate_id('call_', 24),
|
||||
function=function_call_instance,
|
||||
type="function")]
|
||||
|
||||
|
@ -437,31 +440,40 @@ async def predict_stream(model_id, gen_params):
|
|||
output = ""
|
||||
is_function_call = False
|
||||
has_send_first_chunk = False
|
||||
function_name = None
|
||||
created_time = int(time.time())
|
||||
function_name = None
|
||||
response_id = generate_id('chatcmpl-', 29)
|
||||
system_fingerprint = generate_id('fp_', 9)
|
||||
tools = {tool['function']['name'] for tool in gen_params['tools']}
|
||||
async for new_response in generate_stream_glm4(gen_params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text = decoded_unicode[len(output):]
|
||||
output = decoded_unicode
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
if not is_function_call and len(lines) >= 2 and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', lines[0]):
|
||||
is_function_call = True
|
||||
function_name = lines[0].strip()
|
||||
# 检查是否为工具
|
||||
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
|
||||
|
||||
if not is_function_call and len(lines) >= 2:
|
||||
first_line = lines[0].strip()
|
||||
if first_line in tools:
|
||||
is_function_call = True
|
||||
function_name = first_line
|
||||
|
||||
# 工具调用返回
|
||||
if is_function_call:
|
||||
for char in delta_text:
|
||||
function_call = {"name": function_name, "arguments": char}
|
||||
if not has_send_first_chunk:
|
||||
function_call = {"name": function_name, "arguments": ""}
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
index=0,
|
||||
id=generate_id('call_', 24),
|
||||
function=FunctionCall(**function_call),
|
||||
type="function"
|
||||
)
|
||||
message = DeltaMessage(
|
||||
content=None,
|
||||
role=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
|
@ -478,35 +490,48 @@ async def predict_stream(model_id, gen_params):
|
|||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield ""
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
else:
|
||||
if len(output) > 7:
|
||||
finish_reason = new_response.get("finish_reason", None)
|
||||
if not has_send_first_chunk:
|
||||
message = DeltaMessage(
|
||||
content="",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
send_msg = delta_text if has_send_first_chunk else output
|
||||
has_send_first_chunk = True
|
||||
|
||||
function_call = {"name": None, "arguments": delta_text}
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
index=0,
|
||||
id=generate_id('call_', 24),
|
||||
function=FunctionCall(**function_call),
|
||||
type="function"
|
||||
)
|
||||
message = DeltaMessage(
|
||||
content=None,
|
||||
role=None,
|
||||
function_call=None,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
|
||||
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
|
||||
continue
|
||||
|
||||
# 常规返回
|
||||
else:
|
||||
finish_reason = new_response.get("finish_reason", None)
|
||||
if not has_send_first_chunk:
|
||||
message = DeltaMessage(
|
||||
content=send_msg,
|
||||
content="",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
|
@ -524,7 +549,29 @@ async def predict_stream(model_id, gen_params):
|
|||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
has_send_first_chunk = True
|
||||
|
||||
message = DeltaMessage(
|
||||
content=delta_text,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口
|
||||
if is_function_call:
|
||||
yield ChatCompletionResponse(
|
||||
model=model_id,
|
||||
|
@ -544,6 +591,7 @@ async def predict_stream(model_id, gen_params):
|
|||
object="chat.completion.chunk",
|
||||
usage=None
|
||||
).model_dump_json(exclude_unset=True)
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
|
||||
|
@ -562,6 +610,7 @@ async def parse_output_text(model_id: str, value: str, function_call: ChoiceDelt
|
|||
object="chat.completion.chunk"
|
||||
)
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -569,6 +618,7 @@ if __name__ == "__main__":
|
|||
engine_args = AsyncEngineArgs(
|
||||
model=MODEL_PATH,
|
||||
tokenizer=MODEL_PATH,
|
||||
# 如果你有多张显卡,可以在这里设置成你的显卡数量
|
||||
tensor_parallel_size=1,
|
||||
dtype="bfloat16",
|
||||
trust_remote_code=True,
|
||||
|
|
|
@ -16,7 +16,7 @@ sentence_transformers>=2.7.0
|
|||
gradio>=4.33.0
|
||||
|
||||
# openai demo
|
||||
openai>=1.31.1
|
||||
openai>=1.34.0
|
||||
einops>=0.7.0
|
||||
sse-starlette>=2.1.0
|
||||
|
||||
|
|
Loading…
Reference in New Issue