fix combine error

This commit is contained in:
zR 2024-07-25 10:14:51 +08:00
parent 913cb6dc06
commit 46ecf7cca7
1 changed files with 7 additions and 0 deletions

View File

@ -135,6 +135,7 @@ class FinetuningConfig(object):
max_input_length: int
max_output_length: int
combine: bool
training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
@ -247,6 +248,7 @@ def process_batch(
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
@ -324,6 +326,7 @@ def process_batch_eval(
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
@ -449,6 +452,7 @@ def main(
Split.TRAIN,
functools.partial(
process_batch,
combine=ft_config.combine, # Not use now
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
@ -461,9 +465,11 @@ def main(
Split.VALIDATION,
functools.partial(
process_batch_eval,
combine=ft_config.combine,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
@ -474,6 +480,7 @@ def main(
Split.TEST,
functools.partial(
process_batch_eval,
combine=ft_config.combine,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,