Support plain text fine-tuning
This commit is contained in:
parent
f0d67ff4a4
commit
cbe73627ff
|
@ -446,6 +446,13 @@ def main(
|
|||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
|
||||
ft_config.training_args.generation_config.pad_token_id = (
|
||||
151329
|
||||
)
|
||||
ft_config.training_args.generation_config.eos_token_id = [
|
||||
151329, 151336, 151338
|
||||
]
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
|
|
|
@ -30,7 +30,7 @@ from typing import Optional
|
|||
from PIL import Image
|
||||
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
||||
img = Image.new('L', (224, 224), 0).convert('RGB')
|
||||
|
||||
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
|
@ -265,11 +265,13 @@ def process_batch(
|
|||
position_ids = list(range(len(input_ids)))
|
||||
loss_masks = [False, False]
|
||||
images = []
|
||||
|
||||
if conv[0].get('image'):
|
||||
conv[0]['image'] = Image.open(conv[0]['image']).convert('RGB')
|
||||
else:
|
||||
conv[0]['image'] = img
|
||||
|
||||
for message in conv:
|
||||
if message.get('image'):
|
||||
image = Image.open(message['image']).convert('RGB')
|
||||
message['image'] = image
|
||||
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids_all = tokenizer.apply_chat_template(
|
||||
|
@ -306,8 +308,7 @@ def process_batch(
|
|||
batched_attention_mask.append(attention_mask[:max_length])
|
||||
batched_position_ids.append(position_ids[:max_length])
|
||||
batched_labels.append(labels[:max_length])
|
||||
if images is not None:
|
||||
batched_images.append(images[0][0])
|
||||
batched_images.append(images[0][0])
|
||||
|
||||
del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -339,8 +340,10 @@ def process_batch_eval(
|
|||
|
||||
if conv[0].get('image'):
|
||||
image = Image.open(conv[0]['image']).convert('RGB')
|
||||
conv[0]['image'] = image
|
||||
|
||||
else:
|
||||
image = img
|
||||
|
||||
conv[0]['image'] = image
|
||||
new_input_ids_all = tokenizer.apply_chat_template(
|
||||
conv,
|
||||
tokenize=True,
|
||||
|
@ -492,6 +495,13 @@ def main(
|
|||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
|
||||
ft_config.training_args.generation_config.pad_token_id = (
|
||||
151329
|
||||
)
|
||||
ft_config.training_args.generation_config.eos_token_id = [
|
||||
151329, 151336, 151338
|
||||
]
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
|
|
Loading…
Reference in New Issue