fix: #issue65
This commit is contained in:
parent
1683d673d2
commit
20ef30a46b
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
from asyncio.log import logger
|
from asyncio.log import logger
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
@ -18,7 +19,6 @@ from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
||||||
MAX_MODEL_LENGTH = 8192
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
@ -105,7 +105,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tools: Optional[Union[dict, List[dict]]] = None
|
tools: Optional[Union[dict, List[dict]]] = None
|
||||||
tool_choice: Optional[Union[str, dict]] = "None"
|
tool_choice: Optional[Union[str, dict]] = None
|
||||||
repetition_penalty: Optional[float] = 1.1
|
repetition_penalty: Optional[float] = 1.1
|
||||||
|
|
||||||
|
|
||||||
|
@ -212,7 +212,7 @@ async def generate_stream_glm4(params):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def process_messages(messages, tools=None, tool_choice="none"):
|
def process_messages(messages, tools=None, tool_choice=None):
|
||||||
_messages = messages
|
_messages = messages
|
||||||
messages = []
|
messages = []
|
||||||
msg_has_sys = False
|
msg_has_sys = False
|
||||||
|
@ -227,7 +227,7 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
||||||
]
|
]
|
||||||
return filtered_tools
|
return filtered_tools
|
||||||
|
|
||||||
if tool_choice != "none":
|
if tool_choice and tool_choice != "none":
|
||||||
if isinstance(tool_choice, dict):
|
if isinstance(tool_choice, dict):
|
||||||
tools = filter_tools(tool_choice, tools)
|
tools = filter_tools(tool_choice, tools)
|
||||||
if tools:
|
if tools:
|
||||||
|
@ -239,7 +239,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg_has_sys = True
|
msg_has_sys = True
|
||||||
|
|
||||||
# add to metadata
|
# add to metadata
|
||||||
if isinstance(tool_choice, dict) and tools:
|
if isinstance(tool_choice, dict) and tools:
|
||||||
messages.append(
|
messages.append(
|
||||||
|
@ -297,7 +296,8 @@ async def list_models():
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||||
raise HTTPException(status_code=400, detail="Invalid request")
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
|
if request.tool_choice is None:
|
||||||
|
request.tool_choice = "auto" if request.tools else "none"
|
||||||
gen_params = dict(
|
gen_params = dict(
|
||||||
messages=request.messages,
|
messages=request.messages,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
|
@ -314,9 +314,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if request.stream:
|
if request.stream:
|
||||||
predict_stream_generator = predict_stream(request.model, gen_params)
|
predict_stream_generator = predict_stream(request.model, gen_params)
|
||||||
output = await anext(predict_stream_generator)
|
output = await anext(predict_stream_generator)
|
||||||
if output:
|
if not output or 'get_' not in output:
|
||||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream", sep="\n")
|
||||||
logger.debug(f"First result output:\n{output}")
|
logger.debug(f"First result output: \n{output}")
|
||||||
|
|
||||||
function_call = None
|
function_call = None
|
||||||
if output and request.tools:
|
if output and request.tools:
|
||||||
|
@ -334,10 +334,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
gen_params["messages"].append(ChatMessage(role="assistant", content=output))
|
gen_params["messages"].append(ChatMessage(role="assistant", content=output))
|
||||||
gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
|
gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
|
||||||
generate = predict(request.model, gen_params)
|
generate = predict(request.model, gen_params)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
|
||||||
else:
|
else:
|
||||||
generate = parse_output_text(request.model, output)
|
generate = parse_output_text(request.model, output)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
|
||||||
|
|
||||||
response = ""
|
response = ""
|
||||||
async for response in generate_stream_glm4(gen_params):
|
async for response in generate_stream_glm4(gen_params):
|
||||||
|
@ -454,7 +454,7 @@ async def predict_stream(model_id, gen_params):
|
||||||
output = ""
|
output = ""
|
||||||
is_function_call = False
|
is_function_call = False
|
||||||
has_send_first_chunk = False
|
has_send_first_chunk = False
|
||||||
async for new_response in generate_stream_glm4(gen_params):
|
async for new_response in generate_stream_glm4(gen_params):
|
||||||
decoded_unicode = new_response["text"]
|
decoded_unicode = new_response["text"]
|
||||||
delta_text = decoded_unicode[len(output):]
|
delta_text = decoded_unicode[len(output):]
|
||||||
output = decoded_unicode
|
output = decoded_unicode
|
||||||
|
@ -531,6 +531,7 @@ async def parse_output_text(model_id: str, value: str):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||||
engine_args = AsyncEngineArgs(
|
engine_args = AsyncEngineArgs(
|
||||||
model=MODEL_PATH,
|
model=MODEL_PATH,
|
||||||
|
@ -542,8 +543,7 @@ if __name__ == "__main__":
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
worker_use_ray=True,
|
worker_use_ray=True,
|
||||||
engine_use_ray=False,
|
engine_use_ray=False,
|
||||||
disable_log_requests=True,
|
disable_log_requests=True
|
||||||
max_model_len=MAX_MODEL_LENGTH,
|
|
||||||
)
|
)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||||
|
|
Loading…
Reference in New Issue