update:trans_batch_demo
This commit is contained in:
parent
81ba7e087c
commit
4dc2f76e68
basic_demo
|
@ -46,7 +46,7 @@ def batch(
|
|||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_input_tokens).to(model.device)
|
||||
|
||||
|
||||
gen_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"num_beams": num_beams,
|
||||
|
@ -75,14 +75,14 @@ if __name__ == "__main__":
|
|||
]
|
||||
|
||||
batch_inputs = []
|
||||
max_input_tokens = 1024
|
||||
max_input_tokens = 128
|
||||
for i, messages in enumerate(batch_message):
|
||||
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)[12:]
|
||||
max_input_tokens = max(max_input_tokens, len(new_batch_input))
|
||||
batch_inputs.append(new_batch_input)
|
||||
gen_kwargs = {
|
||||
"max_input_tokens": max_input_tokens,
|
||||
"max_new_tokens": 8192,
|
||||
"max_new_tokens": 256,
|
||||
"do_sample": True,
|
||||
"top_p": 0.8,
|
||||
"temperature": 0.8,
|
||||
|
|
Loading…
Reference in New Issue