122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
import argparse
|
|
from typing import List, Tuple
|
|
from threading import Thread
|
|
import torch
|
|
from optimum.intel.openvino import OVModelForCausalLM
|
|
from transformers import (AutoTokenizer, AutoConfig,
|
|
TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria)
|
|
|
|
class StopOnTokens(StoppingCriteria):
|
|
def __init__(self, token_ids):
|
|
self.token_ids = token_ids
|
|
|
|
def __call__(
|
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
|
) -> bool:
|
|
for stop_id in self.token_ids:
|
|
if input_ids[0][-1] == stop_id:
|
|
return True
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
parser.add_argument('-h',
|
|
'--help',
|
|
action='help',
|
|
help='Show this help message and exit.')
|
|
parser.add_argument('-m',
|
|
'--model_path',
|
|
required=True,
|
|
type=str,
|
|
help='Required. model path')
|
|
parser.add_argument('-l',
|
|
'--max_sequence_length',
|
|
default=256,
|
|
required=False,
|
|
type=int,
|
|
help='Required. maximun length of output')
|
|
parser.add_argument('-d',
|
|
'--device',
|
|
default='CPU',
|
|
required=False,
|
|
type=str,
|
|
help='Required. device for inference')
|
|
args = parser.parse_args()
|
|
model_dir = args.model_path
|
|
|
|
ov_config = {"PERFORMANCE_HINT": "LATENCY",
|
|
"NUM_STREAMS": "1", "CACHE_DIR": ""}
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_dir, trust_remote_code=True)
|
|
|
|
print("====Compiling model====")
|
|
ov_model = OVModelForCausalLM.from_pretrained(
|
|
model_dir,
|
|
device=args.device,
|
|
ov_config=ov_config,
|
|
config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),
|
|
trust_remote_code=True,
|
|
)
|
|
|
|
streamer = TextIteratorStreamer(
|
|
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
|
|
)
|
|
stop_tokens = [StopOnTokens([151329, 151336, 151338])]
|
|
|
|
def convert_history_to_token(history: List[Tuple[str, str]]):
|
|
|
|
messages = []
|
|
for idx, (user_msg, model_msg) in enumerate(history):
|
|
if idx == len(history) - 1 and not model_msg:
|
|
messages.append({"role": "user", "content": user_msg})
|
|
break
|
|
if user_msg:
|
|
messages.append({"role": "user", "content": user_msg})
|
|
if model_msg:
|
|
messages.append({"role": "assistant", "content": model_msg})
|
|
|
|
model_inputs = tokenizer.apply_chat_template(messages,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_tensors="pt")
|
|
return model_inputs
|
|
|
|
history = []
|
|
print("====Starting conversation====")
|
|
while True:
|
|
input_text = input("用户: ")
|
|
if input_text.lower() == 'stop':
|
|
break
|
|
|
|
if input_text.lower() == 'clear':
|
|
history = []
|
|
print("AI助手: 对话历史已清空")
|
|
continue
|
|
|
|
print("GLM-4-9B-OpenVINO:", end=" ")
|
|
history = history + [[input_text, ""]]
|
|
model_inputs = convert_history_to_token(history)
|
|
generate_kwargs = dict(
|
|
input_ids=model_inputs,
|
|
max_new_tokens=args.max_sequence_length,
|
|
temperature=0.1,
|
|
do_sample=True,
|
|
top_p=1.0,
|
|
top_k=50,
|
|
repetition_penalty=1.1,
|
|
streamer=streamer,
|
|
stopping_criteria=StoppingCriteriaList(stop_tokens)
|
|
)
|
|
|
|
t1 = Thread(target=ov_model.generate, kwargs=generate_kwargs)
|
|
t1.start()
|
|
|
|
partial_text = ""
|
|
for new_text in streamer:
|
|
new_text = new_text
|
|
print(new_text, end="", flush=True)
|
|
partial_text += new_text
|
|
print("\n")
|
|
history[-1][1] = partial_text |