glm4/intel_device_demo/openvino/openvino_cli_demo.py

122 lines
4.2 KiB
Python
Raw Normal View History

2024-06-28 16:09:51 +08:00
import argparse
from typing import List, Tuple
from threading import Thread
import torch
from optimum.intel.openvino import OVModelForCausalLM
from transformers import (AutoTokenizer, AutoConfig,
TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria)
class StopOnTokens(StoppingCriteria):
def __init__(self, token_ids):
self.token_ids = token_ids
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in self.token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('-h',
'--help',
action='help',
help='Show this help message and exit.')
parser.add_argument('-m',
'--model_path',
required=True,
type=str,
help='Required. model path')
parser.add_argument('-l',
'--max_sequence_length',
default=256,
required=False,
type=int,
help='Required. maximun length of output')
parser.add_argument('-d',
'--device',
default='CPU',
required=False,
type=str,
help='Required. device for inference')
args = parser.parse_args()
model_dir = args.model_path
ov_config = {"PERFORMANCE_HINT": "LATENCY",
"NUM_STREAMS": "1", "CACHE_DIR": ""}
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
print("====Compiling model====")
ov_model = OVModelForCausalLM.from_pretrained(
model_dir,
device=args.device,
ov_config=ov_config,
config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),
trust_remote_code=True,
)
streamer = TextIteratorStreamer(
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
)
stop_tokens = [StopOnTokens([151329, 151336, 151338])]
def convert_history_to_token(history: List[Tuple[str, str]]):
messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
model_inputs = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt")
return model_inputs
history = []
print("====Starting conversation====")
while True:
input_text = input("用户: ")
if input_text.lower() == 'stop':
break
if input_text.lower() == 'clear':
history = []
print("AI助手: 对话历史已清空")
continue
print("GLM-4-9B-OpenVINO:", end=" ")
history = history + [[input_text, ""]]
model_inputs = convert_history_to_token(history)
generate_kwargs = dict(
input_ids=model_inputs,
max_new_tokens=args.max_sequence_length,
temperature=0.1,
do_sample=True,
top_p=1.0,
top_k=50,
repetition_penalty=1.1,
streamer=streamer,
stopping_criteria=StoppingCriteriaList(stop_tokens)
)
t1 = Thread(target=ov_model.generate, kwargs=generate_kwargs)
t1.start()
partial_text = ""
for new_text in streamer:
new_text = new_text
print(new_text, end="", flush=True)
partial_text += new_text
print("\n")
history[-1][1] = partial_text