Merge pull request #204 from linnnnnzf/main

修正非流式输出和process_response tools参数的bug
This commit is contained in:
zR 2024-06-18 23:03:42 +08:00 committed by GitHub
commit 7c62112619
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 3 deletions

View File

@ -146,11 +146,11 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
return scores
def process_response(output: str, tools, use_tool: bool = False) -> Union[str, dict]:
def process_response(output: str, tools: dict | List[dict] = None, use_tool: bool = False) -> Union[str, dict]:
lines = output.strip().split("\n")
arguments_json = None
special_tools = ["cogview", "simple_browser"]
tools = {tool['function']['name'] for tool in tools['tools']}
tools = {tool['function']['name'] for tool in tools}
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
@ -394,7 +394,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
tool_calls = None
if request.tools:
try:
function_call = process_response(response["text"], use_tool=True)
function_call = process_response(response["text"], request.tools, use_tool=True)
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict):