diff --git a/basic_demo/fastapi_call_4v_9b.py b/basic_demo/fastapi_call_4v_9b.py new file mode 100644 index 0000000..d18d281 --- /dev/null +++ b/basic_demo/fastapi_call_4v_9b.py @@ -0,0 +1,112 @@ +import os +import torch +from threading import Thread +from fastapi import FastAPI, UploadFile, File, Form, HTTPException +from fastapi.responses import JSONResponse +from transformers import ( + AutoTokenizer, + StoppingCriteria, + StoppingCriteriaList, + TextIteratorStreamer, AutoModel +) +from pydantic import BaseModel +from PIL import Image +import base64 +import io + +app = FastAPI() + +MODEL_PATH = os.environ.get('MODEL_PATH', '/root/.cache/modelscope/hub/ZhipuAI/glm-4v-9b') + +tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + encode_special_tokens=True +) +model = AutoModel.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + device_map="auto", + torch_dtype=torch.bfloat16 +).eval() + +class StopOnTokens(StoppingCriteria): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + stop_ids = model.config.eos_token_id + for stop_id in stop_ids: + if input_ids[0][-1] == stop_id: + return True + return False + +class Message(BaseModel): + role: str + content: str + +class ChatRequest(BaseModel): + model: str + messages: list[Message] + temperature: float = 0.6 + top_p: float = 0.8 + max_tokens: int = 1024 + image: str = None + +@app.post("/v1/chat/completions") +async def chat(chat_request: ChatRequest): + messages = chat_request.messages + temperature = chat_request.temperature + top_p = chat_request.top_p + max_length = chat_request.max_tokens + image_data = chat_request.image + + inputs = [] + for message in messages: + inputs.append({"role": message.role, "content": message.content}) + + if image_data != "-1": + try: + image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB") + except: + raise HTTPException(status_code=400, detail="Invalid image data") + inputs[-1].update({"image": image}) + + model_inputs = tokenizer.apply_chat_template( + inputs, + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True + ).to(next(model.parameters()).device) + + streamer = TextIteratorStreamer( + tokenizer=tokenizer, + timeout=60, + skip_prompt=True, + skip_special_tokens=True + ) + + stop = StopOnTokens() + generate_kwargs = { + **model_inputs, + "streamer": streamer, + "max_new_tokens": max_length, + "do_sample": True, + "top_p": top_p, + "temperature": temperature, + "stopping_criteria": StoppingCriteriaList([stop]), + "repetition_penalty": 1.2, + "eos_token_id": [151329, 151336, 151338], + } + + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() + + response = "" + for new_token in streamer: + if new_token: + response += new_token + + return JSONResponse(content={"choices": [{"message": {"role": "assistant", "content": response.strip()}}]}) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000)