96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
"""
|
|
|
|
Here is an example of using batch request glm-4-9b,
|
|
here you need to build the conversation format yourself and then call the batch function to make batch requests.
|
|
Please note that in this demo, the memory consumption is significantly higher.
|
|
|
|
Note:
|
|
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
|
|
|
|
"""
|
|
|
|
from typing import Union
|
|
from transformers import AutoTokenizer, LogitsProcessorList, AutoModelForCausalLM
|
|
|
|
MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto").eval()
|
|
|
|
def process_model_outputs(inputs, outputs, tokenizer):
|
|
responses = []
|
|
for input_ids, output_ids in zip(inputs.input_ids, outputs):
|
|
response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
|
|
responses.append(response)
|
|
return responses
|
|
|
|
|
|
def batch(
|
|
model,
|
|
tokenizer,
|
|
messages: Union[str, list[str]],
|
|
max_input_tokens: int = 8192,
|
|
max_new_tokens: int = 8192,
|
|
num_beams: int = 1,
|
|
do_sample: bool = True,
|
|
top_p: float = 0.8,
|
|
temperature: float = 0.8,
|
|
logits_processor=None,
|
|
):
|
|
if logits_processor is None:
|
|
logits_processor = LogitsProcessorList()
|
|
messages = [messages] if isinstance(messages, str) else messages
|
|
batched_inputs = tokenizer(
|
|
messages,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=max_input_tokens).to(model.device)
|
|
|
|
gen_kwargs = {
|
|
"max_new_tokens": max_new_tokens,
|
|
"num_beams": num_beams,
|
|
"do_sample": do_sample,
|
|
"top_p": top_p,
|
|
"temperature": temperature,
|
|
"logits_processor": logits_processor,
|
|
"eos_token_id": model.config.eos_token_id
|
|
}
|
|
batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
|
|
batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer)
|
|
return batched_response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
batch_message = [
|
|
[
|
|
{"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"},
|
|
{"role": "assistant", "content": "因为他们结婚时你还没有出生"},
|
|
{"role": "user", "content": "我刚才的提问是"}
|
|
],
|
|
[
|
|
{"role": "user", "content": "你好,你是谁"}
|
|
]
|
|
]
|
|
|
|
batch_inputs = []
|
|
max_input_tokens = 128
|
|
for i, messages in enumerate(batch_message):
|
|
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)[12:]
|
|
max_input_tokens = max(max_input_tokens, len(new_batch_input))
|
|
batch_inputs.append(new_batch_input)
|
|
gen_kwargs = {
|
|
"max_input_tokens": max_input_tokens,
|
|
"max_new_tokens": 256,
|
|
"do_sample": True,
|
|
"top_p": 0.8,
|
|
"temperature": 0.8,
|
|
"num_beams": 1,
|
|
}
|
|
|
|
batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs)
|
|
for response in batch_responses:
|
|
print("=" * 10)
|
|
print(response)
|