diff --git a/basic_demo/openai_api_server.py b/basic_demo/openai_api_server.py index 6d8499c..f03b120 100644 --- a/basic_demo/openai_api_server.py +++ b/basic_demo/openai_api_server.py @@ -1,5 +1,6 @@ import os import time +import logging from asyncio.log import logger import uvicorn @@ -18,7 +19,6 @@ from sse_starlette.sse import EventSourceResponse EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 MODEL_PATH = 'THUDM/glm-4-9b-chat' -MAX_MODEL_LENGTH = 8192 @asynccontextmanager @@ -105,7 +105,7 @@ class ChatCompletionRequest(BaseModel): max_tokens: Optional[int] = None stream: Optional[bool] = False tools: Optional[Union[dict, List[dict]]] = None - tool_choice: Optional[Union[str, dict]] = "None" + tool_choice: Optional[Union[str, dict]] = None repetition_penalty: Optional[float] = 1.1 @@ -212,7 +212,7 @@ async def generate_stream_glm4(params): torch.cuda.empty_cache() -def process_messages(messages, tools=None, tool_choice="none"): +def process_messages(messages, tools=None, tool_choice=None): _messages = messages messages = [] msg_has_sys = False @@ -227,7 +227,7 @@ def process_messages(messages, tools=None, tool_choice="none"): ] return filtered_tools - if tool_choice != "none": + if tool_choice and tool_choice != "none": if isinstance(tool_choice, dict): tools = filter_tools(tool_choice, tools) if tools: @@ -239,7 +239,6 @@ def process_messages(messages, tools=None, tool_choice="none"): } ) msg_has_sys = True - # add to metadata if isinstance(tool_choice, dict) and tools: messages.append( @@ -297,7 +296,8 @@ async def list_models(): async def create_chat_completion(request: ChatCompletionRequest): if len(request.messages) < 1 or request.messages[-1].role == "assistant": raise HTTPException(status_code=400, detail="Invalid request") - + if request.tool_choice is None: + request.tool_choice = "auto" if request.tools else "none" gen_params = dict( messages=request.messages, temperature=request.temperature, @@ -314,9 +314,9 @@ async def create_chat_completion(request: ChatCompletionRequest): if request.stream: predict_stream_generator = predict_stream(request.model, gen_params) output = await anext(predict_stream_generator) - if output: - return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") - logger.debug(f"First result output:\n{output}") + if not output or 'get_' not in output: + return EventSourceResponse(predict_stream_generator, media_type="text/event-stream", sep="\n") + logger.debug(f"First result output: \n{output}") function_call = None if output and request.tools: @@ -334,10 +334,10 @@ async def create_chat_completion(request: ChatCompletionRequest): gen_params["messages"].append(ChatMessage(role="assistant", content=output)) gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response)) generate = predict(request.model, gen_params) - return EventSourceResponse(generate, media_type="text/event-stream") + return EventSourceResponse(generate, media_type="text/event-stream", sep="\n") else: generate = parse_output_text(request.model, output) - return EventSourceResponse(generate, media_type="text/event-stream") + return EventSourceResponse(generate, media_type="text/event-stream", sep="\n") response = "" async for response in generate_stream_glm4(gen_params): @@ -454,7 +454,7 @@ async def predict_stream(model_id, gen_params): output = "" is_function_call = False has_send_first_chunk = False - async for new_response in generate_stream_glm4(gen_params): + async for new_response in generate_stream_glm4(gen_params): decoded_unicode = new_response["text"] delta_text = decoded_unicode[len(output):] output = decoded_unicode @@ -531,6 +531,7 @@ async def parse_output_text(model_id: str, value: str): if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) engine_args = AsyncEngineArgs( model=MODEL_PATH, @@ -542,8 +543,7 @@ if __name__ == "__main__": enforce_eager=True, worker_use_ray=True, engine_use_ray=False, - disable_log_requests=True, - max_model_len=MAX_MODEL_LENGTH, + disable_log_requests=True ) engine = AsyncLLMEngine.from_engine_args(engine_args) uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)