Merge pull request #28 from xusenlinzy/main
fix process response function
This commit is contained in:
commit
a263b69376
|
@ -18,6 +18,7 @@ from sse_starlette.sse import EventSourceResponse
|
|||
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
MODEL_PATH = 'THUDM/glm-4-9b'
|
||||
MAX_MODEL_LENGTH = 8192
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -141,14 +142,16 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|||
|
||||
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
||||
content = ""
|
||||
for response in output.split(""):
|
||||
metadata, content = response.split("\n", maxsplit=1)
|
||||
for response in output.split("<|assistant|>"):
|
||||
if "\n" in response:
|
||||
metadata, content = response.split("\n", maxsplit=1)
|
||||
else:
|
||||
metadata, content = "", response
|
||||
if not metadata.strip():
|
||||
content = content.strip()
|
||||
else:
|
||||
if use_tool:
|
||||
content = "\n".join(content.split("\n")[1:-1])
|
||||
parameters = eval(content)
|
||||
parameters = eval(content.strip())
|
||||
content = {
|
||||
"name": metadata.strip(),
|
||||
"arguments": json.dumps(parameters, ensure_ascii=False)
|
||||
|
@ -257,8 +260,11 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
|||
}
|
||||
)
|
||||
elif role == "assistant" and func_call is not None:
|
||||
for response in content.split(""):
|
||||
metadata, sub_content = response.split("\n", maxsplit=1)
|
||||
for response in content.split("<|assistant|>"):
|
||||
if "\n" in response:
|
||||
metadata, sub_content = response.split("\n", maxsplit=1)
|
||||
else:
|
||||
metadata, sub_content = "", response
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
|
@ -537,7 +543,8 @@ if __name__ == "__main__":
|
|||
enforce_eager=True,
|
||||
worker_use_ray=True,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True
|
||||
disable_log_requests=True,
|
||||
max_model_len=MAX_MODEL_LENGTH,
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
|
|
Loading…
Reference in New Issue