Add GLM-4V-9B vision model gradio webui
This commit is contained in:
parent
0b67f9338e
commit
35ba249d28
basic_demo
|
@ -0,0 +1,121 @@
|
||||||
|
"""
|
||||||
|
This script creates a Gradio demo with a Transformers backend for the glm-4v-9b model, allowing users to interact with the model through a Gradio web UI.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- Run the script to start the Gradio server.
|
||||||
|
- Interact with the model via the web UI.
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- Gradio package
|
||||||
|
- Type `pip install gradio` to install Gradio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import gradio as gr
|
||||||
|
from threading import Thread
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
StoppingCriteria,
|
||||||
|
StoppingCriteriaList,
|
||||||
|
TextIteratorStreamer, AutoModel, BitsAndBytesConfig
|
||||||
|
)
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
MODEL_PATH,
|
||||||
|
trust_remote_code=True,
|
||||||
|
encode_special_tokens=True
|
||||||
|
)
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
MODEL_PATH,
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
class StopOnTokens(StoppingCriteria):
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
stop_ids = model.config.eos_token_id
|
||||||
|
for stop_id in stop_ids:
|
||||||
|
if input_ids[0][-1] == stop_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_image(image_path=None, image_url=None):
|
||||||
|
if image_path:
|
||||||
|
return Image.open(image_path).convert("RGB")
|
||||||
|
elif image_url:
|
||||||
|
response = requests.get(image_url)
|
||||||
|
return Image.open(BytesIO(response.content)).convert("RGB")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def chatbot(image_path=None, image_url=None, assistant_prompt=""):
|
||||||
|
image = get_image(image_path, image_url)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "assistant", "content": assistant_prompt},
|
||||||
|
{"role": "user", "content": "", "image": image}
|
||||||
|
]
|
||||||
|
|
||||||
|
model_inputs = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True
|
||||||
|
).to(next(model.parameters()).device)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
timeout=60,
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_kwargs = {
|
||||||
|
**model_inputs,
|
||||||
|
"streamer": streamer,
|
||||||
|
"max_new_tokens": 1024,
|
||||||
|
"do_sample": True,
|
||||||
|
"top_p": 0.8,
|
||||||
|
"temperature": 0.6,
|
||||||
|
"stopping_criteria": StoppingCriteriaList([StopOnTokens()]),
|
||||||
|
"repetition_penalty": 1.2,
|
||||||
|
"eos_token_id": [151329, 151336, 151338],
|
||||||
|
}
|
||||||
|
|
||||||
|
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
for new_token in streamer:
|
||||||
|
if new_token:
|
||||||
|
response += new_token
|
||||||
|
|
||||||
|
return image, response.strip()
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
demo.title = "GLM-4V-9B Image Recognition Demo"
|
||||||
|
demo.description = """
|
||||||
|
This demo uses the GLM-4V-9B model to got image infomation.
|
||||||
|
"""
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
image_path_input = gr.File(label="Upload Image (High-Priority)", type="filepath")
|
||||||
|
image_url_input = gr.Textbox(label="Image URL (Low-Priority)")
|
||||||
|
assistant_prompt_input = gr.Textbox(label="Assistant Prompt (You Can Change It)", value="这是什么?")
|
||||||
|
submit_button = gr.Button("Submit")
|
||||||
|
with gr.Column():
|
||||||
|
chatbot_output = gr.Textbox(label="GLM-4V-9B Model Response")
|
||||||
|
image_output = gr.Image(label="Image Preview")
|
||||||
|
|
||||||
|
submit_button.click(chatbot,
|
||||||
|
inputs=[image_path_input, image_url_input, assistant_prompt_input],
|
||||||
|
outputs=[image_output, chatbot_output])
|
||||||
|
|
||||||
|
demo.launch(server_name="127.0.0.1", server_port=8911, inbrowser=True, share=False)
|
Loading…
Reference in New Issue