跟进OpenAI server部分代码解释

with liuzhenghua
This commit is contained in:
zR 2024-06-18 18:13:12 +08:00
parent a2e501f43e
commit ef015f88f9
2 changed files with 117 additions and 67 deletions

View File

@ -63,8 +63,8 @@ class ModelList(BaseModel):
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
name: str name: Optional[str] = None
arguments: str arguments: Optional[str] = None
class ChoiceDeltaToolCallFunction(BaseModel): class ChoiceDeltaToolCallFunction(BaseModel):
@ -86,6 +86,9 @@ class ChatCompletionMessageToolCall(BaseModel):
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
# “function” 字段解释:
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
role: Literal["user", "assistant", "system", "tool"] role: Literal["user", "assistant", "system", "tool"]
content: Optional[str] = None content: Optional[str] = None
function_call: Optional[ChoiceDeltaToolCallFunction] = None function_call: Optional[ChoiceDeltaToolCallFunction] = None
@ -129,7 +132,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: Optional[bool] = False stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None tools: Optional[Union[dict, List[dict]]] = None
tool_choice: Optional[Union[str, dict]] = "None" tool_choice: Optional[Union[str, dict]] = None
repetition_penalty: Optional[float] = 1.1 repetition_penalty: Optional[float] = 1.1
@ -143,17 +146,19 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
return scores return scores
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: def process_response(output: str, tools, use_tool: bool = False) -> Union[str, dict]:
lines = output.strip().split("\n") lines = output.strip().split("\n")
arguments_json = None arguments_json = None
special_tools = ["cogview", "simple_browser"] special_tools = ["cogview", "simple_browser"]
tools = {tool['function']['name'] for tool in tools['tools']}
tool_call_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') # 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
if len(lines) >= 2 and tool_call_pattern.match(lines[0]): if len(lines) >= 2 and lines[1].startswith("{"):
function_name = lines[0].strip() function_name = lines[0].strip()
arguments = "\n".join(lines[1:]).strip() arguments = "\n".join(lines[1:]).strip()
if function_name in tools or function_name in special_tools:
try: try:
arguments_json = json.loads(arguments) arguments_json = json.loads(arguments)
is_tool_call = True is_tool_call = True
@ -228,7 +233,6 @@ async def generate_stream_glm4(params):
"finish_reason": output.outputs[0].finish_reason, "finish_reason": output.outputs[0].finish_reason,
} }
yield ret yield ret
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -366,7 +370,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
function_call = None function_call = None
if output and request.tools: if output and request.tools:
try: try:
function_call = process_response(output, use_tool=True) function_call = process_response(output, request.tools, use_tool=True)
except: except:
logger.warning("Failed to parse tool call") logger.warning("Failed to parse tool call")
@ -393,7 +397,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
function_call = process_response(response["text"], use_tool=True) function_call = process_response(response["text"], use_tool=True)
except Exception as e: except Exception as e:
logger.warning(f"Failed to parse tool call: {e}") logger.warning(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict): if isinstance(function_call, dict):
finish_reason = "tool_calls" finish_reason = "tool_calls"
function_call_response = ChoiceDeltaToolCallFunction(**function_call) function_call_response = ChoiceDeltaToolCallFunction(**function_call)
@ -403,7 +406,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
) )
tool_calls = [ tool_calls = [
ChatCompletionMessageToolCall( ChatCompletionMessageToolCall(
id=f"call_{int(time.time() * 1000)}", id=generate_id('call_', 24),
function=function_call_instance, function=function_call_instance,
type="function")] type="function")]
@ -437,25 +440,64 @@ async def predict_stream(model_id, gen_params):
output = "" output = ""
is_function_call = False is_function_call = False
has_send_first_chunk = False has_send_first_chunk = False
function_name = None
created_time = int(time.time()) created_time = int(time.time())
function_name = None
response_id = generate_id('chatcmpl-', 29) response_id = generate_id('chatcmpl-', 29)
system_fingerprint = generate_id('fp_', 9) system_fingerprint = generate_id('fp_', 9)
tools = {tool['function']['name'] for tool in gen_params['tools']}
async for new_response in generate_stream_glm4(gen_params): async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"] decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):] delta_text = decoded_unicode[len(output):]
output = decoded_unicode output = decoded_unicode
lines = output.strip().split("\n") lines = output.strip().split("\n")
if not is_function_call and len(lines) >= 2 and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', lines[0]): # 检查是否为工具
is_function_call = True # 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
function_name = lines[0].strip() ##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
if not is_function_call and len(lines) >= 2:
first_line = lines[0].strip()
if first_line in tools:
is_function_call = True
function_name = first_line
# 工具调用返回
if is_function_call: if is_function_call:
for char in delta_text: if not has_send_first_chunk:
function_call = {"name": function_name, "arguments": char} function_call = {"name": function_name, "arguments": ""}
tool_call = ChatCompletionMessageToolCall( tool_call = ChatCompletionMessageToolCall(
index=0, index=0,
id=generate_id('call_', 24),
function=FunctionCall(**function_call),
type="function"
)
message = DeltaMessage(
content=None,
role="assistant",
function_call=None,
tool_calls=[tool_call]
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=None
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield ""
yield chunk.model_dump_json(exclude_unset=True)
has_send_first_chunk = True
function_call = {"name": None, "arguments": delta_text}
tool_call = ChatCompletionMessageToolCall(
index=0,
id=generate_id('call_', 24),
function=FunctionCall(**function_call), function=FunctionCall(**function_call),
type="function" type="function"
) )
@ -479,8 +521,13 @@ async def predict_stream(model_id, gen_params):
object="chat.completion.chunk" object="chat.completion.chunk"
) )
yield chunk.model_dump_json(exclude_unset=True) yield chunk.model_dump_json(exclude_unset=True)
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
continue
# 常规返回
else: else:
if len(output) > 7:
finish_reason = new_response.get("finish_reason", None) finish_reason = new_response.get("finish_reason", None)
if not has_send_first_chunk: if not has_send_first_chunk:
message = DeltaMessage( message = DeltaMessage(
@ -502,11 +549,10 @@ async def predict_stream(model_id, gen_params):
object="chat.completion.chunk" object="chat.completion.chunk"
) )
yield chunk.model_dump_json(exclude_unset=True) yield chunk.model_dump_json(exclude_unset=True)
send_msg = delta_text if has_send_first_chunk else output
has_send_first_chunk = True has_send_first_chunk = True
message = DeltaMessage( message = DeltaMessage(
content=send_msg, content=delta_text,
role="assistant", role="assistant",
function_call=None, function_call=None,
) )
@ -525,6 +571,7 @@ async def predict_stream(model_id, gen_params):
) )
yield chunk.model_dump_json(exclude_unset=True) yield chunk.model_dump_json(exclude_unset=True)
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口
if is_function_call: if is_function_call:
yield ChatCompletionResponse( yield ChatCompletionResponse(
model=model_id, model=model_id,
@ -544,6 +591,7 @@ async def predict_stream(model_id, gen_params):
object="chat.completion.chunk", object="chat.completion.chunk",
usage=None usage=None
).model_dump_json(exclude_unset=True) ).model_dump_json(exclude_unset=True)
yield '[DONE]'
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None): async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
@ -562,6 +610,7 @@ async def parse_output_text(model_id: str, value: str, function_call: ChoiceDelt
object="chat.completion.chunk" object="chat.completion.chunk"
) )
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
if __name__ == "__main__": if __name__ == "__main__":
@ -569,6 +618,7 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=MODEL_PATH, model=MODEL_PATH,
tokenizer=MODEL_PATH, tokenizer=MODEL_PATH,
# 如果你有多张显卡,可以在这里设置成你的显卡数量
tensor_parallel_size=1, tensor_parallel_size=1,
dtype="bfloat16", dtype="bfloat16",
trust_remote_code=True, trust_remote_code=True,

View File

@ -16,7 +16,7 @@ sentence_transformers>=2.7.0
gradio>=4.33.0 gradio>=4.33.0
# openai demo # openai demo
openai>=1.31.1 openai>=1.34.0
einops>=0.7.0 einops>=0.7.0
sse-starlette>=2.1.0 sse-starlette>=2.1.0