fix combine error

This commit is contained in:
zR 2024-07-25 10:14:51 +08:00
parent 913cb6dc06
commit 46ecf7cca7
1 changed files with 7 additions and 0 deletions

View File

@ -135,6 +135,7 @@ class FinetuningConfig(object):
max_input_length: int max_input_length: int
max_output_length: int max_output_length: int
combine: bool
training_args: Seq2SeqTrainingArguments = dc.field( training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output') default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
@ -247,6 +248,7 @@ def process_batch(
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
max_input_length: int, max_input_length: int,
max_output_length: int, max_output_length: int,
combine: bool,
) -> dict[str, list]: ) -> dict[str, list]:
batched_conv = batch['messages'] batched_conv = batch['messages']
batched_input_ids = [] batched_input_ids = []
@ -324,6 +326,7 @@ def process_batch_eval(
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
max_input_length: int, max_input_length: int,
max_output_length: int, max_output_length: int,
combine: bool,
) -> dict[str, list]: ) -> dict[str, list]:
batched_conv = batch['messages'] batched_conv = batch['messages']
batched_input_ids = [] batched_input_ids = []
@ -449,6 +452,7 @@ def main(
Split.TRAIN, Split.TRAIN,
functools.partial( functools.partial(
process_batch, process_batch,
combine=ft_config.combine, # Not use now
tokenizer=tokenizer, tokenizer=tokenizer,
max_input_length=ft_config.max_input_length, max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length, max_output_length=ft_config.max_output_length,
@ -461,9 +465,11 @@ def main(
Split.VALIDATION, Split.VALIDATION,
functools.partial( functools.partial(
process_batch_eval, process_batch_eval,
combine=ft_config.combine,
tokenizer=tokenizer, tokenizer=tokenizer,
max_input_length=ft_config.max_input_length, max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length, max_output_length=ft_config.max_output_length,
), ),
batched=True, batched=True,
) )
@ -474,6 +480,7 @@ def main(
Split.TEST, Split.TEST,
functools.partial( functools.partial(
process_batch_eval, process_batch_eval,
combine=ft_config.combine,
tokenizer=tokenizer, tokenizer=tokenizer,
max_input_length=ft_config.max_input_length, max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length, max_output_length=ft_config.max_output_length,