add vision demo
This commit is contained in:
parent
b1f014f151
commit
29480b7394
|
@ -50,7 +50,7 @@ Read this in [English](README_en.md)
|
|||
| 精度 | 显存占用 | Prefilling / 首响 | Decode Speed | Remarks |
|
||||
|------|----------|-----------------|------------------|------------|
|
||||
| Int4 | 10267MiB | 0.1685a | 28.7101 tokens/s | 输入长度为 1000 |
|
||||
| Int4 | 14105MiB | 0.8629s | 40.7134 tokens/s | 输入长度为 8000 |
|
||||
| Int4 | 14105MiB | 0.8629s | 24.2370 tokens/s | 输入长度为 8000 |
|
||||
|
||||
### 最低硬件要求
|
||||
|
||||
|
|
|
@ -42,15 +42,15 @@ The stress test data of relevant inference are as follows:
|
|||
|
||||
#### GLM-4V-9B
|
||||
|
||||
| 精度 | 显存占用 | Prefilling / 首响 | Decode Speed | Remarks |
|
||||
|------|----------|-----------------|------------------|------------|
|
||||
| BF16 | 28131MiB | 0.1016s | 33.4660 tokens/s | Input length is 1000 |
|
||||
| BF16 | 33043MiB | 0.7935a | 39.2444 tokens/s | Input length is 8000 |
|
||||
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
||||
|------|------------|-----------------|------------------|------------|
|
||||
| BF16 | 28131MiB | 0.1016s | 33.4660 tokens/s | Input length is 1000 |
|
||||
| BF16 | 33043MiB | 0.7935a | 39.2444 tokens/s | Input length is 8000 |
|
||||
|
||||
| 精度 | 显存占用 | Prefilling / 首响 | Decode Speed | Remarks |
|
||||
|------|----------|-----------------|------------------|------------|
|
||||
| Int4 | 10267MiB | 0.1685a | 28.7101 tokens/s | Input length is 1000 |
|
||||
| Int4 | 14105MiB | 0.8629s | 40.7134 tokens/s | Input length is 8000 |
|
||||
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
||||
|-------|----------|-----------------|------------------|------------|
|
||||
| Int4 | 10267MiB | 0.1685a | 28.7101 tokens/s | Input length is 1000 |
|
||||
| Int4 | 14105MiB | 0.8629s | 24.2370 tokens/s | Input length is 8000 |
|
||||
|
||||
### Minimum hardware requirements
|
||||
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
"""
|
||||
This script creates a CLI demo with transformers backend for the glm-4v-9b model,
|
||||
allowing users to interact with the model through a command-line interface.
|
||||
|
||||
Usage:
|
||||
- Run the script to start the CLI demo.
|
||||
- Interact with the model by typing questions and receiving responses.
|
||||
|
||||
Note: The script includes a modification to handle markdown to plain text conversion,
|
||||
ensuring that the CLI interface displays formatted text correctly.
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TextIteratorStreamer, AutoModel
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
encode_special_tokens=True
|
||||
)
|
||||
model = AutoModel.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16).eval()
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
stop_ids = model.config.eos_token_id
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
history = []
|
||||
max_length = 1024
|
||||
top_p = 0.8
|
||||
temperature = 0.6
|
||||
stop = StopOnTokens()
|
||||
uploaded = False
|
||||
image = None
|
||||
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
|
||||
image_path = input("Image Path:")
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
except:
|
||||
print("Invalid image path. Continuing with text conversation.")
|
||||
while True:
|
||||
user_input = input("\nYou: ")
|
||||
if user_input.lower() in ["exit", "quit"]:
|
||||
break
|
||||
history.append([user_input, ""])
|
||||
|
||||
messages = []
|
||||
for idx, (user_msg, model_msg) in enumerate(history):
|
||||
if idx == len(history) - 1 and not model_msg:
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
if image and not uploaded:
|
||||
messages[-1].update({"image": image})
|
||||
uploaded = True
|
||||
break
|
||||
if user_msg:
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
if model_msg:
|
||||
messages.append({"role": "assistant", "content": model_msg})
|
||||
model_inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True
|
||||
).to(model.device)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=tokenizer,
|
||||
timeout=60,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
generate_kwargs = {
|
||||
**model_inputs,
|
||||
"streamer": streamer,
|
||||
"max_new_tokens": max_length,
|
||||
"do_sample": True,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"stopping_criteria": StoppingCriteriaList([stop]),
|
||||
"repetition_penalty": 1.2,
|
||||
"eos_token_id": [151329, 151336, 151338],
|
||||
}
|
||||
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
t.start()
|
||||
print("GLM-4:", end="", flush=True)
|
||||
for new_token in streamer:
|
||||
if new_token:
|
||||
print(new_token, end="", flush=True)
|
||||
history[-1][1] += new_token
|
||||
|
||||
history[-1][1] = history[-1][1].strip()
|
Loading…
Reference in New Issue