update finetune
This commit is contained in:
parent
0b80c79b8c
commit
4daff94302
|
@ -3,8 +3,11 @@ data_config:
|
||||||
val_file: dev.jsonl
|
val_file: dev.jsonl
|
||||||
test_file: dev.jsonl
|
test_file: dev.jsonl
|
||||||
num_proc: 1
|
num_proc: 1
|
||||||
|
|
||||||
|
combine: True
|
||||||
max_input_length: 512
|
max_input_length: 512
|
||||||
max_output_length: 512
|
max_output_length: 512
|
||||||
|
|
||||||
training_args:
|
training_args:
|
||||||
# see `transformers.Seq2SeqTrainingArguments`
|
# see `transformers.Seq2SeqTrainingArguments`
|
||||||
output_dir: ./output
|
output_dir: ./output
|
||||||
|
|
|
@ -3,8 +3,11 @@ data_config:
|
||||||
val_file: dev.jsonl
|
val_file: dev.jsonl
|
||||||
test_file: dev.jsonl
|
test_file: dev.jsonl
|
||||||
num_proc: 1
|
num_proc: 1
|
||||||
max_input_length: 128
|
|
||||||
max_output_length: 128
|
combine: True
|
||||||
|
max_input_length: 512
|
||||||
|
max_output_length: 512
|
||||||
|
|
||||||
training_args:
|
training_args:
|
||||||
# see `transformers.Seq2SeqTrainingArguments`
|
# see `transformers.Seq2SeqTrainingArguments`
|
||||||
output_dir: ./output
|
output_dir: ./output
|
||||||
|
|
|
@ -3,8 +3,11 @@ data_config:
|
||||||
val_file: dev.jsonl
|
val_file: dev.jsonl
|
||||||
test_file: dev.jsonl
|
test_file: dev.jsonl
|
||||||
num_proc: 1
|
num_proc: 1
|
||||||
max_input_length: 256
|
|
||||||
|
combine: True
|
||||||
|
max_input_length: 512
|
||||||
max_output_length: 512
|
max_output_length: 512
|
||||||
|
|
||||||
training_args:
|
training_args:
|
||||||
# see `transformers.Seq2SeqTrainingArguments`
|
# see `transformers.Seq2SeqTrainingArguments`
|
||||||
output_dir: ./output
|
output_dir: ./output
|
||||||
|
|
|
@ -129,6 +129,7 @@ class FinetuningConfig(object):
|
||||||
|
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_output_length: int
|
max_output_length: int
|
||||||
|
combine: bool
|
||||||
|
|
||||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||||
|
@ -250,21 +251,29 @@ def process_batch(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
max_input_length: int,
|
max_input_length: int,
|
||||||
max_output_length: int,
|
max_output_length: int,
|
||||||
|
combine: bool,
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list]:
|
||||||
batched_conv = batch['messages']
|
batched_conv = batch['messages']
|
||||||
batched_input_ids = []
|
batched_input_ids = []
|
||||||
batched_labels = []
|
batched_labels = []
|
||||||
|
|
||||||
for conv in batched_conv:
|
for conv in batched_conv:
|
||||||
input_ids = [151331, 151333]
|
input_ids = [151331, 151333]
|
||||||
loss_masks = [False, False]
|
loss_masks = [False, False]
|
||||||
for message in conv:
|
if combine:
|
||||||
message = process_message(message)
|
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
|
||||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
input_ids = new_input_ids
|
||||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
loss_masks = [False] * len(input_ids)
|
||||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
|
||||||
input_ids += new_input_ids
|
for j in range(last_assistant_index + 1, len(input_ids)):
|
||||||
loss_masks += new_loss_masks
|
loss_masks[j] = True
|
||||||
|
else:
|
||||||
|
for message in conv:
|
||||||
|
message = process_message(message)
|
||||||
|
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||||
|
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||||
|
input_ids += new_input_ids
|
||||||
|
loss_masks += [loss_mask_val] * len(new_input_ids)
|
||||||
|
|
||||||
input_ids.append(151336) # EOS for chat
|
input_ids.append(151336) # EOS for chat
|
||||||
loss_masks = [False, *loss_masks]
|
loss_masks = [False, *loss_masks]
|
||||||
labels = []
|
labels = []
|
||||||
|
@ -277,7 +286,7 @@ def process_batch(
|
||||||
batched_input_ids.append(input_ids[:max_length])
|
batched_input_ids.append(input_ids[:max_length])
|
||||||
batched_labels.append(labels[:max_length])
|
batched_labels.append(labels[:max_length])
|
||||||
|
|
||||||
del batched_conv, conv, input_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
|
del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return {'input_ids': batched_input_ids, 'labels': batched_labels}
|
return {'input_ids': batched_input_ids, 'labels': batched_labels}
|
||||||
|
@ -288,32 +297,47 @@ def process_batch_eval(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
max_input_length: int,
|
max_input_length: int,
|
||||||
max_output_length: int,
|
max_output_length: int,
|
||||||
|
combine: bool,
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list]:
|
||||||
batched_conv = batch['messages']
|
batched_conv = batch['messages']
|
||||||
batched_input_ids = []
|
batched_input_ids = []
|
||||||
batched_output_ids = []
|
batched_output_ids = []
|
||||||
|
|
||||||
for conv in batched_conv:
|
for conv in batched_conv:
|
||||||
input_ids = [151331, 151333]
|
if combine:
|
||||||
for message in conv:
|
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
|
||||||
if len(input_ids) >= max_input_length:
|
input_ids = new_input_ids
|
||||||
break
|
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
|
||||||
else:
|
output_prompt, output_ids = (
|
||||||
message = process_message(message)
|
input_ids[:1],
|
||||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
input_ids[last_assistant_index:],
|
||||||
if message['role'] == 'assistant':
|
)
|
||||||
output_prompt, output_ids = (
|
output_ids.append(151336)
|
||||||
new_input_ids[:1],
|
batched_input_ids.append(
|
||||||
new_input_ids[1:],
|
input_ids[:max_input_length] + output_prompt[:1]
|
||||||
)
|
)
|
||||||
output_ids.append(151336)
|
batched_output_ids.append(output_ids[:max_output_length])
|
||||||
batched_input_ids.append(
|
else:
|
||||||
input_ids[:max_input_length] + output_prompt[:1]
|
input_ids = [151331, 151333]
|
||||||
)
|
for message in conv:
|
||||||
batched_output_ids.append(output_ids[:max_output_length])
|
if len(input_ids) >= max_input_length:
|
||||||
input_ids += new_input_ids
|
break
|
||||||
|
else:
|
||||||
|
message = process_message(message)
|
||||||
|
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||||
|
if message['role'] == 'assistant':
|
||||||
|
output_prompt, output_ids = (
|
||||||
|
new_input_ids[:1],
|
||||||
|
new_input_ids[1:],
|
||||||
|
)
|
||||||
|
output_ids.append(151336)
|
||||||
|
batched_input_ids.append(
|
||||||
|
input_ids[:max_input_length] + output_prompt[:1]
|
||||||
|
)
|
||||||
|
batched_output_ids.append(output_ids[:max_output_length])
|
||||||
|
input_ids += new_input_ids
|
||||||
|
|
||||||
del batched_conv, conv, input_ids, message, new_input_ids, output_prompt, output_ids
|
del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
|
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
|
||||||
|
@ -386,6 +410,7 @@ def main(
|
||||||
functools.partial(
|
functools.partial(
|
||||||
process_batch,
|
process_batch,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
combine=ft_config.combine,
|
||||||
max_input_length=ft_config.max_input_length,
|
max_input_length=ft_config.max_input_length,
|
||||||
max_output_length=ft_config.max_output_length,
|
max_output_length=ft_config.max_output_length,
|
||||||
),
|
),
|
||||||
|
@ -397,6 +422,7 @@ def main(
|
||||||
functools.partial(
|
functools.partial(
|
||||||
process_batch_eval,
|
process_batch_eval,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
combine=ft_config.combine,
|
||||||
max_input_length=ft_config.max_input_length,
|
max_input_length=ft_config.max_input_length,
|
||||||
max_output_length=ft_config.max_output_length,
|
max_output_length=ft_config.max_output_length,
|
||||||
),
|
),
|
||||||
|
@ -409,6 +435,7 @@ def main(
|
||||||
functools.partial(
|
functools.partial(
|
||||||
process_batch_eval,
|
process_batch_eval,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
combine=ft_config.combine,
|
||||||
max_input_length=ft_config.max_input_length,
|
max_input_length=ft_config.max_input_length,
|
||||||
max_output_length=ft_config.max_output_length,
|
max_output_length=ft_config.max_output_length,
|
||||||
),
|
),
|
||||||
|
|
Loading…
Reference in New Issue