glm4/basic_demo/trans_cli_vision_demo.py

126 lines
3.8 KiB
Python
Raw Normal View History

2024-06-05 13:21:23 +08:00
"""
This script creates a CLI demo with transformers backend for the glm-4v-9b model,
allowing users to interact with the model through a command-line interface.
Usage:
- Run the script to start the CLI demo.
- Interact with the model by typing questions and receiving responses.
Note: The script includes a modification to handle markdown to plain text conversion,
ensuring that the CLI interface displays formatted text correctly.
"""
import torch
from threading import Thread
from transformers import (
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
2024-11-01 18:49:55 +08:00
TextIteratorStreamer,
AutoModel,
BitsAndBytesConfig
2024-06-05 13:21:23 +08:00
)
from PIL import Image
2024-11-01 18:49:55 +08:00
MODEL_PATH = "THUDM/glm-4v-9b"
2024-06-05 13:21:23 +08:00
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
encode_special_tokens=True
)
2024-11-01 18:21:56 +08:00
## For BF16 inference
2024-06-05 13:21:23 +08:00
model = AutoModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
2024-06-24 23:45:04 +08:00
# attn_implementation="flash_attention_2", # Use Flash Attention
2024-07-02 01:00:12 +08:00
torch_dtype=torch.bfloat16,
2024-06-05 13:21:23 +08:00
device_map="auto",
2024-06-07 16:53:56 +08:00
).eval()
2024-06-05 13:21:23 +08:00
2024-06-07 16:53:56 +08:00
## For INT4 inference
# model = AutoModel.from_pretrained(
# MODEL_PATH,
# trust_remote_code=True,
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True
# ).eval()
2024-06-05 13:21:23 +08:00
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
if __name__ == "__main__":
history = []
max_length = 1024
top_p = 0.8
temperature = 0.6
stop = StopOnTokens()
uploaded = False
image = None
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
image_path = input("Image Path:")
try:
image = Image.open(image_path).convert("RGB")
except:
print("Invalid image path. Continuing with text conversation.")
while True:
user_input = input("\nYou: ")
if user_input.lower() in ["exit", "quit"]:
break
history.append([user_input, ""])
messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
if image and not uploaded:
messages[-1].update({"image": image})
uploaded = True
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
2024-06-07 16:53:56 +08:00
).to(next(model.parameters()).device)
2024-06-05 13:21:23 +08:00
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
**model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
"eos_token_id": [151329, 151336, 151338],
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
print("GLM-4V:", end="", flush=True)
2024-06-05 13:21:23 +08:00
for new_token in streamer:
if new_token:
print(new_token, end="", flush=True)
history[-1][1] += new_token
history[-1][1] = history[-1][1].strip()