glm4/basic_demo/fastapi_call_4v_9b.py

113 lines
3.0 KiB
Python
Raw Normal View History

import os
import torch
from threading import Thread
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from transformers import (
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer, AutoModel
)
from pydantic import BaseModel
from PIL import Image
import base64
import io
app = FastAPI()
MODEL_PATH = os.environ.get('MODEL_PATH', '/root/.cache/modelscope/hub/ZhipuAI/glm-4v-9b')
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
encode_special_tokens=True
)
model = AutoModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: list[Message]
temperature: float = 0.6
top_p: float = 0.8
max_tokens: int = 1024
image: str = None
@app.post("/v1/chat/completions")
async def chat(chat_request: ChatRequest):
messages = chat_request.messages
temperature = chat_request.temperature
top_p = chat_request.top_p
max_length = chat_request.max_tokens
image_data = chat_request.image
inputs = []
for message in messages:
inputs.append({"role": message.role, "content": message.content})
if image_data != "-1":
try:
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
except:
raise HTTPException(status_code=400, detail="Invalid image data")
inputs[-1].update({"image": image})
model_inputs = tokenizer.apply_chat_template(
inputs,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(next(model.parameters()).device)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
stop = StopOnTokens()
generate_kwargs = {
**model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
"eos_token_id": [151329, 151336, 151338],
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
response = ""
for new_token in streamer:
if new_token:
response += new_token
return JSONResponse(content={"choices": [{"message": {"role": "assistant", "content": response.strip()}}]})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)