update:trans_batch_demo
This commit is contained in:
parent
81ba7e087c
commit
4dc2f76e68
|
@ -46,7 +46,7 @@ def batch(
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_input_tokens).to(model.device)
|
max_length=max_input_tokens).to(model.device)
|
||||||
|
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"num_beams": num_beams,
|
"num_beams": num_beams,
|
||||||
|
@ -75,14 +75,14 @@ if __name__ == "__main__":
|
||||||
]
|
]
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_input_tokens = 1024
|
max_input_tokens = 128
|
||||||
for i, messages in enumerate(batch_message):
|
for i, messages in enumerate(batch_message):
|
||||||
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)[12:]
|
||||||
max_input_tokens = max(max_input_tokens, len(new_batch_input))
|
max_input_tokens = max(max_input_tokens, len(new_batch_input))
|
||||||
batch_inputs.append(new_batch_input)
|
batch_inputs.append(new_batch_input)
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"max_input_tokens": max_input_tokens,
|
"max_input_tokens": max_input_tokens,
|
||||||
"max_new_tokens": 8192,
|
"max_new_tokens": 256,
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"top_p": 0.8,
|
"top_p": 0.8,
|
||||||
"temperature": 0.8,
|
"temperature": 0.8,
|
||||||
|
|
Loading…
Reference in New Issue