Merge pull request #538 from zhipuch/ftv
Support plain text fine-tuning
This commit is contained in:
commit
1d78cfb3d7
|
@ -447,6 +447,13 @@ def main(
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
|
|
||||||
|
ft_config.training_args.generation_config.pad_token_id = (
|
||||||
|
151329
|
||||||
|
)
|
||||||
|
ft_config.training_args.generation_config.eos_token_id = [
|
||||||
|
151329, 151336, 151338
|
||||||
|
]
|
||||||
|
|
||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=ft_config.training_args,
|
args=ft_config.training_args,
|
||||||
|
|
|
@ -30,7 +30,7 @@ from typing import Optional
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||||
|
img = Image.new('L', (224, 224), 0).convert('RGB')
|
||||||
|
|
||||||
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
|
@ -266,10 +266,12 @@ def process_batch(
|
||||||
loss_masks = [False, False]
|
loss_masks = [False, False]
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
|
if conv[0].get('image'):
|
||||||
|
conv[0]['image'] = Image.open(conv[0]['image']).convert('RGB')
|
||||||
|
else:
|
||||||
|
conv[0]['image'] = img
|
||||||
|
|
||||||
for message in conv:
|
for message in conv:
|
||||||
if message.get('image'):
|
|
||||||
image = Image.open(message['image']).convert('RGB')
|
|
||||||
message['image'] = image
|
|
||||||
|
|
||||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||||
new_input_ids_all = tokenizer.apply_chat_template(
|
new_input_ids_all = tokenizer.apply_chat_template(
|
||||||
|
@ -306,8 +308,7 @@ def process_batch(
|
||||||
batched_attention_mask.append(attention_mask[:max_length])
|
batched_attention_mask.append(attention_mask[:max_length])
|
||||||
batched_position_ids.append(position_ids[:max_length])
|
batched_position_ids.append(position_ids[:max_length])
|
||||||
batched_labels.append(labels[:max_length])
|
batched_labels.append(labels[:max_length])
|
||||||
if images is not None:
|
batched_images.append(images[0][0])
|
||||||
batched_images.append(images[0][0])
|
|
||||||
|
|
||||||
del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
|
del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -339,8 +340,10 @@ def process_batch_eval(
|
||||||
|
|
||||||
if conv[0].get('image'):
|
if conv[0].get('image'):
|
||||||
image = Image.open(conv[0]['image']).convert('RGB')
|
image = Image.open(conv[0]['image']).convert('RGB')
|
||||||
conv[0]['image'] = image
|
else:
|
||||||
|
image = img
|
||||||
|
|
||||||
|
conv[0]['image'] = image
|
||||||
new_input_ids_all = tokenizer.apply_chat_template(
|
new_input_ids_all = tokenizer.apply_chat_template(
|
||||||
conv,
|
conv,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
|
@ -493,6 +496,13 @@ def main(
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
|
|
||||||
|
ft_config.training_args.generation_config.pad_token_id = (
|
||||||
|
151329
|
||||||
|
)
|
||||||
|
ft_config.training_args.generation_config.eos_token_id = [
|
||||||
|
151329, 151336, 151338
|
||||||
|
]
|
||||||
|
|
||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=ft_config.training_args,
|
args=ft_config.training_args,
|
||||||
|
|
Loading…
Reference in New Issue