""" This script creates a CLI demo with transformers backend for the glm-4-9b model, allowing users to interact with the model through a command-line interface. Usage: - Run the script to start the CLI demo. - Interact with the model by typing questions and receiving responses. Note: The script includes a modification to handle markdown to plain text conversion, ensuring that the CLI interface displays formatted text correctly. """ import os import torch from threading import Thread from typing import Union from pathlib import Path from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer ) ModelType = Union[PreTrainedModel, PeftModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b') def load_model_and_tokenizer( model_dir: Union[str, Path], trust_remote_code: bool = True ) -> tuple[ModelType, TokenizerType]: model_dir = Path(model_dir).expanduser().resolve() if (model_dir / 'adapter_config.json').exists(): model = AutoPeftModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto') tokenizer_dir = model.peft_config['default'].base_model_name_or_path else: model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=trust_remote_code, device_map='auto') tokenizer_dir = model_dir tokenizer = AutoTokenizer.from_pretrained( tokenizer_dir, trust_remote_code=trust_remote_code, encode_special_tokens=True, use_fast=False ) return model, tokenizer model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = model.config.eos_token_id for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False if __name__ == "__main__": history = [] max_length = 8192 top_p = 0.8 temperature = 0.6 stop = StopOnTokens() print("Welcome to the GLM-4-9B CLI chat. Type your messages below.") while True: user_input = input("\nYou: ") if user_input.lower() in ["exit", "quit"]: break history.append([user_input, ""]) messages = [] for idx, (user_msg, model_msg) in enumerate(history): if idx == len(history) - 1 and not model_msg: messages.append({"role": "user", "content": user_msg}) break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) model_inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1.2, "eos_token_id": model.config.eos_token_id, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() print("GLM-4:", end="", flush=True) for new_token in streamer: if new_token: print(new_token, end="", flush=True) history[-1][1] += new_token history[-1][1] = history[-1][1].strip()