add glm4v openai server
This commit is contained in:
parent
0f6a7c94fe
commit
9f98825a63
|
@ -119,7 +119,8 @@ python vllm_cli_demo.py
|
|||
```
|
||||
|
||||
|
||||
+ 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。
|
||||
+ 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat GLM-4v-9B 或者模型进行对话。本 demo 支持 Function Call 和 All Tools功能。
|
||||
+ 修改 `open_api_server.py` 中模型路径 `MODEL_PATH`,可选择构建 GLM-4-9B-Chat 或者 GLM-4v-9B 服务端
|
||||
|
||||
启动服务端:
|
||||
|
||||
|
|
|
@ -126,6 +126,7 @@ python vllm_cli_demo.py
|
|||
|
||||
+ Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This
|
||||
demo supports Function Call and All Tools functions.
|
||||
+ Modify the `MODEL_PATH` in `open_api_server.py`, and you can choose to build the GLM-4-9B-Chat or GLM-4v-9B server side.
|
||||
|
||||
Start the server:
|
||||
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 1.0 MiB |
|
@ -0,0 +1,380 @@
|
|||
import gc
|
||||
import threading
|
||||
import time
|
||||
import base64
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Literal, Union, Tuple, Optional
|
||||
import torch
|
||||
import uvicorn
|
||||
import requests
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModel,
|
||||
TextIteratorStreamer
|
||||
)
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
An asynchronous context manager for managing the lifecycle of the FastAPI app.
|
||||
It ensures that GPU memory is cleared after the app's lifecycle ends, which is essential for efficient resource management in GPU environments.
|
||||
"""
|
||||
yield
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
"""
|
||||
A Pydantic model representing a model card, which provides metadata about a machine learning model.
|
||||
It includes fields like model ID, owner, and creation time.
|
||||
"""
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class ImageUrl(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ImageUrlContent(BaseModel):
|
||||
type: Literal["image_url"]
|
||||
image_url: ImageUrl
|
||||
|
||||
|
||||
ContentItem = Union[TextContent, ImageUrlContent]
|
||||
|
||||
|
||||
class ChatMessageInput(BaseModel):
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: Union[str, List[ContentItem]]
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
role: Literal["assistant"]
|
||||
content: str = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessageInput]
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.8
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
# Additional parameters
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessageResponse
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
"""
|
||||
An endpoint to list available models. It returns a list of model cards.
|
||||
This is useful for clients to query and understand what models are available for use.
|
||||
"""
|
||||
model_card = ModelCard(id="GLM-4v-9b")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer
|
||||
|
||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
echo=False,
|
||||
stream=request.stream,
|
||||
repetition_penalty=request.repetition_penalty
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
generate = predict(request.model, gen_params)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
response = generate_glm4v(model, tokenizer, gen_params)
|
||||
|
||||
usage = UsageInfo()
|
||||
|
||||
message = ChatMessageResponse(
|
||||
role="assistant",
|
||||
content=response["text"],
|
||||
)
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
)
|
||||
task_usage = UsageInfo.model_validate(response["usage"])
|
||||
for usage_key, usage_value in task_usage.model_dump().items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage)
|
||||
|
||||
|
||||
def predict(model_id: str, params: dict):
|
||||
global model, tokenizer
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
previous_text = ""
|
||||
for new_response in generate_stream_glm4v(model, tokenizer, params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text = decoded_unicode[len(previous_text):]
|
||||
previous_text = decoded_unicode
|
||||
delta = DeltaMessage(content=delta_text, role="assistant")
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=delta)
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage())
|
||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
|
||||
def generate_glm4v(model: AutoModel, tokenizer: AutoTokenizer, params: dict):
|
||||
"""
|
||||
Generates a response using the GLM-4v-9b model. It processes the chat history and image data, if any,
|
||||
and then invokes the model to generate a response.
|
||||
"""
|
||||
|
||||
response = None
|
||||
|
||||
for response in generate_stream_glm4v(model, tokenizer, params):
|
||||
pass
|
||||
return response
|
||||
|
||||
|
||||
def process_history_and_images(messages: List[ChatMessageInput]) -> Tuple[
|
||||
Optional[str], Optional[List[Tuple[str, str]]], Optional[List[Image.Image]]]:
|
||||
"""
|
||||
Process history messages to extract text, identify the last user query,
|
||||
and convert base64 encoded image URLs to PIL images.
|
||||
|
||||
Args:
|
||||
messages(List[ChatMessageInput]): List of ChatMessageInput objects.
|
||||
return: A tuple of three elements:
|
||||
- The last user query as a string.
|
||||
- Text history formatted as a list of tuples for the model.
|
||||
- List of PIL Image objects extracted from the messages.
|
||||
"""
|
||||
|
||||
formatted_history = []
|
||||
image_list = []
|
||||
last_user_query = ''
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
role = message.role
|
||||
content = message.content
|
||||
|
||||
if isinstance(content, list): # text
|
||||
text_content = ' '.join(item.text for item in content if isinstance(item, TextContent))
|
||||
else:
|
||||
text_content = content
|
||||
|
||||
if isinstance(content, list): # image
|
||||
for item in content:
|
||||
if isinstance(item, ImageUrlContent):
|
||||
image_url = item.image_url.url
|
||||
if image_url.startswith("data:image/jpeg;base64,"):
|
||||
base64_encoded_image = image_url.split("data:image/jpeg;base64,")[1]
|
||||
image_data = base64.b64decode(base64_encoded_image)
|
||||
image = Image.open(BytesIO(image_data)).convert('RGB')
|
||||
else:
|
||||
response = requests.get(image_url, verify=False)
|
||||
image = Image.open(BytesIO(response.content)).convert('RGB')
|
||||
image_list.append(image)
|
||||
|
||||
if role == 'user':
|
||||
if i == len(messages) - 1: # 最后一条用户消息
|
||||
last_user_query = text_content
|
||||
else:
|
||||
formatted_history.append((text_content, ''))
|
||||
elif role == 'assistant':
|
||||
if formatted_history:
|
||||
if formatted_history[-1][1] != '':
|
||||
assert False, f"the last query is answered. answer again. {formatted_history[-1][0]}, {formatted_history[-1][1]}, {text_content}"
|
||||
formatted_history[-1] = (formatted_history[-1][0], text_content)
|
||||
else:
|
||||
assert False, f"assistant reply before user"
|
||||
else:
|
||||
assert False, f"unrecognized role: {role}"
|
||||
|
||||
return last_user_query, formatted_history, image_list
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream_glm4v(model: AutoModel, tokenizer: AutoTokenizer, params: dict):
|
||||
messages = params["messages"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
max_new_tokens = int(params.get("max_tokens", 256))
|
||||
query, history, image_list = process_history_and_images(messages)
|
||||
|
||||
inputs = []
|
||||
for idx, (user_msg, model_msg) in enumerate(history):
|
||||
if idx == len(history) - 1 and not model_msg:
|
||||
inputs.append({"role": "user", "content": user_msg})
|
||||
if image_list and not uploaded:
|
||||
inputs[-1].update({"image": image_list[0]})
|
||||
uploaded = True
|
||||
break
|
||||
if user_msg:
|
||||
inputs.append({"role": "user", "content": user_msg})
|
||||
if model_msg:
|
||||
inputs.append({"role": "assistant", "content": model_msg})
|
||||
inputs.append({"role": "user", "content": query, "image": image_list[0]})
|
||||
|
||||
model_inputs = tokenizer.apply_chat_template(
|
||||
inputs,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True
|
||||
).to(next(model.parameters()).device)
|
||||
input_echo_len = len(model_inputs["input_ids"][0])
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=tokenizer,
|
||||
timeout=60.0,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
gen_kwargs = {
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"do_sample": True if temperature > 1e-5 else False,
|
||||
"top_p": top_p if temperature > 1e-5 else 0,
|
||||
"top_k": 1,
|
||||
'streamer': streamer,
|
||||
}
|
||||
if temperature > 1e-5:
|
||||
gen_kwargs["temperature"] = temperature
|
||||
|
||||
generated_text = ""
|
||||
|
||||
def generate_text():
|
||||
with torch.no_grad():
|
||||
model.generate(**model_inputs, **gen_kwargs)
|
||||
|
||||
generation_thread = threading.Thread(target=generate_text)
|
||||
generation_thread.start()
|
||||
|
||||
total_len = input_echo_len
|
||||
for next_text in streamer:
|
||||
generated_text += next_text
|
||||
total_len = len(tokenizer.encode(generated_text))
|
||||
yield {
|
||||
"text": generated_text,
|
||||
"usage": {
|
||||
"prompt_tokens": input_echo_len,
|
||||
"completion_tokens": total_len - input_echo_len,
|
||||
"total_tokens": total_len,
|
||||
},
|
||||
}
|
||||
generation_thread.join()
|
||||
|
||||
yield {
|
||||
"text": generated_text,
|
||||
"usage": {
|
||||
"prompt_tokens": input_echo_len,
|
||||
"completion_tokens": total_len - input_echo_len,
|
||||
"total_tokens": total_len,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if __name__ == "__main__":
|
||||
MODEL_PATH = sys.argv[1]
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
encode_special_tokens=True
|
||||
)
|
||||
model = AutoModel.from_pretrained(
|
||||
MODEL_PATH,
|
||||
torch_dtype=TORCH_TYPE,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
).eval().to(DEVICE)
|
||||
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
|
@ -0,0 +1,681 @@
|
|||
import time
|
||||
from asyncio.log import logger
|
||||
import re
|
||||
import sys
|
||||
import uvicorn
|
||||
import gc
|
||||
import json
|
||||
import torch
|
||||
import random
|
||||
import string
|
||||
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
||||
from fastapi import FastAPI, HTTPException, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoTokenizer, LogitsProcessor
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
|
||||
MAX_MODEL_LENGTH = 8192
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def generate_id(prefix: str, k=29) -> str:
|
||||
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
|
||||
return f"{prefix}{suffix}"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str = ""
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = ["glm-4"]
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCall(BaseModel):
|
||||
index: Optional[int] = 0
|
||||
id: Optional[str] = None
|
||||
function: FunctionCall
|
||||
type: Optional[Literal["function"]] = 'function'
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
# “function” 字段解释:
|
||||
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
|
||||
|
||||
role: Literal["user", "assistant", "system", "tool"]
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
index: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.8
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[Union[dict, List[dict]]] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
repetition_penalty: Optional[float] = 1.1
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 5] = 5e4
|
||||
return scores
|
||||
|
||||
|
||||
def process_response(output: str, tools: dict | List[dict] = None, use_tool: bool = False) -> Union[str, dict]:
|
||||
lines = output.strip().split("\n")
|
||||
arguments_json = None
|
||||
special_tools = ["cogview", "simple_browser"]
|
||||
tools = {tool['function']['name'] for tool in tools} if tools else {}
|
||||
|
||||
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
|
||||
|
||||
if len(lines) >= 2 and lines[1].startswith("{"):
|
||||
function_name = lines[0].strip()
|
||||
arguments = "\n".join(lines[1:]).strip()
|
||||
if function_name in tools or function_name in special_tools:
|
||||
try:
|
||||
arguments_json = json.loads(arguments)
|
||||
is_tool_call = True
|
||||
except json.JSONDecodeError:
|
||||
is_tool_call = function_name in special_tools
|
||||
|
||||
if is_tool_call and use_tool:
|
||||
content = {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||
ensure_ascii=False)
|
||||
}
|
||||
if function_name == "simple_browser":
|
||||
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
|
||||
match = search_pattern.match(arguments)
|
||||
if match:
|
||||
content["arguments"] = json.dumps({
|
||||
"query": match.group(1),
|
||||
"recency_days": int(match.group(2))
|
||||
}, ensure_ascii=False)
|
||||
elif function_name == "cogview":
|
||||
content["arguments"] = json.dumps({
|
||||
"prompt": arguments
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return content
|
||||
return output.strip()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
async def generate_stream_glm4(params):
|
||||
messages = params["messages"]
|
||||
tools = params["tools"]
|
||||
tool_choice = params["tool_choice"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
max_new_tokens = int(params.get("max_tokens", 8192))
|
||||
|
||||
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
params_dict = {
|
||||
"n": 1,
|
||||
"best_of": 1,
|
||||
"presence_penalty": 1.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": -1,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"use_beam_search": False,
|
||||
"length_penalty": 1,
|
||||
"early_stopping": False,
|
||||
"stop_token_ids": [151329, 151336, 151338],
|
||||
"ignore_eos": False,
|
||||
"max_tokens": max_new_tokens,
|
||||
"logprobs": None,
|
||||
"prompt_logprobs": None,
|
||||
"skip_special_tokens": True,
|
||||
}
|
||||
sampling_params = SamplingParams(**params_dict)
|
||||
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
||||
output_len = len(output.outputs[0].token_ids)
|
||||
input_len = len(output.prompt_token_ids)
|
||||
ret = {
|
||||
"text": output.outputs[0].text,
|
||||
"usage": {
|
||||
"prompt_tokens": input_len,
|
||||
"completion_tokens": output_len,
|
||||
"total_tokens": output_len + input_len
|
||||
},
|
||||
"finish_reason": output.outputs[0].finish_reason,
|
||||
}
|
||||
yield ret
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def process_messages(messages, tools=None, tool_choice="none"):
|
||||
_messages = messages
|
||||
processed_messages = []
|
||||
msg_has_sys = False
|
||||
|
||||
def filter_tools(tool_choice, tools):
|
||||
function_name = tool_choice.get('function', {}).get('name', None)
|
||||
if not function_name:
|
||||
return []
|
||||
filtered_tools = [
|
||||
tool for tool in tools
|
||||
if tool.get('function', {}).get('name') == function_name
|
||||
]
|
||||
return filtered_tools
|
||||
|
||||
if tool_choice != "none":
|
||||
if isinstance(tool_choice, dict):
|
||||
tools = filter_tools(tool_choice, tools)
|
||||
if tools:
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": None,
|
||||
"tools": tools
|
||||
}
|
||||
)
|
||||
msg_has_sys = True
|
||||
|
||||
if isinstance(tool_choice, dict) and tools:
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"metadata": tool_choice["function"]["name"],
|
||||
"content": ""
|
||||
}
|
||||
)
|
||||
|
||||
for m in _messages:
|
||||
role, content, func_call = m.role, m.content, m.function_call
|
||||
tool_calls = getattr(m, 'tool_calls', None)
|
||||
|
||||
if role == "function":
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "observation",
|
||||
"content": content
|
||||
}
|
||||
)
|
||||
elif role == "tool":
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "observation",
|
||||
"content": content,
|
||||
"function_call": True
|
||||
}
|
||||
)
|
||||
elif role == "assistant":
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"metadata": tool_call.function.name,
|
||||
"content": tool_call.function.arguments
|
||||
}
|
||||
)
|
||||
else:
|
||||
for response in content.split("\n"):
|
||||
if "\n" in response:
|
||||
metadata, sub_content = response.split("\n", maxsplit=1)
|
||||
else:
|
||||
metadata, sub_content = "", response
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"metadata": metadata,
|
||||
"content": sub_content.strip()
|
||||
}
|
||||
)
|
||||
else:
|
||||
if role == "system" and msg_has_sys:
|
||||
msg_has_sys = False
|
||||
continue
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
if not tools or tool_choice == "none":
|
||||
for m in _messages:
|
||||
if m.role == 'system':
|
||||
processed_messages.insert(0, {"role": m.role, "content": m.content})
|
||||
break
|
||||
return processed_messages
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="glm-4")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
echo=False,
|
||||
stream=request.stream,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
logger.debug(f"==== request ====\n{gen_params}")
|
||||
|
||||
if request.stream:
|
||||
predict_stream_generator = predict_stream(request.model, gen_params)
|
||||
output = await anext(predict_stream_generator)
|
||||
if output:
|
||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||
logger.debug(f"First result output:\n{output}")
|
||||
|
||||
function_call = None
|
||||
if output and request.tools:
|
||||
try:
|
||||
function_call = process_response(output, request.tools, use_tool=True)
|
||||
except:
|
||||
logger.warning("Failed to parse tool call")
|
||||
|
||||
if isinstance(function_call, dict):
|
||||
function_call = ChoiceDeltaToolCallFunction(**function_call)
|
||||
generate = parse_output_text(request.model, output, function_call=function_call)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
else:
|
||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||
response = ""
|
||||
async for response in generate_stream_glm4(gen_params):
|
||||
pass
|
||||
|
||||
if response["text"].startswith("\n"):
|
||||
response["text"] = response["text"][1:]
|
||||
response["text"] = response["text"].strip()
|
||||
|
||||
usage = UsageInfo()
|
||||
|
||||
function_call, finish_reason = None, "stop"
|
||||
tool_calls = None
|
||||
if request.tools:
|
||||
try:
|
||||
function_call = process_response(response["text"], request.tools, use_tool=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse tool call: {e}")
|
||||
if isinstance(function_call, dict):
|
||||
finish_reason = "tool_calls"
|
||||
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
|
||||
function_call_instance = FunctionCall(
|
||||
name=function_call_response.name,
|
||||
arguments=function_call_response.arguments
|
||||
)
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id=generate_id('call_', 24),
|
||||
function=function_call_instance,
|
||||
type="function")]
|
||||
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=None if tool_calls else response["text"],
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
logger.debug(f"==== message ====\n{message}")
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
task_usage = UsageInfo.model_validate(response["usage"])
|
||||
for usage_key, usage_value in task_usage.model_dump().items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
choices=[choice_data],
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
|
||||
async def predict_stream(model_id, gen_params):
|
||||
output = ""
|
||||
is_function_call = False
|
||||
has_send_first_chunk = False
|
||||
created_time = int(time.time())
|
||||
function_name = None
|
||||
response_id = generate_id('chatcmpl-', 29)
|
||||
system_fingerprint = generate_id('fp_', 9)
|
||||
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
|
||||
delta_text = ""
|
||||
async for new_response in generate_stream_glm4(gen_params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text += decoded_unicode[len(output):]
|
||||
output = decoded_unicode
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
# 检查是否为工具
|
||||
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
|
||||
|
||||
if not is_function_call and len(lines) >= 2:
|
||||
first_line = lines[0].strip()
|
||||
if first_line in tools:
|
||||
is_function_call = True
|
||||
function_name = first_line
|
||||
delta_text = lines[1]
|
||||
|
||||
# 工具调用返回
|
||||
if is_function_call:
|
||||
if not has_send_first_chunk:
|
||||
function_call = {"name": function_name, "arguments": ""}
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
index=0,
|
||||
id=generate_id('call_', 24),
|
||||
function=FunctionCall(**function_call),
|
||||
type="function"
|
||||
)
|
||||
message = DeltaMessage(
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield ""
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
has_send_first_chunk = True
|
||||
|
||||
function_call = {"name": None, "arguments": delta_text}
|
||||
delta_text = ""
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
index=0,
|
||||
id=None,
|
||||
function=FunctionCall(**function_call),
|
||||
type="function"
|
||||
)
|
||||
message = DeltaMessage(
|
||||
content=None,
|
||||
role=None,
|
||||
function_call=None,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
|
||||
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
|
||||
continue
|
||||
|
||||
# 常规返回
|
||||
else:
|
||||
finish_reason = new_response.get("finish_reason", None)
|
||||
if not has_send_first_chunk:
|
||||
message = DeltaMessage(
|
||||
content="",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
has_send_first_chunk = True
|
||||
|
||||
message = DeltaMessage(
|
||||
content=delta_text,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
delta_text = ""
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口
|
||||
if is_function_call:
|
||||
yield ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
system_fingerprint=system_fingerprint,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(
|
||||
content=None,
|
||||
role=None,
|
||||
function_call=None,
|
||||
),
|
||||
finish_reason="tool_calls"
|
||||
)],
|
||||
created=created_time,
|
||||
object="chat.completion.chunk",
|
||||
usage=None
|
||||
).model_dump_json(exclude_unset=True)
|
||||
elif delta_text != "":
|
||||
message = DeltaMessage(
|
||||
content="",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
finish_reason = 'stop'
|
||||
message = DeltaMessage(
|
||||
content=delta_text,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
delta_text = ""
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
yield '[DONE]'
|
||||
else:
|
||||
yield '[DONE]'
|
||||
|
||||
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
|
||||
delta = DeltaMessage(role="assistant", content=value)
|
||||
if function_call is not None:
|
||||
delta.function_call = function_call
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
choices=[choice_data],
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MODEL_PATH = sys.argv[1]
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=MODEL_PATH,
|
||||
tokenizer=MODEL_PATH,
|
||||
# 如果你有多张显卡,可以在这里设置成你的显卡数量
|
||||
tensor_parallel_size=1,
|
||||
# dtype="bfloat16",
|
||||
dtype="half",
|
||||
trust_remote_code=True,
|
||||
# 占用显存的比例,请根据你的显卡显存大小设置合适的值,例如,如果你的显卡有80G,您只想使用24G,请按照24/80=0.3设置
|
||||
gpu_memory_utilization=0.9,
|
||||
enforce_eager=True,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
max_model_len=MAX_MODEL_LENGTH,
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
|
@ -3,6 +3,7 @@ This script creates a OpenAI Request demo for the glm-4-9b model, just Use OpenA
|
|||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
import base64
|
||||
|
||||
base_url = "http://127.0.0.1:8000/v1/"
|
||||
client = OpenAI(api_key="EMPTY", base_url=base_url)
|
||||
|
@ -121,7 +122,86 @@ def simple_chat(use_stream=False):
|
|||
print("Error:", response.status_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# simple_chat(use_stream=False)
|
||||
function_chat(use_stream=False)
|
||||
def create_chat_completion(messages, use_stream=False):
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4v",
|
||||
messages=messages,
|
||||
stream=use_stream,
|
||||
max_tokens=256,
|
||||
temperature=0.4,
|
||||
presence_penalty=1.2,
|
||||
top_p=0.8,
|
||||
)
|
||||
if response:
|
||||
if use_stream:
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
else:
|
||||
print(response)
|
||||
else:
|
||||
print("Error:", response.status_code)
|
||||
|
||||
|
||||
def encode_image(image_path):
|
||||
"""
|
||||
Encodes an image file into a base64 string.
|
||||
Args:
|
||||
image_path (str): The path to the image file.
|
||||
|
||||
This function opens the specified image file, reads its content, and encodes it into a base64 string.
|
||||
The base64 encoding is used to send images over HTTP as text.
|
||||
"""
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def glm4v_simple_image_chat(use_stream=False, img_path=None):
|
||||
"""
|
||||
Facilitates a simple chat interaction involving an image.
|
||||
|
||||
Args:
|
||||
use_stream (bool): Specifies whether to use streaming for chat responses.
|
||||
img_path (str): Path to the image file to be included in the chat.
|
||||
|
||||
This function encodes the specified image and constructs a predefined conversation involving the image.
|
||||
It then calls `create_chat_completion` to generate a response from the model.
|
||||
The conversation includes asking about the content of the image and a follow-up question.
|
||||
"""
|
||||
|
||||
img_url = f"data:image/jpeg;base64,{encode_image(img_path)}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What’s in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": img_url
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The image displays a wooden boardwalk extending through a vibrant green grassy wetland. The sky is partly cloudy with soft, wispy clouds, indicating nice weather. Vegetation is seen on either side of the boardwalk, and trees are present in the background, suggesting that this area might be a natural reserve or park designed for ecological preservation and outdoor recreation. The boardwalk allows visitors to explore the area without disturbing the natural habitat.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Do you think this is a spring or winter photo?"
|
||||
},
|
||||
|
||||
|
||||
]
|
||||
create_chat_completion(messages=messages, use_stream=use_stream)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
simple_chat(use_stream=False)
|
||||
# function_chat(use_stream=False)
|
||||
# glm4v_simple_image_chat(use_stream=False, img_path="demo.jpg")
|
||||
|
||||
|
|
|
@ -1,682 +1,14 @@
|
|||
import time
|
||||
from asyncio.log import logger
|
||||
import re
|
||||
import uvicorn
|
||||
import gc
|
||||
import json
|
||||
import torch
|
||||
import random
|
||||
import string
|
||||
|
||||
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
||||
from fastapi import FastAPI, HTTPException, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoTokenizer, LogitsProcessor
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
# text-model THUDM/glm-4-9b-chat
|
||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
||||
MAX_MODEL_LENGTH = 8192
|
||||
|
||||
# vision-model THUDM/glm-4v-9b
|
||||
# MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def generate_id(prefix: str, k=29) -> str:
|
||||
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
|
||||
return f"{prefix}{suffix}"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str = ""
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = ["glm-4"]
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCall(BaseModel):
|
||||
index: Optional[int] = 0
|
||||
id: Optional[str] = None
|
||||
function: FunctionCall
|
||||
type: Optional[Literal["function"]] = 'function'
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
# “function” 字段解释:
|
||||
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
|
||||
|
||||
role: Literal["user", "assistant", "system", "tool"]
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
index: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.8
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[Union[dict, List[dict]]] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
repetition_penalty: Optional[float] = 1.1
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 5] = 5e4
|
||||
return scores
|
||||
|
||||
|
||||
def process_response(output: str, tools: dict | List[dict] = None, use_tool: bool = False) -> Union[str, dict]:
|
||||
lines = output.strip().split("\n")
|
||||
arguments_json = None
|
||||
special_tools = ["cogview", "simple_browser"]
|
||||
tools = {tool['function']['name'] for tool in tools} if tools else {}
|
||||
|
||||
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
|
||||
|
||||
if len(lines) >= 2 and lines[1].startswith("{"):
|
||||
function_name = lines[0].strip()
|
||||
arguments = "\n".join(lines[1:]).strip()
|
||||
if function_name in tools or function_name in special_tools:
|
||||
try:
|
||||
arguments_json = json.loads(arguments)
|
||||
is_tool_call = True
|
||||
except json.JSONDecodeError:
|
||||
is_tool_call = function_name in special_tools
|
||||
|
||||
if is_tool_call and use_tool:
|
||||
content = {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||
ensure_ascii=False)
|
||||
}
|
||||
if function_name == "simple_browser":
|
||||
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
|
||||
match = search_pattern.match(arguments)
|
||||
if match:
|
||||
content["arguments"] = json.dumps({
|
||||
"query": match.group(1),
|
||||
"recency_days": int(match.group(2))
|
||||
}, ensure_ascii=False)
|
||||
elif function_name == "cogview":
|
||||
content["arguments"] = json.dumps({
|
||||
"prompt": arguments
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return content
|
||||
return output.strip()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
async def generate_stream_glm4(params):
|
||||
messages = params["messages"]
|
||||
tools = params["tools"]
|
||||
tool_choice = params["tool_choice"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
max_new_tokens = int(params.get("max_tokens", 8192))
|
||||
|
||||
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
params_dict = {
|
||||
"n": 1,
|
||||
"best_of": 1,
|
||||
"presence_penalty": 1.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": -1,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"use_beam_search": False,
|
||||
"length_penalty": 1,
|
||||
"early_stopping": False,
|
||||
"stop_token_ids": [151329, 151336, 151338],
|
||||
"ignore_eos": False,
|
||||
"max_tokens": max_new_tokens,
|
||||
"logprobs": None,
|
||||
"prompt_logprobs": None,
|
||||
"skip_special_tokens": True,
|
||||
}
|
||||
sampling_params = SamplingParams(**params_dict)
|
||||
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
||||
output_len = len(output.outputs[0].token_ids)
|
||||
input_len = len(output.prompt_token_ids)
|
||||
ret = {
|
||||
"text": output.outputs[0].text,
|
||||
"usage": {
|
||||
"prompt_tokens": input_len,
|
||||
"completion_tokens": output_len,
|
||||
"total_tokens": output_len + input_len
|
||||
},
|
||||
"finish_reason": output.outputs[0].finish_reason,
|
||||
}
|
||||
yield ret
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def process_messages(messages, tools=None, tool_choice="none"):
|
||||
_messages = messages
|
||||
processed_messages = []
|
||||
msg_has_sys = False
|
||||
|
||||
def filter_tools(tool_choice, tools):
|
||||
function_name = tool_choice.get('function', {}).get('name', None)
|
||||
if not function_name:
|
||||
return []
|
||||
filtered_tools = [
|
||||
tool for tool in tools
|
||||
if tool.get('function', {}).get('name') == function_name
|
||||
]
|
||||
return filtered_tools
|
||||
|
||||
if tool_choice != "none":
|
||||
if isinstance(tool_choice, dict):
|
||||
tools = filter_tools(tool_choice, tools)
|
||||
if tools:
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": None,
|
||||
"tools": tools
|
||||
}
|
||||
)
|
||||
msg_has_sys = True
|
||||
|
||||
if isinstance(tool_choice, dict) and tools:
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"metadata": tool_choice["function"]["name"],
|
||||
"content": ""
|
||||
}
|
||||
)
|
||||
|
||||
for m in _messages:
|
||||
role, content, func_call = m.role, m.content, m.function_call
|
||||
tool_calls = getattr(m, 'tool_calls', None)
|
||||
|
||||
if role == "function":
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "observation",
|
||||
"content": content
|
||||
}
|
||||
)
|
||||
elif role == "tool":
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "observation",
|
||||
"content": content,
|
||||
"function_call": True
|
||||
}
|
||||
)
|
||||
elif role == "assistant":
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"metadata": tool_call.function.name,
|
||||
"content": tool_call.function.arguments
|
||||
}
|
||||
)
|
||||
else:
|
||||
for response in content.split("\n"):
|
||||
if "\n" in response:
|
||||
metadata, sub_content = response.split("\n", maxsplit=1)
|
||||
else:
|
||||
metadata, sub_content = "", response
|
||||
processed_messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"metadata": metadata,
|
||||
"content": sub_content.strip()
|
||||
}
|
||||
)
|
||||
else:
|
||||
if role == "system" and msg_has_sys:
|
||||
msg_has_sys = False
|
||||
continue
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
if not tools or tool_choice == "none":
|
||||
for m in _messages:
|
||||
if m.role == 'system':
|
||||
processed_messages.insert(0, {"role": m.role, "content": m.content})
|
||||
break
|
||||
return processed_messages
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="glm-4")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
echo=False,
|
||||
stream=request.stream,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
logger.debug(f"==== request ====\n{gen_params}")
|
||||
|
||||
if request.stream:
|
||||
predict_stream_generator = predict_stream(request.model, gen_params)
|
||||
output = await anext(predict_stream_generator)
|
||||
if output:
|
||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||
logger.debug(f"First result output:\n{output}")
|
||||
|
||||
function_call = None
|
||||
if output and request.tools:
|
||||
try:
|
||||
function_call = process_response(output, request.tools, use_tool=True)
|
||||
except:
|
||||
logger.warning("Failed to parse tool call")
|
||||
|
||||
if isinstance(function_call, dict):
|
||||
function_call = ChoiceDeltaToolCallFunction(**function_call)
|
||||
generate = parse_output_text(request.model, output, function_call=function_call)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
else:
|
||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||
response = ""
|
||||
async for response in generate_stream_glm4(gen_params):
|
||||
pass
|
||||
|
||||
if response["text"].startswith("\n"):
|
||||
response["text"] = response["text"][1:]
|
||||
response["text"] = response["text"].strip()
|
||||
|
||||
usage = UsageInfo()
|
||||
|
||||
function_call, finish_reason = None, "stop"
|
||||
tool_calls = None
|
||||
if request.tools:
|
||||
try:
|
||||
function_call = process_response(response["text"], request.tools, use_tool=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse tool call: {e}")
|
||||
if isinstance(function_call, dict):
|
||||
finish_reason = "tool_calls"
|
||||
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
|
||||
function_call_instance = FunctionCall(
|
||||
name=function_call_response.name,
|
||||
arguments=function_call_response.arguments
|
||||
)
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id=generate_id('call_', 24),
|
||||
function=function_call_instance,
|
||||
type="function")]
|
||||
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=None if tool_calls else response["text"],
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
logger.debug(f"==== message ====\n{message}")
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
task_usage = UsageInfo.model_validate(response["usage"])
|
||||
for usage_key, usage_value in task_usage.model_dump().items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
choices=[choice_data],
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
|
||||
async def predict_stream(model_id, gen_params):
|
||||
output = ""
|
||||
is_function_call = False
|
||||
has_send_first_chunk = False
|
||||
created_time = int(time.time())
|
||||
function_name = None
|
||||
response_id = generate_id('chatcmpl-', 29)
|
||||
system_fingerprint = generate_id('fp_', 9)
|
||||
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
|
||||
delta_text = ""
|
||||
async for new_response in generate_stream_glm4(gen_params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text += decoded_unicode[len(output):]
|
||||
output = decoded_unicode
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
# 检查是否为工具
|
||||
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
||||
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
|
||||
|
||||
if not is_function_call and len(lines) >= 2:
|
||||
first_line = lines[0].strip()
|
||||
if first_line in tools:
|
||||
is_function_call = True
|
||||
function_name = first_line
|
||||
delta_text = lines[1]
|
||||
|
||||
# 工具调用返回
|
||||
if is_function_call:
|
||||
if not has_send_first_chunk:
|
||||
function_call = {"name": function_name, "arguments": ""}
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
index=0,
|
||||
id=generate_id('call_', 24),
|
||||
function=FunctionCall(**function_call),
|
||||
type="function"
|
||||
)
|
||||
message = DeltaMessage(
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield ""
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
has_send_first_chunk = True
|
||||
|
||||
function_call = {"name": None, "arguments": delta_text}
|
||||
delta_text = ""
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
index=0,
|
||||
id=None,
|
||||
function=FunctionCall(**function_call),
|
||||
type="function"
|
||||
)
|
||||
message = DeltaMessage(
|
||||
content=None,
|
||||
role=None,
|
||||
function_call=None,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
|
||||
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
|
||||
continue
|
||||
|
||||
# 常规返回
|
||||
else:
|
||||
finish_reason = new_response.get("finish_reason", None)
|
||||
if not has_send_first_chunk:
|
||||
message = DeltaMessage(
|
||||
content="",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
has_send_first_chunk = True
|
||||
|
||||
message = DeltaMessage(
|
||||
content=delta_text,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
delta_text = ""
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口
|
||||
if is_function_call:
|
||||
yield ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
system_fingerprint=system_fingerprint,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(
|
||||
content=None,
|
||||
role=None,
|
||||
function_call=None,
|
||||
),
|
||||
finish_reason="tool_calls"
|
||||
)],
|
||||
created=created_time,
|
||||
object="chat.completion.chunk",
|
||||
usage=None
|
||||
).model_dump_json(exclude_unset=True)
|
||||
elif delta_text != "":
|
||||
message = DeltaMessage(
|
||||
content="",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
finish_reason = 'stop'
|
||||
message = DeltaMessage(
|
||||
content=delta_text,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
delta_text = ""
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
yield '[DONE]'
|
||||
else:
|
||||
yield '[DONE]'
|
||||
|
||||
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
|
||||
delta = DeltaMessage(role="assistant", content=value)
|
||||
if function_call is not None:
|
||||
delta.function_call = function_call
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
choices=[choice_data],
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=MODEL_PATH,
|
||||
tokenizer=MODEL_PATH,
|
||||
# 如果你有多张显卡,可以在这里设置成你的显卡数量
|
||||
tensor_parallel_size=1,
|
||||
dtype="bfloat16",
|
||||
trust_remote_code=True,
|
||||
# 占用显存的比例,请根据你的显卡显存大小设置合适的值,例如,如果你的显卡有80G,您只想使用24G,请按照24/80=0.3设置
|
||||
gpu_memory_utilization=0.9,
|
||||
enforce_eager=True,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
max_model_len=MAX_MODEL_LENGTH,
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
||||
if '4v' in MODEL_PATH.lower():
|
||||
subprocess.run(["python", "glm4v_server.py", MODEL_PATH])
|
||||
else:
|
||||
subprocess.run(["python", "glm_server.py", MODEL_PATH])
|
||||
|
|
Loading…
Reference in New Issue