跟进OpenAI server部分代码解释

with liuzhenghua
This commit is contained in:
zR 2024-06-18 18:13:12 +08:00
parent a2e501f43e
commit ef015f88f9
2 changed files with 117 additions and 67 deletions

View File

@ -63,8 +63,8 @@ class ModelList(BaseModel):
class FunctionCall(BaseModel):
name: str
arguments: str
name: Optional[str] = None
arguments: Optional[str] = None
class ChoiceDeltaToolCallFunction(BaseModel):
@ -86,6 +86,9 @@ class ChatCompletionMessageToolCall(BaseModel):
class ChatMessage(BaseModel):
# “function” 字段解释:
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
role: Literal["user", "assistant", "system", "tool"]
content: Optional[str] = None
function_call: Optional[ChoiceDeltaToolCallFunction] = None
@ -129,7 +132,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
tool_choice: Optional[Union[str, dict]] = "None"
tool_choice: Optional[Union[str, dict]] = None
repetition_penalty: Optional[float] = 1.1
@ -143,43 +146,45 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
return scores
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
def process_response(output: str, tools, use_tool: bool = False) -> Union[str, dict]:
lines = output.strip().split("\n")
arguments_json = None
special_tools = ["cogview", "simple_browser"]
tools = {tool['function']['name'] for tool in tools['tools']}
tool_call_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
if len(lines) >= 2 and tool_call_pattern.match(lines[0]):
if len(lines) >= 2 and lines[1].startswith("{"):
function_name = lines[0].strip()
arguments = "\n".join(lines[1:]).strip()
if function_name in tools or function_name in special_tools:
try:
arguments_json = json.loads(arguments)
is_tool_call = True
except json.JSONDecodeError:
is_tool_call = function_name in special_tools
try:
arguments_json = json.loads(arguments)
is_tool_call = True
except json.JSONDecodeError:
is_tool_call = function_name in special_tools
if is_tool_call and use_tool:
content = {
"name": function_name,
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
}
if function_name == "simple_browser":
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
match = search_pattern.match(arguments)
if match:
if is_tool_call and use_tool:
content = {
"name": function_name,
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
}
if function_name == "simple_browser":
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
match = search_pattern.match(arguments)
if match:
content["arguments"] = json.dumps({
"query": match.group(1),
"recency_days": int(match.group(2))
}, ensure_ascii=False)
elif function_name == "cogview":
content["arguments"] = json.dumps({
"query": match.group(1),
"recency_days": int(match.group(2))
"prompt": arguments
}, ensure_ascii=False)
elif function_name == "cogview":
content["arguments"] = json.dumps({
"prompt": arguments
}, ensure_ascii=False)
return content
return content
return output.strip()
@ -228,7 +233,6 @@ async def generate_stream_glm4(params):
"finish_reason": output.outputs[0].finish_reason,
}
yield ret
gc.collect()
torch.cuda.empty_cache()
@ -366,7 +370,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
function_call = None
if output and request.tools:
try:
function_call = process_response(output, use_tool=True)
function_call = process_response(output, request.tools, use_tool=True)
except:
logger.warning("Failed to parse tool call")
@ -393,7 +397,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
function_call = process_response(response["text"], use_tool=True)
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict):
finish_reason = "tool_calls"
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
@ -403,7 +406,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
)
tool_calls = [
ChatCompletionMessageToolCall(
id=f"call_{int(time.time() * 1000)}",
id=generate_id('call_', 24),
function=function_call_instance,
type="function")]
@ -437,31 +440,40 @@ async def predict_stream(model_id, gen_params):
output = ""
is_function_call = False
has_send_first_chunk = False
function_name = None
created_time = int(time.time())
function_name = None
response_id = generate_id('chatcmpl-', 29)
system_fingerprint = generate_id('fp_', 9)
tools = {tool['function']['name'] for tool in gen_params['tools']}
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode
lines = output.strip().split("\n")
if not is_function_call and len(lines) >= 2 and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', lines[0]):
is_function_call = True
function_name = lines[0].strip()
# 检查是否为工具
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
if not is_function_call and len(lines) >= 2:
first_line = lines[0].strip()
if first_line in tools:
is_function_call = True
function_name = first_line
# 工具调用返回
if is_function_call:
for char in delta_text:
function_call = {"name": function_name, "arguments": char}
if not has_send_first_chunk:
function_call = {"name": function_name, "arguments": ""}
tool_call = ChatCompletionMessageToolCall(
index=0,
id=generate_id('call_', 24),
function=FunctionCall(**function_call),
type="function"
)
message = DeltaMessage(
content=None,
role=None,
role="assistant",
function_call=None,
tool_calls=[tool_call]
)
@ -478,35 +490,48 @@ async def predict_stream(model_id, gen_params):
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield ""
yield chunk.model_dump_json(exclude_unset=True)
else:
if len(output) > 7:
finish_reason = new_response.get("finish_reason", None)
if not has_send_first_chunk:
message = DeltaMessage(
content="",
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
send_msg = delta_text if has_send_first_chunk else output
has_send_first_chunk = True
function_call = {"name": None, "arguments": delta_text}
tool_call = ChatCompletionMessageToolCall(
index=0,
id=generate_id('call_', 24),
function=FunctionCall(**function_call),
type="function"
)
message = DeltaMessage(
content=None,
role=None,
function_call=None,
tool_calls=[tool_call]
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
continue
# 常规返回
else:
finish_reason = new_response.get("finish_reason", None)
if not has_send_first_chunk:
message = DeltaMessage(
content=send_msg,
content="",
role="assistant",
function_call=None,
)
@ -524,7 +549,29 @@ async def predict_stream(model_id, gen_params):
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
has_send_first_chunk = True
message = DeltaMessage(
content=delta_text,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口
if is_function_call:
yield ChatCompletionResponse(
model=model_id,
@ -544,6 +591,7 @@ async def predict_stream(model_id, gen_params):
object="chat.completion.chunk",
usage=None
).model_dump_json(exclude_unset=True)
yield '[DONE]'
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
@ -562,6 +610,7 @@ async def parse_output_text(model_id: str, value: str, function_call: ChoiceDelt
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
if __name__ == "__main__":
@ -569,6 +618,7 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
tokenizer=MODEL_PATH,
# 如果你有多张显卡,可以在这里设置成你的显卡数量
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,

View File

@ -16,7 +16,7 @@ sentence_transformers>=2.7.0
gradio>=4.33.0
# openai demo
openai>=1.31.1
openai>=1.34.0
einops>=0.7.0
sse-starlette>=2.1.0