update finetune
This commit is contained in:
parent
0b80c79b8c
commit
4daff94302
|
@ -3,8 +3,11 @@ data_config:
|
|||
val_file: dev.jsonl
|
||||
test_file: dev.jsonl
|
||||
num_proc: 1
|
||||
|
||||
combine: True
|
||||
max_input_length: 512
|
||||
max_output_length: 512
|
||||
|
||||
training_args:
|
||||
# see `transformers.Seq2SeqTrainingArguments`
|
||||
output_dir: ./output
|
||||
|
|
|
@ -3,8 +3,11 @@ data_config:
|
|||
val_file: dev.jsonl
|
||||
test_file: dev.jsonl
|
||||
num_proc: 1
|
||||
max_input_length: 128
|
||||
max_output_length: 128
|
||||
|
||||
combine: True
|
||||
max_input_length: 512
|
||||
max_output_length: 512
|
||||
|
||||
training_args:
|
||||
# see `transformers.Seq2SeqTrainingArguments`
|
||||
output_dir: ./output
|
||||
|
|
|
@ -3,8 +3,11 @@ data_config:
|
|||
val_file: dev.jsonl
|
||||
test_file: dev.jsonl
|
||||
num_proc: 1
|
||||
max_input_length: 256
|
||||
|
||||
combine: True
|
||||
max_input_length: 512
|
||||
max_output_length: 512
|
||||
|
||||
training_args:
|
||||
# see `transformers.Seq2SeqTrainingArguments`
|
||||
output_dir: ./output
|
||||
|
|
|
@ -129,6 +129,7 @@ class FinetuningConfig(object):
|
|||
|
||||
max_input_length: int
|
||||
max_output_length: int
|
||||
combine: bool
|
||||
|
||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||
|
@ -250,21 +251,29 @@ def process_batch(
|
|||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int,
|
||||
max_output_length: int,
|
||||
combine: bool,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_input_ids = []
|
||||
batched_labels = []
|
||||
|
||||
for conv in batched_conv:
|
||||
input_ids = [151331, 151333]
|
||||
loss_masks = [False, False]
|
||||
for message in conv:
|
||||
message = process_message(message)
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
input_ids += new_input_ids
|
||||
loss_masks += new_loss_masks
|
||||
if combine:
|
||||
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
|
||||
input_ids = new_input_ids
|
||||
loss_masks = [False] * len(input_ids)
|
||||
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
|
||||
for j in range(last_assistant_index + 1, len(input_ids)):
|
||||
loss_masks[j] = True
|
||||
else:
|
||||
for message in conv:
|
||||
message = process_message(message)
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
input_ids += new_input_ids
|
||||
loss_masks += [loss_mask_val] * len(new_input_ids)
|
||||
|
||||
input_ids.append(151336) # EOS for chat
|
||||
loss_masks = [False, *loss_masks]
|
||||
labels = []
|
||||
|
@ -277,7 +286,7 @@ def process_batch(
|
|||
batched_input_ids.append(input_ids[:max_length])
|
||||
batched_labels.append(labels[:max_length])
|
||||
|
||||
del batched_conv, conv, input_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
|
||||
del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {'input_ids': batched_input_ids, 'labels': batched_labels}
|
||||
|
@ -288,32 +297,47 @@ def process_batch_eval(
|
|||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int,
|
||||
max_output_length: int,
|
||||
combine: bool,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_input_ids = []
|
||||
batched_output_ids = []
|
||||
|
||||
for conv in batched_conv:
|
||||
input_ids = [151331, 151333]
|
||||
for message in conv:
|
||||
if len(input_ids) >= max_input_length:
|
||||
break
|
||||
else:
|
||||
message = process_message(message)
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
if message['role'] == 'assistant':
|
||||
output_prompt, output_ids = (
|
||||
new_input_ids[:1],
|
||||
new_input_ids[1:],
|
||||
)
|
||||
output_ids.append(151336)
|
||||
batched_input_ids.append(
|
||||
input_ids[:max_input_length] + output_prompt[:1]
|
||||
)
|
||||
batched_output_ids.append(output_ids[:max_output_length])
|
||||
input_ids += new_input_ids
|
||||
if combine:
|
||||
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
|
||||
input_ids = new_input_ids
|
||||
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
|
||||
output_prompt, output_ids = (
|
||||
input_ids[:1],
|
||||
input_ids[last_assistant_index:],
|
||||
)
|
||||
output_ids.append(151336)
|
||||
batched_input_ids.append(
|
||||
input_ids[:max_input_length] + output_prompt[:1]
|
||||
)
|
||||
batched_output_ids.append(output_ids[:max_output_length])
|
||||
else:
|
||||
input_ids = [151331, 151333]
|
||||
for message in conv:
|
||||
if len(input_ids) >= max_input_length:
|
||||
break
|
||||
else:
|
||||
message = process_message(message)
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
if message['role'] == 'assistant':
|
||||
output_prompt, output_ids = (
|
||||
new_input_ids[:1],
|
||||
new_input_ids[1:],
|
||||
)
|
||||
output_ids.append(151336)
|
||||
batched_input_ids.append(
|
||||
input_ids[:max_input_length] + output_prompt[:1]
|
||||
)
|
||||
batched_output_ids.append(output_ids[:max_output_length])
|
||||
input_ids += new_input_ids
|
||||
|
||||
del batched_conv, conv, input_ids, message, new_input_ids, output_prompt, output_ids
|
||||
del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
|
||||
|
@ -386,6 +410,7 @@ def main(
|
|||
functools.partial(
|
||||
process_batch,
|
||||
tokenizer=tokenizer,
|
||||
combine=ft_config.combine,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
|
@ -397,6 +422,7 @@ def main(
|
|||
functools.partial(
|
||||
process_batch_eval,
|
||||
tokenizer=tokenizer,
|
||||
combine=ft_config.combine,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
|
@ -409,6 +435,7 @@ def main(
|
|||
functools.partial(
|
||||
process_batch_eval,
|
||||
tokenizer=tokenizer,
|
||||
combine=ft_config.combine,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
|
|
Loading…
Reference in New Issue