diff --git a/basic_demo/trans_batch_demo.py b/basic_demo/trans_batch_demo.py index 7deebf9..5ae6817 100644 --- a/basic_demo/trans_batch_demo.py +++ b/basic_demo/trans_batch_demo.py @@ -46,7 +46,7 @@ def batch( padding="max_length", truncation=True, max_length=max_input_tokens).to(model.device) - + gen_kwargs = { "max_new_tokens": max_new_tokens, "num_beams": num_beams, @@ -75,14 +75,14 @@ if __name__ == "__main__": ] batch_inputs = [] - max_input_tokens = 1024 + max_input_tokens = 128 for i, messages in enumerate(batch_message): - new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)[12:] max_input_tokens = max(max_input_tokens, len(new_batch_input)) batch_inputs.append(new_batch_input) gen_kwargs = { "max_input_tokens": max_input_tokens, - "max_new_tokens": 8192, + "max_new_tokens": 256, "do_sample": True, "top_p": 0.8, "temperature": 0.8,