fix: #issue65

This commit is contained in:
liuzhenghua-jk 2024-06-06 11:29:23 +08:00
parent 1683d673d2
commit 20ef30a46b
1 changed files with 14 additions and 14 deletions

View File

@ -1,5 +1,6 @@
import os
import time
import logging
from asyncio.log import logger
import uvicorn
@ -18,7 +19,6 @@ from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = 'THUDM/glm-4-9b-chat'
MAX_MODEL_LENGTH = 8192
@asynccontextmanager
@ -105,7 +105,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
tool_choice: Optional[Union[str, dict]] = "None"
tool_choice: Optional[Union[str, dict]] = None
repetition_penalty: Optional[float] = 1.1
@ -212,7 +212,7 @@ async def generate_stream_glm4(params):
torch.cuda.empty_cache()
def process_messages(messages, tools=None, tool_choice="none"):
def process_messages(messages, tools=None, tool_choice=None):
_messages = messages
messages = []
msg_has_sys = False
@ -227,7 +227,7 @@ def process_messages(messages, tools=None, tool_choice="none"):
]
return filtered_tools
if tool_choice != "none":
if tool_choice and tool_choice != "none":
if isinstance(tool_choice, dict):
tools = filter_tools(tool_choice, tools)
if tools:
@ -239,7 +239,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
}
)
msg_has_sys = True
# add to metadata
if isinstance(tool_choice, dict) and tools:
messages.append(
@ -297,7 +296,8 @@ async def list_models():
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
if request.tool_choice is None:
request.tool_choice = "auto" if request.tools else "none"
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
@ -314,9 +314,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params)
output = await anext(predict_stream_generator)
if output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
logger.debug(f"First result output\n{output}")
if not output or 'get_' not in output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream", sep="\n")
logger.debug(f"First result output: \n{output}")
function_call = None
if output and request.tools:
@ -334,10 +334,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
gen_params["messages"].append(ChatMessage(role="assistant", content=output))
gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
generate = predict(request.model, gen_params)
return EventSourceResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
else:
generate = parse_output_text(request.model, output)
return EventSourceResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
response = ""
async for response in generate_stream_glm4(gen_params):
@ -531,6 +531,7 @@ async def parse_output_text(model_id: str, value: str):
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
@ -542,8 +543,7 @@ if __name__ == "__main__":
enforce_eager=True,
worker_use_ray=True,
engine_use_ray=False,
disable_log_requests=True,
max_model_len=MAX_MODEL_LENGTH,
disable_log_requests=True
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)