glm4/basic_demo/trans_batch_demo.py

96 lines
3.1 KiB
Python

"""
Here is an example of using batch request glm-4-9b,
here you need to build the conversation format yourself and then call the batch function to make batch requests.
Please note that in this demo, the memory consumption is significantly higher.
Note:
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
"""
from typing import Union
from transformers import AutoTokenizer, LogitsProcessorList, AutoModelForCausalLM
MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto").eval()
def process_model_outputs(inputs, outputs, tokenizer):
responses = []
for input_ids, output_ids in zip(inputs.input_ids, outputs):
response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
responses.append(response)
return responses
def batch(
model,
tokenizer,
messages: Union[str, list[str]],
max_input_tokens: int = 8192,
max_new_tokens: int = 8192,
num_beams: int = 1,
do_sample: bool = True,
top_p: float = 0.8,
temperature: float = 0.8,
logits_processor=None,
):
if logits_processor is None:
logits_processor = LogitsProcessorList()
messages = [messages] if isinstance(messages, str) else messages
batched_inputs = tokenizer(
messages,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_input_tokens).to(model.device)
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
"eos_token_id": model.config.eos_token_id
}
batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer)
return batched_response
if __name__ == "__main__":
batch_message = [
[
{"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"},
{"role": "assistant", "content": "因为他们结婚时你还没有出生"},
{"role": "user", "content": "我刚才的提问是"}
],
[
{"role": "user", "content": "你好,你是谁"}
]
]
batch_inputs = []
max_input_tokens = 128
for i, messages in enumerate(batch_message):
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)[12:]
max_input_tokens = max(max_input_tokens, len(new_batch_input))
batch_inputs.append(new_batch_input)
gen_kwargs = {
"max_input_tokens": max_input_tokens,
"max_new_tokens": 256,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.8,
"num_beams": 1,
}
batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs)
for response in batch_responses:
print("=" * 10)
print(response)