2024-06-05 10:22:16 +08:00
|
|
|
"""
|
2024-10-29 01:41:10 +08:00
|
|
|
This script creates a CLI demo with transformers backend for the glm-4-9b-chat model,
|
2024-06-05 10:22:16 +08:00
|
|
|
allowing users to interact with the model through a command-line interface.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
- Run the script to start the CLI demo.
|
|
|
|
- Interact with the model by typing questions and receiving responses.
|
|
|
|
|
|
|
|
Note: The script includes a modification to handle markdown to plain text conversion,
|
|
|
|
ensuring that the CLI interface displays formatted text correctly.
|
2024-06-24 23:45:04 +08:00
|
|
|
|
|
|
|
If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading.
|
2024-10-29 01:40:11 +08:00
|
|
|
|
2024-06-05 10:22:16 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from threading import Thread
|
2024-10-29 01:40:11 +08:00
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
2024-06-05 10:22:16 +08:00
|
|
|
|
2024-11-01 18:49:55 +08:00
|
|
|
MODEL_PATH = "THUDM/glm-4-9b-chat"
|
2024-06-05 10:22:16 +08:00
|
|
|
|
2024-11-01 18:49:55 +08:00
|
|
|
# trust_remote_code=True is needed if you using with `glm-4-9b-chat`
|
|
|
|
# Not use if you using with `glm-4-9b-chat-hf`
|
|
|
|
# both tokenizer and model should consider with this issue.
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
2024-06-06 16:18:14 +08:00
|
|
|
|
2024-10-29 01:40:11 +08:00
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
MODEL_PATH, # attn_implementation="flash_attention_2", # Use Flash Attention
|
2024-10-29 00:13:41 +08:00
|
|
|
torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16
|
2024-11-01 18:49:55 +08:00
|
|
|
trust_remote_code=True,
|
2024-06-07 16:53:56 +08:00
|
|
|
device_map="auto").eval()
|
2024-06-05 10:22:16 +08:00
|
|
|
|
|
|
|
|
|
|
|
class StopOnTokens(StoppingCriteria):
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
stop_ids = model.config.eos_token_id
|
|
|
|
for stop_id in stop_ids:
|
|
|
|
if input_ids[0][-1] == stop_id:
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
history = []
|
|
|
|
max_length = 8192
|
|
|
|
top_p = 0.8
|
|
|
|
temperature = 0.6
|
|
|
|
stop = StopOnTokens()
|
|
|
|
|
|
|
|
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
|
|
|
|
while True:
|
|
|
|
user_input = input("\nYou: ")
|
|
|
|
if user_input.lower() in ["exit", "quit"]:
|
|
|
|
break
|
|
|
|
history.append([user_input, ""])
|
|
|
|
|
|
|
|
messages = []
|
|
|
|
for idx, (user_msg, model_msg) in enumerate(history):
|
|
|
|
if idx == len(history) - 1 and not model_msg:
|
|
|
|
messages.append({"role": "user", "content": user_msg})
|
|
|
|
break
|
|
|
|
if user_msg:
|
|
|
|
messages.append({"role": "user", "content": user_msg})
|
|
|
|
if model_msg:
|
|
|
|
messages.append({"role": "assistant", "content": model_msg})
|
|
|
|
model_inputs = tokenizer.apply_chat_template(
|
|
|
|
messages,
|
|
|
|
add_generation_prompt=True,
|
|
|
|
tokenize=True,
|
2024-10-29 00:13:41 +08:00
|
|
|
return_dict=True,
|
2024-06-05 10:22:16 +08:00
|
|
|
return_tensors="pt"
|
|
|
|
).to(model.device)
|
|
|
|
streamer = TextIteratorStreamer(
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
timeout=60,
|
|
|
|
skip_prompt=True,
|
|
|
|
skip_special_tokens=True
|
|
|
|
)
|
|
|
|
generate_kwargs = {
|
2024-10-29 00:13:41 +08:00
|
|
|
"input_ids": model_inputs["input_ids"],
|
|
|
|
"attention_mask": model_inputs["attention_mask"],
|
2024-06-05 10:22:16 +08:00
|
|
|
"streamer": streamer,
|
|
|
|
"max_new_tokens": max_length,
|
|
|
|
"do_sample": True,
|
|
|
|
"top_p": top_p,
|
|
|
|
"temperature": temperature,
|
|
|
|
"stopping_criteria": StoppingCriteriaList([stop]),
|
|
|
|
"repetition_penalty": 1.2,
|
|
|
|
"eos_token_id": model.config.eos_token_id,
|
|
|
|
}
|
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
|
|
|
t.start()
|
|
|
|
print("GLM-4:", end="", flush=True)
|
|
|
|
for new_token in streamer:
|
|
|
|
if new_token:
|
|
|
|
print(new_token, end="", flush=True)
|
|
|
|
history[-1][1] += new_token
|
|
|
|
|
|
|
|
history[-1][1] = history[-1][1].strip()
|