Merge pull request #28 from xusenlinzy/main

fix process response function
This commit is contained in:
zR 2024-06-05 18:07:07 +08:00 committed by GitHub
commit a263b69376
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 14 additions and 7 deletions

View File

@ -18,6 +18,7 @@ from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = 'THUDM/glm-4-9b'
MAX_MODEL_LENGTH = 8192
@asynccontextmanager
@ -141,14 +142,16 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
content = ""
for response in output.split(""):
metadata, content = response.split("\n", maxsplit=1)
for response in output.split("<|assistant|>"):
if "\n" in response:
metadata, content = response.split("\n", maxsplit=1)
else:
metadata, content = "", response
if not metadata.strip():
content = content.strip()
else:
if use_tool:
content = "\n".join(content.split("\n")[1:-1])
parameters = eval(content)
parameters = eval(content.strip())
content = {
"name": metadata.strip(),
"arguments": json.dumps(parameters, ensure_ascii=False)
@ -257,8 +260,11 @@ def process_messages(messages, tools=None, tool_choice="none"):
}
)
elif role == "assistant" and func_call is not None:
for response in content.split(""):
metadata, sub_content = response.split("\n", maxsplit=1)
for response in content.split("<|assistant|>"):
if "\n" in response:
metadata, sub_content = response.split("\n", maxsplit=1)
else:
metadata, sub_content = "", response
messages.append(
{
"role": role,
@ -537,7 +543,8 @@ if __name__ == "__main__":
enforce_eager=True,
worker_use_ray=True,
engine_use_ray=False,
disable_log_requests=True
disable_log_requests=True,
max_model_len=MAX_MODEL_LENGTH,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)