fix combine error
This commit is contained in:
parent
913cb6dc06
commit
46ecf7cca7
|
@ -135,6 +135,7 @@ class FinetuningConfig(object):
|
|||
|
||||
max_input_length: int
|
||||
max_output_length: int
|
||||
combine: bool
|
||||
|
||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||
|
@ -247,6 +248,7 @@ def process_batch(
|
|||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int,
|
||||
max_output_length: int,
|
||||
combine: bool,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_input_ids = []
|
||||
|
@ -324,6 +326,7 @@ def process_batch_eval(
|
|||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int,
|
||||
max_output_length: int,
|
||||
combine: bool,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_input_ids = []
|
||||
|
@ -449,6 +452,7 @@ def main(
|
|||
Split.TRAIN,
|
||||
functools.partial(
|
||||
process_batch,
|
||||
combine=ft_config.combine, # Not use now
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
|
@ -461,9 +465,11 @@ def main(
|
|||
Split.VALIDATION,
|
||||
functools.partial(
|
||||
process_batch_eval,
|
||||
combine=ft_config.combine,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
|
||||
),
|
||||
batched=True,
|
||||
)
|
||||
|
@ -474,6 +480,7 @@ def main(
|
|||
Split.TEST,
|
||||
functools.partial(
|
||||
process_batch_eval,
|
||||
combine=ft_config.combine,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
|
|
Loading…
Reference in New Issue