fix process response function
This commit is contained in:
parent
4d5194d758
commit
11e244c0f2
|
@ -18,6 +18,7 @@ from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||||
MODEL_PATH = 'THUDM/glm-4-9b'
|
MODEL_PATH = 'THUDM/glm-4-9b'
|
||||||
|
MAX_MODEL_LENGTH = 8192
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
@ -141,14 +142,16 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
||||||
content = ""
|
content = ""
|
||||||
for response in output.split(""):
|
for response in output.split("<|assistant|>"):
|
||||||
metadata, content = response.split("\n", maxsplit=1)
|
if "\n" in response:
|
||||||
|
metadata, content = response.split("\n", maxsplit=1)
|
||||||
|
else:
|
||||||
|
metadata, content = "", response
|
||||||
if not metadata.strip():
|
if not metadata.strip():
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
else:
|
else:
|
||||||
if use_tool:
|
if use_tool:
|
||||||
content = "\n".join(content.split("\n")[1:-1])
|
parameters = eval(content.strip())
|
||||||
parameters = eval(content)
|
|
||||||
content = {
|
content = {
|
||||||
"name": metadata.strip(),
|
"name": metadata.strip(),
|
||||||
"arguments": json.dumps(parameters, ensure_ascii=False)
|
"arguments": json.dumps(parameters, ensure_ascii=False)
|
||||||
|
@ -257,8 +260,11 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif role == "assistant" and func_call is not None:
|
elif role == "assistant" and func_call is not None:
|
||||||
for response in content.split(""):
|
for response in content.split("<|assistant|>"):
|
||||||
metadata, sub_content = response.split("\n", maxsplit=1)
|
if "\n" in response:
|
||||||
|
metadata, sub_content = response.split("\n", maxsplit=1)
|
||||||
|
else:
|
||||||
|
metadata, sub_content = "", response
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": role,
|
"role": role,
|
||||||
|
@ -537,7 +543,8 @@ if __name__ == "__main__":
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
worker_use_ray=True,
|
worker_use_ray=True,
|
||||||
engine_use_ray=False,
|
engine_use_ray=False,
|
||||||
disable_log_requests=True
|
disable_log_requests=True,
|
||||||
|
max_model_len=MAX_MODEL_LENGTH,
|
||||||
)
|
)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||||
|
|
Loading…
Reference in New Issue