Create fastapi_call_4v_9b.py

添加对glm-4v-9b的基于fastapi的服务化封装
This commit is contained in:
dongfangduoshou123 2024-06-10 11:17:35 +08:00 committed by GitHub
parent adeeb0e8e0
commit ceca8aa3cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 112 additions and 0 deletions

View File

@ -0,0 +1,112 @@
import os
import torch
from threading import Thread
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from transformers import (
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer, AutoModel
)
from pydantic import BaseModel
from PIL import Image
import base64
import io
app = FastAPI()
MODEL_PATH = os.environ.get('MODEL_PATH', '/root/.cache/modelscope/hub/ZhipuAI/glm-4v-9b')
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
encode_special_tokens=True
)
model = AutoModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: list[Message]
temperature: float = 0.6
top_p: float = 0.8
max_tokens: int = 1024
image: str = None
@app.post("/v1/chat/completions")
async def chat(chat_request: ChatRequest):
messages = chat_request.messages
temperature = chat_request.temperature
top_p = chat_request.top_p
max_length = chat_request.max_tokens
image_data = chat_request.image
inputs = []
for message in messages:
inputs.append({"role": message.role, "content": message.content})
if image_data != "-1":
try:
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
except:
raise HTTPException(status_code=400, detail="Invalid image data")
inputs[-1].update({"image": image})
model_inputs = tokenizer.apply_chat_template(
inputs,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(next(model.parameters()).device)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
stop = StopOnTokens()
generate_kwargs = {
**model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
"eos_token_id": [151329, 151336, 151338],
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
response = ""
for new_token in streamer:
if new_token:
response += new_token
return JSONResponse(content={"choices": [{"message": {"role": "assistant", "content": response.strip()}}]})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)