fix combine error
This commit is contained in:
parent
913cb6dc06
commit
46ecf7cca7
|
@ -135,6 +135,7 @@ class FinetuningConfig(object):
|
||||||
|
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_output_length: int
|
max_output_length: int
|
||||||
|
combine: bool
|
||||||
|
|
||||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||||
|
@ -247,6 +248,7 @@ def process_batch(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
max_input_length: int,
|
max_input_length: int,
|
||||||
max_output_length: int,
|
max_output_length: int,
|
||||||
|
combine: bool,
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list]:
|
||||||
batched_conv = batch['messages']
|
batched_conv = batch['messages']
|
||||||
batched_input_ids = []
|
batched_input_ids = []
|
||||||
|
@ -324,6 +326,7 @@ def process_batch_eval(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
max_input_length: int,
|
max_input_length: int,
|
||||||
max_output_length: int,
|
max_output_length: int,
|
||||||
|
combine: bool,
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list]:
|
||||||
batched_conv = batch['messages']
|
batched_conv = batch['messages']
|
||||||
batched_input_ids = []
|
batched_input_ids = []
|
||||||
|
@ -449,6 +452,7 @@ def main(
|
||||||
Split.TRAIN,
|
Split.TRAIN,
|
||||||
functools.partial(
|
functools.partial(
|
||||||
process_batch,
|
process_batch,
|
||||||
|
combine=ft_config.combine, # Not use now
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_input_length=ft_config.max_input_length,
|
max_input_length=ft_config.max_input_length,
|
||||||
max_output_length=ft_config.max_output_length,
|
max_output_length=ft_config.max_output_length,
|
||||||
|
@ -461,9 +465,11 @@ def main(
|
||||||
Split.VALIDATION,
|
Split.VALIDATION,
|
||||||
functools.partial(
|
functools.partial(
|
||||||
process_batch_eval,
|
process_batch_eval,
|
||||||
|
combine=ft_config.combine,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_input_length=ft_config.max_input_length,
|
max_input_length=ft_config.max_input_length,
|
||||||
max_output_length=ft_config.max_output_length,
|
max_output_length=ft_config.max_output_length,
|
||||||
|
|
||||||
),
|
),
|
||||||
batched=True,
|
batched=True,
|
||||||
)
|
)
|
||||||
|
@ -474,6 +480,7 @@ def main(
|
||||||
Split.TEST,
|
Split.TEST,
|
||||||
functools.partial(
|
functools.partial(
|
||||||
process_batch_eval,
|
process_batch_eval,
|
||||||
|
combine=ft_config.combine,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_input_length=ft_config.max_input_length,
|
max_input_length=ft_config.max_input_length,
|
||||||
max_output_length=ft_config.max_output_length,
|
max_output_length=ft_config.max_output_length,
|
||||||
|
|
Loading…
Reference in New Issue