Create fastapi_call_4v_9b.py
添加对glm-4v-9b的基于fastapi的服务化封装
This commit is contained in:
parent
adeeb0e8e0
commit
ceca8aa3cc
|
@ -0,0 +1,112 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from threading import Thread
|
||||||
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
StoppingCriteria,
|
||||||
|
StoppingCriteriaList,
|
||||||
|
TextIteratorStreamer, AutoModel
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from PIL import Image
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
MODEL_PATH = os.environ.get('MODEL_PATH', '/root/.cache/modelscope/hub/ZhipuAI/glm-4v-9b')
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
MODEL_PATH,
|
||||||
|
trust_remote_code=True,
|
||||||
|
encode_special_tokens=True
|
||||||
|
)
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
MODEL_PATH,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
class StopOnTokens(StoppingCriteria):
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
stop_ids = model.config.eos_token_id
|
||||||
|
for stop_id in stop_ids:
|
||||||
|
if input_ids[0][-1] == stop_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: list[Message]
|
||||||
|
temperature: float = 0.6
|
||||||
|
top_p: float = 0.8
|
||||||
|
max_tokens: int = 1024
|
||||||
|
image: str = None
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat(chat_request: ChatRequest):
|
||||||
|
messages = chat_request.messages
|
||||||
|
temperature = chat_request.temperature
|
||||||
|
top_p = chat_request.top_p
|
||||||
|
max_length = chat_request.max_tokens
|
||||||
|
image_data = chat_request.image
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
for message in messages:
|
||||||
|
inputs.append({"role": message.role, "content": message.content})
|
||||||
|
|
||||||
|
if image_data != "-1":
|
||||||
|
try:
|
||||||
|
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||||
|
inputs[-1].update({"image": image})
|
||||||
|
|
||||||
|
model_inputs = tokenizer.apply_chat_template(
|
||||||
|
inputs,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True
|
||||||
|
).to(next(model.parameters()).device)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
timeout=60,
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
stop = StopOnTokens()
|
||||||
|
generate_kwargs = {
|
||||||
|
**model_inputs,
|
||||||
|
"streamer": streamer,
|
||||||
|
"max_new_tokens": max_length,
|
||||||
|
"do_sample": True,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature,
|
||||||
|
"stopping_criteria": StoppingCriteriaList([stop]),
|
||||||
|
"repetition_penalty": 1.2,
|
||||||
|
"eos_token_id": [151329, 151336, 151338],
|
||||||
|
}
|
||||||
|
|
||||||
|
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
for new_token in streamer:
|
||||||
|
if new_token:
|
||||||
|
response += new_token
|
||||||
|
|
||||||
|
return JSONResponse(content={"choices": [{"message": {"role": "assistant", "content": response.strip()}}]})
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
Loading…
Reference in New Issue