parent
a2e501f43e
commit
ef015f88f9
|
@ -63,8 +63,8 @@ class ModelList(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
name: str
|
name: Optional[str] = None
|
||||||
arguments: str
|
arguments: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChoiceDeltaToolCallFunction(BaseModel):
|
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||||
|
@ -86,6 +86,9 @@ class ChatCompletionMessageToolCall(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
# “function” 字段解释:
|
||||||
|
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
|
||||||
|
|
||||||
role: Literal["user", "assistant", "system", "tool"]
|
role: Literal["user", "assistant", "system", "tool"]
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||||
|
@ -129,7 +132,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tools: Optional[Union[dict, List[dict]]] = None
|
tools: Optional[Union[dict, List[dict]]] = None
|
||||||
tool_choice: Optional[Union[str, dict]] = "None"
|
tool_choice: Optional[Union[str, dict]] = None
|
||||||
repetition_penalty: Optional[float] = 1.1
|
repetition_penalty: Optional[float] = 1.1
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,17 +146,19 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
def process_response(output: str, tools, use_tool: bool = False) -> Union[str, dict]:
|
||||||
lines = output.strip().split("\n")
|
lines = output.strip().split("\n")
|
||||||
arguments_json = None
|
arguments_json = None
|
||||||
special_tools = ["cogview", "simple_browser"]
|
special_tools = ["cogview", "simple_browser"]
|
||||||
|
tools = {tool['function']['name'] for tool in tools['tools']}
|
||||||
|
|
||||||
tool_call_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||||
|
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
|
||||||
|
|
||||||
if len(lines) >= 2 and tool_call_pattern.match(lines[0]):
|
if len(lines) >= 2 and lines[1].startswith("{"):
|
||||||
function_name = lines[0].strip()
|
function_name = lines[0].strip()
|
||||||
arguments = "\n".join(lines[1:]).strip()
|
arguments = "\n".join(lines[1:]).strip()
|
||||||
|
if function_name in tools or function_name in special_tools:
|
||||||
try:
|
try:
|
||||||
arguments_json = json.loads(arguments)
|
arguments_json = json.loads(arguments)
|
||||||
is_tool_call = True
|
is_tool_call = True
|
||||||
|
@ -228,7 +233,6 @@ async def generate_stream_glm4(params):
|
||||||
"finish_reason": output.outputs[0].finish_reason,
|
"finish_reason": output.outputs[0].finish_reason,
|
||||||
}
|
}
|
||||||
yield ret
|
yield ret
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -366,7 +370,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
function_call = None
|
function_call = None
|
||||||
if output and request.tools:
|
if output and request.tools:
|
||||||
try:
|
try:
|
||||||
function_call = process_response(output, use_tool=True)
|
function_call = process_response(output, request.tools, use_tool=True)
|
||||||
except:
|
except:
|
||||||
logger.warning("Failed to parse tool call")
|
logger.warning("Failed to parse tool call")
|
||||||
|
|
||||||
|
@ -393,7 +397,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
function_call = process_response(response["text"], use_tool=True)
|
function_call = process_response(response["text"], use_tool=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to parse tool call: {e}")
|
logger.warning(f"Failed to parse tool call: {e}")
|
||||||
|
|
||||||
if isinstance(function_call, dict):
|
if isinstance(function_call, dict):
|
||||||
finish_reason = "tool_calls"
|
finish_reason = "tool_calls"
|
||||||
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
|
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
|
||||||
|
@ -403,7 +406,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
)
|
)
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ChatCompletionMessageToolCall(
|
ChatCompletionMessageToolCall(
|
||||||
id=f"call_{int(time.time() * 1000)}",
|
id=generate_id('call_', 24),
|
||||||
function=function_call_instance,
|
function=function_call_instance,
|
||||||
type="function")]
|
type="function")]
|
||||||
|
|
||||||
|
@ -437,25 +440,64 @@ async def predict_stream(model_id, gen_params):
|
||||||
output = ""
|
output = ""
|
||||||
is_function_call = False
|
is_function_call = False
|
||||||
has_send_first_chunk = False
|
has_send_first_chunk = False
|
||||||
function_name = None
|
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
function_name = None
|
||||||
response_id = generate_id('chatcmpl-', 29)
|
response_id = generate_id('chatcmpl-', 29)
|
||||||
system_fingerprint = generate_id('fp_', 9)
|
system_fingerprint = generate_id('fp_', 9)
|
||||||
|
tools = {tool['function']['name'] for tool in gen_params['tools']}
|
||||||
async for new_response in generate_stream_glm4(gen_params):
|
async for new_response in generate_stream_glm4(gen_params):
|
||||||
decoded_unicode = new_response["text"]
|
decoded_unicode = new_response["text"]
|
||||||
delta_text = decoded_unicode[len(output):]
|
delta_text = decoded_unicode[len(output):]
|
||||||
output = decoded_unicode
|
output = decoded_unicode
|
||||||
lines = output.strip().split("\n")
|
lines = output.strip().split("\n")
|
||||||
|
|
||||||
if not is_function_call and len(lines) >= 2 and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', lines[0]):
|
# 检查是否为工具
|
||||||
is_function_call = True
|
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||||
function_name = lines[0].strip()
|
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
|
||||||
|
|
||||||
|
if not is_function_call and len(lines) >= 2:
|
||||||
|
first_line = lines[0].strip()
|
||||||
|
if first_line in tools:
|
||||||
|
is_function_call = True
|
||||||
|
function_name = first_line
|
||||||
|
|
||||||
|
# 工具调用返回
|
||||||
if is_function_call:
|
if is_function_call:
|
||||||
for char in delta_text:
|
if not has_send_first_chunk:
|
||||||
function_call = {"name": function_name, "arguments": char}
|
function_call = {"name": function_name, "arguments": ""}
|
||||||
tool_call = ChatCompletionMessageToolCall(
|
tool_call = ChatCompletionMessageToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
id=generate_id('call_', 24),
|
||||||
|
function=FunctionCall(**function_call),
|
||||||
|
type="function"
|
||||||
|
)
|
||||||
|
message = DeltaMessage(
|
||||||
|
content=None,
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=[tool_call]
|
||||||
|
)
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=message,
|
||||||
|
finish_reason=None
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionResponse(
|
||||||
|
model=model_id,
|
||||||
|
id=response_id,
|
||||||
|
choices=[choice_data],
|
||||||
|
created=created_time,
|
||||||
|
system_fingerprint=system_fingerprint,
|
||||||
|
object="chat.completion.chunk"
|
||||||
|
)
|
||||||
|
yield ""
|
||||||
|
yield chunk.model_dump_json(exclude_unset=True)
|
||||||
|
has_send_first_chunk = True
|
||||||
|
|
||||||
|
function_call = {"name": None, "arguments": delta_text}
|
||||||
|
tool_call = ChatCompletionMessageToolCall(
|
||||||
|
index=0,
|
||||||
|
id=generate_id('call_', 24),
|
||||||
function=FunctionCall(**function_call),
|
function=FunctionCall(**function_call),
|
||||||
type="function"
|
type="function"
|
||||||
)
|
)
|
||||||
|
@ -479,8 +521,13 @@ async def predict_stream(model_id, gen_params):
|
||||||
object="chat.completion.chunk"
|
object="chat.completion.chunk"
|
||||||
)
|
)
|
||||||
yield chunk.model_dump_json(exclude_unset=True)
|
yield chunk.model_dump_json(exclude_unset=True)
|
||||||
|
|
||||||
|
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
|
||||||
|
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 常规返回
|
||||||
else:
|
else:
|
||||||
if len(output) > 7:
|
|
||||||
finish_reason = new_response.get("finish_reason", None)
|
finish_reason = new_response.get("finish_reason", None)
|
||||||
if not has_send_first_chunk:
|
if not has_send_first_chunk:
|
||||||
message = DeltaMessage(
|
message = DeltaMessage(
|
||||||
|
@ -502,11 +549,10 @@ async def predict_stream(model_id, gen_params):
|
||||||
object="chat.completion.chunk"
|
object="chat.completion.chunk"
|
||||||
)
|
)
|
||||||
yield chunk.model_dump_json(exclude_unset=True)
|
yield chunk.model_dump_json(exclude_unset=True)
|
||||||
|
|
||||||
send_msg = delta_text if has_send_first_chunk else output
|
|
||||||
has_send_first_chunk = True
|
has_send_first_chunk = True
|
||||||
|
|
||||||
message = DeltaMessage(
|
message = DeltaMessage(
|
||||||
content=send_msg,
|
content=delta_text,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
function_call=None,
|
function_call=None,
|
||||||
)
|
)
|
||||||
|
@ -525,6 +571,7 @@ async def predict_stream(model_id, gen_params):
|
||||||
)
|
)
|
||||||
yield chunk.model_dump_json(exclude_unset=True)
|
yield chunk.model_dump_json(exclude_unset=True)
|
||||||
|
|
||||||
|
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口
|
||||||
if is_function_call:
|
if is_function_call:
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
|
@ -544,6 +591,7 @@ async def predict_stream(model_id, gen_params):
|
||||||
object="chat.completion.chunk",
|
object="chat.completion.chunk",
|
||||||
usage=None
|
usage=None
|
||||||
).model_dump_json(exclude_unset=True)
|
).model_dump_json(exclude_unset=True)
|
||||||
|
yield '[DONE]'
|
||||||
|
|
||||||
|
|
||||||
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
|
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
|
||||||
|
@ -562,6 +610,7 @@ async def parse_output_text(model_id: str, value: str, function_call: ChoiceDelt
|
||||||
object="chat.completion.chunk"
|
object="chat.completion.chunk"
|
||||||
)
|
)
|
||||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||||
|
yield '[DONE]'
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -569,6 +618,7 @@ if __name__ == "__main__":
|
||||||
engine_args = AsyncEngineArgs(
|
engine_args = AsyncEngineArgs(
|
||||||
model=MODEL_PATH,
|
model=MODEL_PATH,
|
||||||
tokenizer=MODEL_PATH,
|
tokenizer=MODEL_PATH,
|
||||||
|
# 如果你有多张显卡,可以在这里设置成你的显卡数量
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|
|
@ -16,7 +16,7 @@ sentence_transformers>=2.7.0
|
||||||
gradio>=4.33.0
|
gradio>=4.33.0
|
||||||
|
|
||||||
# openai demo
|
# openai demo
|
||||||
openai>=1.31.1
|
openai>=1.34.0
|
||||||
einops>=0.7.0
|
einops>=0.7.0
|
||||||
sse-starlette>=2.1.0
|
sse-starlette>=2.1.0
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue