update:trans_batch_demo

This commit is contained in:
zhipuch 2024-12-31 16:07:08 +08:00
parent 81ba7e087c
commit 4dc2f76e68
1 changed files with 4 additions and 4 deletions

View File

@ -46,7 +46,7 @@ def batch(
padding="max_length", padding="max_length",
truncation=True, truncation=True,
max_length=max_input_tokens).to(model.device) max_length=max_input_tokens).to(model.device)
gen_kwargs = { gen_kwargs = {
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
"num_beams": num_beams, "num_beams": num_beams,
@ -75,14 +75,14 @@ if __name__ == "__main__":
] ]
batch_inputs = [] batch_inputs = []
max_input_tokens = 1024 max_input_tokens = 128
for i, messages in enumerate(batch_message): for i, messages in enumerate(batch_message):
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)[12:]
max_input_tokens = max(max_input_tokens, len(new_batch_input)) max_input_tokens = max(max_input_tokens, len(new_batch_input))
batch_inputs.append(new_batch_input) batch_inputs.append(new_batch_input)
gen_kwargs = { gen_kwargs = {
"max_input_tokens": max_input_tokens, "max_input_tokens": max_input_tokens,
"max_new_tokens": 8192, "max_new_tokens": 256,
"do_sample": True, "do_sample": True,
"top_p": 0.8, "top_p": 0.8,
"temperature": 0.8, "temperature": 0.8,