update finetune

This commit is contained in:
zR 2024-07-19 22:04:53 +08:00
parent 0b80c79b8c
commit 4daff94302
4 changed files with 67 additions and 31 deletions

View File

@ -3,8 +3,11 @@ data_config:
val_file: dev.jsonl
test_file: dev.jsonl
num_proc: 1
combine: True
max_input_length: 512
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output

View File

@ -3,8 +3,11 @@ data_config:
val_file: dev.jsonl
test_file: dev.jsonl
num_proc: 1
max_input_length: 128
max_output_length: 128
combine: True
max_input_length: 512
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output

View File

@ -3,8 +3,11 @@ data_config:
val_file: dev.jsonl
test_file: dev.jsonl
num_proc: 1
max_input_length: 256
combine: True
max_input_length: 512
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output

View File

@ -129,6 +129,7 @@ class FinetuningConfig(object):
max_input_length: int
max_output_length: int
combine: bool
training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
@ -250,21 +251,29 @@ def process_batch(
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_labels = []
for conv in batched_conv:
input_ids = [151331, 151333]
loss_masks = [False, False]
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
new_loss_masks = [loss_mask_val] * len(new_input_ids)
input_ids += new_input_ids
loss_masks += new_loss_masks
if combine:
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
input_ids = new_input_ids
loss_masks = [False] * len(input_ids)
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
for j in range(last_assistant_index + 1, len(input_ids)):
loss_masks[j] = True
else:
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
input_ids += new_input_ids
loss_masks += [loss_mask_val] * len(new_input_ids)
input_ids.append(151336) # EOS for chat
loss_masks = [False, *loss_masks]
labels = []
@ -277,7 +286,7 @@ def process_batch(
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
del batched_conv, conv, input_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
torch.cuda.empty_cache()
return {'input_ids': batched_input_ids, 'labels': batched_labels}
@ -288,32 +297,47 @@ def process_batch_eval(
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_output_ids = []
for conv in batched_conv:
input_ids = [151331, 151333]
for message in conv:
if len(input_ids) >= max_input_length:
break
else:
message = process_message(message)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],
new_input_ids[1:],
)
output_ids.append(151336)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
if combine:
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
input_ids = new_input_ids
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
output_prompt, output_ids = (
input_ids[:1],
input_ids[last_assistant_index:],
)
output_ids.append(151336)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
else:
input_ids = [151331, 151333]
for message in conv:
if len(input_ids) >= max_input_length:
break
else:
message = process_message(message)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],
new_input_ids[1:],
)
output_ids.append(151336)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
del batched_conv, conv, input_ids, message, new_input_ids, output_prompt, output_ids
del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
torch.cuda.empty_cache()
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
@ -386,6 +410,7 @@ def main(
functools.partial(
process_batch,
tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
@ -397,6 +422,7 @@ def main(
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
@ -409,6 +435,7 @@ def main(
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
combine=ft_config.combine,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),