Support plain text fine-tuning

This commit is contained in:
zhipuch 2024-09-03 09:52:02 +00:00
parent f0d67ff4a4
commit cbe73627ff
2 changed files with 25 additions and 8 deletions

View File

@ -446,6 +446,13 @@ def main(
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
ft_config.training_args.generation_config.pad_token_id = (
151329
)
ft_config.training_args.generation_config.eos_token_id = [
151329, 151336, 151338
]
trainer = Seq2SeqTrainer(
model=model,

View File

@ -30,7 +30,7 @@ from typing import Optional
from PIL import Image
app = typer.Typer(pretty_exceptions_show_locals=False)
img = Image.new('L', (224, 224), 0).convert('RGB')
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
@ -265,11 +265,13 @@ def process_batch(
position_ids = list(range(len(input_ids)))
loss_masks = [False, False]
images = []
if conv[0].get('image'):
conv[0]['image'] = Image.open(conv[0]['image']).convert('RGB')
else:
conv[0]['image'] = img
for message in conv:
if message.get('image'):
image = Image.open(message['image']).convert('RGB')
message['image'] = image
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids_all = tokenizer.apply_chat_template(
@ -306,8 +308,7 @@ def process_batch(
batched_attention_mask.append(attention_mask[:max_length])
batched_position_ids.append(position_ids[:max_length])
batched_labels.append(labels[:max_length])
if images is not None:
batched_images.append(images[0][0])
batched_images.append(images[0][0])
del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
torch.cuda.empty_cache()
@ -339,8 +340,10 @@ def process_batch_eval(
if conv[0].get('image'):
image = Image.open(conv[0]['image']).convert('RGB')
conv[0]['image'] = image
else:
image = img
conv[0]['image'] = image
new_input_ids_all = tokenizer.apply_chat_template(
conv,
tokenize=True,
@ -492,6 +495,13 @@ def main(
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
ft_config.training_args.generation_config.pad_token_id = (
151329
)
ft_config.training_args.generation_config.eos_token_id = [
151329, 151336, 151338
]
trainer = Seq2SeqTrainer(
model=model,