Merge pull request #538 from zhipuch/ftv

Support plain text fine-tuning
This commit is contained in:
zR 2024-09-04 18:11:15 +08:00 committed by GitHub
commit 1d78cfb3d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 8 deletions

View File

@ -447,6 +447,13 @@ def main(
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model.enable_input_require_grads() model.enable_input_require_grads()
ft_config.training_args.generation_config.pad_token_id = (
151329
)
ft_config.training_args.generation_config.eos_token_id = [
151329, 151336, 151338
]
trainer = Seq2SeqTrainer( trainer = Seq2SeqTrainer(
model=model, model=model,
args=ft_config.training_args, args=ft_config.training_args,

View File

@ -30,7 +30,7 @@ from typing import Optional
from PIL import Image from PIL import Image
app = typer.Typer(pretty_exceptions_show_locals=False) app = typer.Typer(pretty_exceptions_show_locals=False)
img = Image.new('L', (224, 224), 0).convert('RGB')
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq): class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None): def __call__(self, features, return_tensors=None):
@ -266,10 +266,12 @@ def process_batch(
loss_masks = [False, False] loss_masks = [False, False]
images = [] images = []
if conv[0].get('image'):
conv[0]['image'] = Image.open(conv[0]['image']).convert('RGB')
else:
conv[0]['image'] = img
for message in conv: for message in conv:
if message.get('image'):
image = Image.open(message['image']).convert('RGB')
message['image'] = image
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids_all = tokenizer.apply_chat_template( new_input_ids_all = tokenizer.apply_chat_template(
@ -306,8 +308,7 @@ def process_batch(
batched_attention_mask.append(attention_mask[:max_length]) batched_attention_mask.append(attention_mask[:max_length])
batched_position_ids.append(position_ids[:max_length]) batched_position_ids.append(position_ids[:max_length])
batched_labels.append(labels[:max_length]) batched_labels.append(labels[:max_length])
if images is not None: batched_images.append(images[0][0])
batched_images.append(images[0][0])
del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -339,8 +340,10 @@ def process_batch_eval(
if conv[0].get('image'): if conv[0].get('image'):
image = Image.open(conv[0]['image']).convert('RGB') image = Image.open(conv[0]['image']).convert('RGB')
conv[0]['image'] = image else:
image = img
conv[0]['image'] = image
new_input_ids_all = tokenizer.apply_chat_template( new_input_ids_all = tokenizer.apply_chat_template(
conv, conv,
tokenize=True, tokenize=True,
@ -493,6 +496,13 @@ def main(
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model.enable_input_require_grads() model.enable_input_require_grads()
ft_config.training_args.generation_config.pad_token_id = (
151329
)
ft_config.training_args.generation_config.eos_token_id = [
151329, 151336, 151338
]
trainer = Seq2SeqTrainer( trainer = Seq2SeqTrainer(
model=model, model=model,
args=ft_config.training_args, args=ft_config.training_args,