fix: #issue65
This commit is contained in:
parent
1683d673d2
commit
20ef30a46b
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import time
|
||||
import logging
|
||||
from asyncio.log import logger
|
||||
|
||||
import uvicorn
|
||||
|
@ -18,7 +19,6 @@ from sse_starlette.sse import EventSourceResponse
|
|||
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
||||
MAX_MODEL_LENGTH = 8192
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -105,7 +105,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[Union[dict, List[dict]]] = None
|
||||
tool_choice: Optional[Union[str, dict]] = "None"
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
repetition_penalty: Optional[float] = 1.1
|
||||
|
||||
|
||||
|
@ -212,7 +212,7 @@ async def generate_stream_glm4(params):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def process_messages(messages, tools=None, tool_choice="none"):
|
||||
def process_messages(messages, tools=None, tool_choice=None):
|
||||
_messages = messages
|
||||
messages = []
|
||||
msg_has_sys = False
|
||||
|
@ -227,7 +227,7 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
|||
]
|
||||
return filtered_tools
|
||||
|
||||
if tool_choice != "none":
|
||||
if tool_choice and tool_choice != "none":
|
||||
if isinstance(tool_choice, dict):
|
||||
tools = filter_tools(tool_choice, tools)
|
||||
if tools:
|
||||
|
@ -239,7 +239,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
|||
}
|
||||
)
|
||||
msg_has_sys = True
|
||||
|
||||
# add to metadata
|
||||
if isinstance(tool_choice, dict) and tools:
|
||||
messages.append(
|
||||
|
@ -297,7 +296,8 @@ async def list_models():
|
|||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
if request.tool_choice is None:
|
||||
request.tool_choice = "auto" if request.tools else "none"
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
|
@ -314,9 +314,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
if request.stream:
|
||||
predict_stream_generator = predict_stream(request.model, gen_params)
|
||||
output = await anext(predict_stream_generator)
|
||||
if output:
|
||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||
logger.debug(f"First result output:\n{output}")
|
||||
if not output or 'get_' not in output:
|
||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream", sep="\n")
|
||||
logger.debug(f"First result output: \n{output}")
|
||||
|
||||
function_call = None
|
||||
if output and request.tools:
|
||||
|
@ -334,10 +334,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
gen_params["messages"].append(ChatMessage(role="assistant", content=output))
|
||||
gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
|
||||
generate = predict(request.model, gen_params)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
|
||||
else:
|
||||
generate = parse_output_text(request.model, output)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
|
||||
|
||||
response = ""
|
||||
async for response in generate_stream_glm4(gen_params):
|
||||
|
@ -531,6 +531,7 @@ async def parse_output_text(model_id: str, value: str):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=MODEL_PATH,
|
||||
|
@ -542,8 +543,7 @@ if __name__ == "__main__":
|
|||
enforce_eager=True,
|
||||
worker_use_ray=True,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
max_model_len=MAX_MODEL_LENGTH,
|
||||
disable_log_requests=True
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
|
|
Loading…
Reference in New Issue