glm4/basic_demo/openai_api_server.py

463 lines
16 KiB
Python
Raw Normal View History

2024-06-05 10:22:16 +08:00
import time
from asyncio.log import logger
import uvicorn
import gc
import json
2024-06-12 08:55:31 +08:00
import random
import string
import logging
2024-06-05 10:22:16 +08:00
2024-06-12 08:55:31 +08:00
import torch
2024-06-05 10:22:16 +08:00
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, LogitsProcessor
from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
2024-06-07 22:14:00 +08:00
2024-06-06 10:00:11 +08:00
MODEL_PATH = 'THUDM/glm-4-9b-chat'
2024-06-12 08:55:31 +08:00
# max model length 128k
2024-06-05 16:11:54 +08:00
MAX_MODEL_LENGTH = 8192
2024-06-05 10:22:16 +08:00
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2024-06-12 08:55:31 +08:00
def generate_id(prefix: str) -> str:
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=24))
return f"{prefix}-{suffix}"
2024-06-05 10:22:16 +08:00
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
2024-06-07 22:14:00 +08:00
data: List[ModelCard] = ["glm-4"]
class FunctionCall(BaseModel):
name: str
arguments: str
2024-06-05 10:22:16 +08:00
class FunctionCallResponse(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
2024-06-07 22:14:00 +08:00
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionMessageToolCall(BaseModel):
2024-06-12 08:55:31 +08:00
id: Optional[str] = Field(default_factory=lambda: generate_id('call'))
2024-06-07 22:14:00 +08:00
function: FunctionCall
2024-06-12 08:55:31 +08:00
type: Optional[Literal["function"]] = 'function'
2024-06-07 22:14:00 +08:00
2024-06-05 10:22:16 +08:00
class ChatMessage(BaseModel):
2024-06-12 08:55:31 +08:00
role: Literal["user", "assistant", "system", "function", "tool"]
2024-06-07 22:14:00 +08:00
content: Optional[str] = None
2024-06-12 08:55:31 +08:00
function_call: Optional[FunctionCall] = None
2024-06-07 22:14:00 +08:00
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
2024-06-05 10:22:16 +08:00
class DeltaMessage(BaseModel):
2024-06-12 08:55:31 +08:00
role: Optional[Literal["user", "assistant", "function", "system"]] = None
2024-06-05 10:22:16 +08:00
content: Optional[str] = None
2024-06-12 08:55:31 +08:00
function_call: Optional[FunctionCall] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
2024-06-05 10:22:16 +08:00
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
2024-06-07 22:14:00 +08:00
finish_reason: Literal["stop", "length", "tool_calls"]
2024-06-05 10:22:16 +08:00
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
2024-06-07 22:14:00 +08:00
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
2024-06-05 10:22:16 +08:00
index: int
class ChatCompletionResponse(BaseModel):
model: str
2024-06-12 08:55:31 +08:00
id: str = Field(default_factory=lambda: generate_id('chatcmpl'))
2024-06-05 10:22:16 +08:00
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
2024-06-12 08:55:31 +08:00
@staticmethod
def _convert_to_tool_calls_from_content(content: str) -> Union[List[ChatCompletionMessageToolCall], str]:
tool_calls = []
content = content.strip()
for response in content.split("<|assistant|>"):
if "\n" in response:
metadata, content = response.split("\n", maxsplit=1)
else:
metadata, content = "", response
if metadata.strip():
parameters = eval(content.strip())
function_call = FunctionCall(
name=metadata.strip(),
arguments=json.dumps(parameters, ensure_ascii=False)
)
tool_calls.append(ChatCompletionMessageToolCall(function=function_call))
return tool_calls if len(tool_calls) > 0 else content
@staticmethod
def stream_reply(model_id: str, content: str, finish_reason: str, use_tool: bool = False) -> str:
if content.startswith("\n"):
content = content[1:]
tool_calls = None
if use_tool:
parsed_tool_calls = ChatCompletionResponse._convert_to_tool_calls_from_content(content)
if isinstance(parsed_tool_calls, list):
tool_calls = parsed_tool_calls
finish_reason = "tool_calls"
content = None
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=content, tool_calls=tool_calls),
finish_reason=finish_reason
)
return ChatCompletionResponse(
model=model_id,
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
).model_dump_json(exclude_none=True)
@staticmethod
def reply(model_id: str, content: str, finish_reason: str, use_tool: bool = False, usage: UsageInfo = None) \
-> 'ChatCompletionResponse':
if content.startswith("\n"):
content = content[1:]
tool_calls = None
if use_tool:
parsed_tool_calls = ChatCompletionResponse._convert_to_tool_calls_from_content(content)
if isinstance(parsed_tool_calls, list):
tool_calls = parsed_tool_calls
finish_reason = "tool_calls"
content = None
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=content, tool_calls=tool_calls),
finish_reason=finish_reason
)
return ChatCompletionResponse(
model=model_id,
choices=[choice_data],
created=int(time.time()),
object="chat.completion",
usage=usage
)
2024-06-05 10:22:16 +08:00
2024-06-07 22:14:00 +08:00
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
2024-06-12 08:55:31 +08:00
tool_choice: Optional[Union[str, dict]] = None
2024-06-07 22:14:00 +08:00
repetition_penalty: Optional[float] = 1.1
2024-06-05 10:22:16 +08:00
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
@torch.inference_mode()
async def generate_stream_glm4(params):
messages = params["messages"]
tools = params["tools"]
tool_choice = params["tool_choice"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 8192))
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
"top_k": -1,
"repetition_penalty": repetition_penalty,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"max_tokens": max_new_tokens,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
2024-06-05 10:22:16 +08:00
output_len = len(output.outputs[0].token_ids)
input_len = len(output.prompt_token_ids)
ret = {
"text": output.outputs[0].text,
"usage": {
"prompt_tokens": input_len,
"completion_tokens": output_len,
"total_tokens": output_len + input_len
},
"finish_reason": output.outputs[0].finish_reason,
}
yield ret
gc.collect()
torch.cuda.empty_cache()
2024-06-12 08:55:31 +08:00
def process_messages(messages, tools=None, tool_choice=None):
2024-06-05 10:22:16 +08:00
_messages = messages
processed_messages = []
2024-06-05 10:22:16 +08:00
msg_has_sys = False
def filter_tools(tool_choice, tools):
function_name = tool_choice.get('function', {}).get('name', None)
if not function_name:
return []
filtered_tools = [
tool for tool in tools
if tool.get('function', {}).get('name') == function_name
]
return filtered_tools
2024-06-12 08:55:31 +08:00
if tool_choice and tool_choice != "none":
2024-06-05 10:22:16 +08:00
if isinstance(tool_choice, dict):
tools = filter_tools(tool_choice, tools)
if tools:
processed_messages.append(
2024-06-05 10:22:16 +08:00
{
"role": "system",
"content": None,
"tools": tools
}
)
2024-06-07 22:14:00 +08:00
msg_has_sys = True
2024-06-05 10:22:16 +08:00
if isinstance(tool_choice, dict) and tools:
processed_messages.append(
2024-06-05 10:22:16 +08:00
{
"role": "assistant",
"metadata": tool_choice["function"]["name"],
"content": ""
}
)
for m in _messages:
role, content, func_call = m.role, m.content, m.function_call
tool_calls = getattr(m, 'tool_calls', None)
2024-06-05 10:22:16 +08:00
if role == "function":
processed_messages.append(
2024-06-05 10:22:16 +08:00
{
"role": "observation",
"content": content
}
)
elif role == "tool":
processed_messages.append(
{
"role": "observation",
"content": content,
"function_call": True
}
)
elif role == "assistant":
if tool_calls:
for tool_call in tool_calls:
processed_messages.append(
{
"role": "assistant",
"metadata": tool_call.function.name,
"content": tool_call.function.arguments
}
)
else:
for response in content.split("\n"):
if "\n" in response:
metadata, sub_content = response.split("\n", maxsplit=1)
else:
metadata, sub_content = "", response
processed_messages.append(
{
"role": role,
"metadata": metadata,
"content": sub_content.strip()
}
)
2024-06-05 10:22:16 +08:00
else:
if role == "system" and msg_has_sys:
msg_has_sys = False
continue
processed_messages.append({"role": role, "content": content})
2024-06-05 10:22:16 +08:00
2024-06-07 22:14:00 +08:00
if not tools or tool_choice == "none":
for m in _messages:
if m.role == 'system':
processed_messages.insert(0, {"role": m.role, "content": m.content})
2024-06-07 22:14:00 +08:00
break
return processed_messages
2024-06-05 10:22:16 +08:00
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="glm-4")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
2024-06-12 08:55:31 +08:00
if request.tool_choice is None:
request.tool_choice = "auto" if request.tools else "none"
2024-06-05 10:22:16 +08:00
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
tool_choice=request.tool_choice,
)
2024-06-12 08:55:31 +08:00
logger.debug(f"==== request ====\n{request.model_dump_json()}")
2024-06-05 10:22:16 +08:00
if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params)
2024-06-12 08:55:31 +08:00
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream", sep="\n")
2024-06-05 10:22:16 +08:00
response = ""
async for response in generate_stream_glm4(gen_params):
pass
2024-06-12 08:55:31 +08:00
is_tool_call = is_return_tool_call(response["text"], request.tools)
2024-06-05 10:22:16 +08:00
usage = UsageInfo()
task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
2024-06-12 08:55:31 +08:00
return ChatCompletionResponse.reply(request.model, response["text"], response["finish_reason"], is_tool_call, usage)
2024-06-05 10:22:16 +08:00
2024-06-12 08:55:31 +08:00
def calc_max_tool_name_len(tools: Optional[List[dict]]) -> int:
max_tool_name_len = 0
if not tools:
return max_tool_name_len
tool_names = [tool['function']['name'] for tool in tools if 'function' in tool and 'name' in tool['function']]
max_tool_name_len = max(len(tool_name) for tool_name in tool_names)
return max_tool_name_len
def is_return_tool_call(output: str, tools: Optional[List[dict]]) -> bool:
if not tools:
return False
output = output.strip()
tool_names = [tool['function']['name'] for tool in tools if 'function' in tool and 'name' in tool['function']]
return any(output.startswith(name) for name in tool_names)
2024-06-05 10:22:16 +08:00
2024-06-05 10:22:16 +08:00
async def predict_stream(model_id, gen_params):
output = ""
is_function_call = False
has_send_first_chunk = False
2024-06-12 08:55:31 +08:00
tools = gen_params.get("tools")
max_tool_name_len = calc_max_tool_name_len(tools)
finish_reason = "stop"
2024-06-07 22:14:00 +08:00
async for new_response in generate_stream_glm4(gen_params):
2024-06-05 10:22:16 +08:00
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode
2024-06-12 08:55:31 +08:00
# read an extra char because the first generate char may be \n
if len(output) <= max_tool_name_len:
continue
if not is_function_call:
is_function_call = is_return_tool_call(output, tools)
if is_function_call:
2024-06-12 08:55:31 +08:00
continue
else:
2024-06-12 08:55:31 +08:00
finish_reason = new_response["finish_reason"]
send_msg = delta_text if has_send_first_chunk else output[1:] if output.startswith("\n") else output
has_send_first_chunk = True
yield ChatCompletionResponse.stream_reply(model_id, send_msg, finish_reason)
# if the total output length less than the max tool name length, has_send_first_chunk = False
if is_function_call or not has_send_first_chunk:
yield ChatCompletionResponse.stream_reply(model_id, output, finish_reason, is_function_call)
2024-06-05 10:22:16 +08:00
yield '[DONE]'
2024-06-12 08:55:31 +08:00
2024-06-05 10:22:16 +08:00
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
tokenizer=MODEL_PATH,
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,
2024-06-08 13:26:43 +08:00
# 占用显存的比例请根据你的显卡显存大小设置合适的值例如如果你的显卡有80G您只想使用24G请按照24/80=0.3设置
gpu_memory_utilization=0.9,
2024-06-05 10:22:16 +08:00
enforce_eager=True,
2024-06-07 22:14:00 +08:00
worker_use_ray=False,
2024-06-05 10:22:16 +08:00
engine_use_ray=False,
2024-06-05 16:11:54 +08:00
disable_log_requests=True,
max_model_len=MAX_MODEL_LENGTH,
2024-06-05 10:22:16 +08:00
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)