update finetune

This commit is contained in:
zR 2024-07-19 22:04:53 +08:00
parent 0b80c79b8c
commit 4daff94302
4 changed files with 67 additions and 31 deletions

View File

@ -3,8 +3,11 @@ data_config:
val_file: dev.jsonl val_file: dev.jsonl
test_file: dev.jsonl test_file: dev.jsonl
num_proc: 1 num_proc: 1
combine: True
max_input_length: 512 max_input_length: 512
max_output_length: 512 max_output_length: 512
training_args: training_args:
# see `transformers.Seq2SeqTrainingArguments` # see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output output_dir: ./output

View File

@ -3,8 +3,11 @@ data_config:
val_file: dev.jsonl val_file: dev.jsonl
test_file: dev.jsonl test_file: dev.jsonl
num_proc: 1 num_proc: 1
max_input_length: 128
max_output_length: 128 combine: True
max_input_length: 512
max_output_length: 512
training_args: training_args:
# see `transformers.Seq2SeqTrainingArguments` # see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output output_dir: ./output

View File

@ -3,8 +3,11 @@ data_config:
val_file: dev.jsonl val_file: dev.jsonl
test_file: dev.jsonl test_file: dev.jsonl
num_proc: 1 num_proc: 1
max_input_length: 256
combine: True
max_input_length: 512
max_output_length: 512 max_output_length: 512
training_args: training_args:
# see `transformers.Seq2SeqTrainingArguments` # see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output output_dir: ./output

View File

@ -129,6 +129,7 @@ class FinetuningConfig(object):
max_input_length: int max_input_length: int
max_output_length: int max_output_length: int
combine: bool
training_args: Seq2SeqTrainingArguments = dc.field( training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output') default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
@ -250,21 +251,29 @@ def process_batch(
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
max_input_length: int, max_input_length: int,
max_output_length: int, max_output_length: int,
combine: bool,
) -> dict[str, list]: ) -> dict[str, list]:
batched_conv = batch['messages'] batched_conv = batch['messages']
batched_input_ids = [] batched_input_ids = []
batched_labels = [] batched_labels = []
for conv in batched_conv: for conv in batched_conv:
input_ids = [151331, 151333] input_ids = [151331, 151333]
loss_masks = [False, False] loss_masks = [False, False]
for message in conv: if combine:
message = process_message(message) new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True input_ids = new_input_ids
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:] loss_masks = [False] * len(input_ids)
new_loss_masks = [loss_mask_val] * len(new_input_ids) last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
input_ids += new_input_ids for j in range(last_assistant_index + 1, len(input_ids)):
loss_masks += new_loss_masks loss_masks[j] = True
else:
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
input_ids += new_input_ids
loss_masks += [loss_mask_val] * len(new_input_ids)
input_ids.append(151336) # EOS for chat input_ids.append(151336) # EOS for chat
loss_masks = [False, *loss_masks] loss_masks = [False, *loss_masks]
labels = [] labels = []
@ -277,7 +286,7 @@ def process_batch(
batched_input_ids.append(input_ids[:max_length]) batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length]) batched_labels.append(labels[:max_length])
del batched_conv, conv, input_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
torch.cuda.empty_cache() torch.cuda.empty_cache()
return {'input_ids': batched_input_ids, 'labels': batched_labels} return {'input_ids': batched_input_ids, 'labels': batched_labels}
@ -288,32 +297,47 @@ def process_batch_eval(
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
max_input_length: int, max_input_length: int,
max_output_length: int, max_output_length: int,
combine: bool,
) -> dict[str, list]: ) -> dict[str, list]:
batched_conv = batch['messages'] batched_conv = batch['messages']
batched_input_ids = [] batched_input_ids = []
batched_output_ids = [] batched_output_ids = []
for conv in batched_conv: for conv in batched_conv:
input_ids = [151331, 151333] if combine:
for message in conv: new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
if len(input_ids) >= max_input_length: input_ids = new_input_ids
break last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
else: output_prompt, output_ids = (
message = process_message(message) input_ids[:1],
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:] input_ids[last_assistant_index:],
if message['role'] == 'assistant': )
output_prompt, output_ids = ( output_ids.append(151336)
new_input_ids[:1], batched_input_ids.append(
new_input_ids[1:], input_ids[:max_input_length] + output_prompt[:1]
) )
output_ids.append(151336) batched_output_ids.append(output_ids[:max_output_length])
batched_input_ids.append( else:
input_ids[:max_input_length] + output_prompt[:1] input_ids = [151331, 151333]
) for message in conv:
batched_output_ids.append(output_ids[:max_output_length]) if len(input_ids) >= max_input_length:
input_ids += new_input_ids break
else:
message = process_message(message)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],
new_input_ids[1:],
)
output_ids.append(151336)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
del batched_conv, conv, input_ids, message, new_input_ids, output_prompt, output_ids del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
torch.cuda.empty_cache() torch.cuda.empty_cache()
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids} return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
@ -386,6 +410,7 @@ def main(
functools.partial( functools.partial(
process_batch, process_batch,
tokenizer=tokenizer, tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length, max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length, max_output_length=ft_config.max_output_length,
), ),
@ -397,6 +422,7 @@ def main(
functools.partial( functools.partial(
process_batch_eval, process_batch_eval,
tokenizer=tokenizer, tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length, max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length, max_output_length=ft_config.max_output_length,
), ),
@ -409,6 +435,7 @@ def main(
functools.partial( functools.partial(
process_batch_eval, process_batch_eval,
tokenizer=tokenizer, tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length, max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length, max_output_length=ft_config.max_output_length,
), ),