glm4/finetune_demo/finetune_vision.py

580 lines
20 KiB
Python
Raw Permalink Normal View History

2024-07-01 17:00:28 +08:00
# -*- coding: utf-8 -*-
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, Split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
EvalPrediction,
GenerationConfig,
PreTrainedTokenizer,
Seq2SeqTrainingArguments,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional
from PIL import Image
2024-12-09 15:58:10 +08:00
# For Ascend NPU, please add this
# import torch_npu
# from torch_npu.contrib import transfer_to_npu
2024-07-01 17:00:28 +08:00
app = typer.Typer(pretty_exceptions_show_locals=False)
2024-12-09 15:58:10 +08:00
img = Image.new("L", (224, 224), 0).convert("RGB")
2024-07-01 17:00:28 +08:00
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
2024-12-09 15:58:10 +08:00
output_ids = (
[feature["output_ids"] for feature in features]
if "output_ids" in features[0].keys()
else None
)
2024-07-01 17:00:28 +08:00
if output_ids is not None:
max_output_length = max(len(out) for out in output_ids)
if self.pad_to_multiple_of is not None:
max_output_length = (
2024-12-09 15:58:10 +08:00
(max_output_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
2024-07-01 17:00:28 +08:00
)
for feature in features:
remainder = [self.tokenizer.pad_token_id] * (
2024-12-09 15:58:10 +08:00
max_output_length - len(feature["output_ids"])
2024-07-01 17:00:28 +08:00
)
2024-12-09 15:58:10 +08:00
if isinstance(feature["output_ids"], list):
feature["output_ids"] = feature["output_ids"] + remainder
2024-07-01 17:00:28 +08:00
else:
2024-12-09 15:58:10 +08:00
feature["output_ids"] = np.concatenate(
[feature["output_ids"], remainder]
2024-07-01 17:00:28 +08:00
).astype(np.int64)
return super().__call__(features, return_tensors)
class Seq2SeqTrainer(_Seq2SeqTrainer):
def prediction_step(
2024-12-09 15:58:10 +08:00
self,
model: nn.Module,
inputs: dict,
prediction_loss_only: bool,
ignore_keys=None,
**gen_kwargs,
2024-07-01 17:00:28 +08:00
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.no_grad():
if self.args.predict_with_generate:
2024-12-09 15:58:10 +08:00
output_ids = inputs.pop("output_ids", None)
2024-12-15 12:43:53 +08:00
del inputs["labels"]
2024-07-01 17:00:28 +08:00
loss, generated_tokens, labels = super().prediction_step(
model=model,
inputs=inputs,
prediction_loss_only=prediction_loss_only,
ignore_keys=ignore_keys,
2024-12-09 15:58:10 +08:00
**gen_kwargs,
2024-07-01 17:00:28 +08:00
)
if generated_tokens is not None:
2024-12-09 15:58:10 +08:00
generated_tokens = generated_tokens[:, inputs["input_ids"].size()[1] :]
2024-07-01 17:00:28 +08:00
if self.args.predict_with_generate:
labels = output_ids
del inputs, output_ids
torch.cuda.empty_cache()
return loss, generated_tokens, labels
@dc.dataclass
class DataConfig(object):
train_file: Optional[str] = None
val_file: Optional[str] = None
test_file: Optional[str] = None
num_proc: Optional[int] = None
@property
def data_format(self) -> str:
return Path(self.train_file).suffix
@property
def data_files(self) -> dict[NamedSplit, str]:
return {
split: data_file
for split, data_file in zip(
[Split.TRAIN, Split.VALIDATION, Split.TEST],
[self.train_file, self.val_file, self.test_file],
)
if data_file is not None
}
@dc.dataclass
class FinetuningConfig(object):
data_config: DataConfig
max_input_length: int
max_output_length: int
2024-07-25 10:14:51 +08:00
combine: bool
2024-09-06 17:31:27 +08:00
freezeV: bool
2024-07-01 17:00:28 +08:00
training_args: Seq2SeqTrainingArguments = dc.field(
2024-12-09 15:58:10 +08:00
default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output")
2024-07-01 17:00:28 +08:00
)
peft_config: Optional[PeftConfig] = None
def __post_init__(self):
if not self.training_args.do_eval or self.data_config.val_file is None:
self.training_args.do_eval = False
2024-12-09 15:58:10 +08:00
self.training_args.evaluation_strategy = "no"
2024-07-01 17:00:28 +08:00
self.data_config.val_file = None
else:
self.training_args.per_device_eval_batch_size = (
2024-12-09 15:58:10 +08:00
self.training_args.per_device_eval_batch_size
or self.training_args.per_device_train_batch_size
2024-07-01 17:00:28 +08:00
)
@classmethod
2024-12-09 15:58:10 +08:00
def from_dict(cls, **kwargs) -> "FinetuningConfig":
training_args = kwargs.get("training_args", None)
2024-07-01 17:00:28 +08:00
if training_args is not None and not isinstance(
2024-12-09 15:58:10 +08:00
training_args, Seq2SeqTrainingArguments
2024-07-01 17:00:28 +08:00
):
2024-12-09 15:58:10 +08:00
gen_config = training_args.get("generation_config")
2024-07-01 17:00:28 +08:00
if not isinstance(gen_config, GenerationConfig):
2024-12-09 15:58:10 +08:00
training_args["generation_config"] = GenerationConfig(**gen_config)
kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args)
2024-07-01 17:00:28 +08:00
2024-12-09 15:58:10 +08:00
data_config = kwargs.get("data_config")
2024-07-01 17:00:28 +08:00
if not isinstance(data_config, DataConfig):
2024-12-09 15:58:10 +08:00
kwargs["data_config"] = DataConfig(**data_config)
2024-07-01 17:00:28 +08:00
2024-12-09 15:58:10 +08:00
peft_config = kwargs.get("peft_config", None)
2024-07-01 17:00:28 +08:00
if peft_config is not None and not isinstance(peft_config, PeftConfig):
2024-12-09 15:58:10 +08:00
kwargs["peft_config"] = get_peft_config(config_dict=peft_config)
2024-07-01 17:00:28 +08:00
return cls(**kwargs)
@classmethod
2024-12-09 15:58:10 +08:00
def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig":
2024-07-01 17:00:28 +08:00
path = Path(path)
2024-12-09 15:58:10 +08:00
parser = yaml.YAML(typ="safe", pure=True)
2024-07-01 17:00:28 +08:00
parser.indent(mapping=2, offset=2, sequence=4)
parser.default_flow_style = False
kwargs = parser.load(path)
return cls.from_dict(**kwargs)
def _load_datasets(
2024-12-09 15:58:10 +08:00
data_dir: str,
data_format: str,
data_files: dict[NamedSplit, str],
num_proc: Optional[int],
2024-07-01 17:00:28 +08:00
) -> DatasetDict:
2024-12-09 15:58:10 +08:00
if data_format == ".jsonl":
2024-07-01 17:00:28 +08:00
dataset_dct = load_dataset(
data_dir,
data_files=data_files,
split=None,
num_proc=num_proc,
)
else:
raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
return dataset_dct
class DataManager(object):
def __init__(self, data_dir: str, data_config: DataConfig):
self._num_proc = data_config.num_proc
self._dataset_dct = _load_datasets(
data_dir,
data_config.data_format,
data_config.data_files,
self._num_proc,
)
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
return self._dataset_dct.get(split, None)
def get_dataset(
2024-12-09 15:58:10 +08:00
self,
split: NamedSplit,
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
batched: bool = True,
remove_orig_columns: bool = True,
2024-07-01 17:00:28 +08:00
) -> Optional[Dataset]:
orig_dataset = self._get_dataset(split)
if orig_dataset is None:
return
if remove_orig_columns:
remove_columns = orig_dataset.column_names
else:
remove_columns = None
return orig_dataset.map(
process_fn,
batched=batched,
remove_columns=remove_columns,
num_proc=self._num_proc,
2024-07-04 11:37:00 +08:00
# This is default params of orig_dataset.map, and you can change it smaller
# https://github.com/THUDM/GLM-4/issues/277
writer_batch_size=1000,
batch_size=1000,
2024-07-01 17:00:28 +08:00
)
def process_batch(
2024-12-09 15:58:10 +08:00
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
2024-07-01 17:00:28 +08:00
) -> dict[str, list]:
2024-12-09 15:58:10 +08:00
batched_conv = batch["messages"]
2024-07-01 17:00:28 +08:00
batched_input_ids = []
batched_attention_mask = []
batched_position_ids = []
batched_labels = []
batched_images = []
max_length = max_input_length + max_output_length
for conv in batched_conv:
input_ids = [151331, 151333]
attention_mask = [1, 1]
position_ids = list(range(len(input_ids)))
loss_masks = [False, False]
images = []
2024-12-09 15:58:10 +08:00
if conv[0].get("image"):
conv[0]["image"] = Image.open(conv[0]["image"]).convert("RGB")
2024-09-03 17:52:02 +08:00
else:
2024-12-09 15:58:10 +08:00
conv[0]["image"] = img
2024-07-01 17:00:28 +08:00
for message in conv:
2024-12-09 15:58:10 +08:00
loss_mask_val = (
False if message["role"] in ("system", "user", "observation") else True
)
2024-07-01 17:00:28 +08:00
new_input_ids_all = tokenizer.apply_chat_template(
2024-12-09 15:58:10 +08:00
[message], tokenize=True, return_dict=True, padding=True
)
new_input_ids = new_input_ids_all["input_ids"][0][2:]
new_attention_mask = new_input_ids_all["attention_mask"][0][2:]
new_position_ids = list(
range(position_ids[-1] + 1, position_ids[-1] + 1 + len(new_input_ids))
2024-07-01 17:00:28 +08:00
)
2024-12-09 15:58:10 +08:00
if message.get("image"): # Only One Image
images.append(new_input_ids_all["images"])
2024-07-01 17:00:28 +08:00
new_loss_masks = [loss_mask_val] * len(new_input_ids)
input_ids += new_input_ids
attention_mask += new_attention_mask
position_ids += new_position_ids
loss_masks += new_loss_masks
input_ids.append(151336) # EOS
attention_mask.append(1)
position_ids.append(len(position_ids))
loss_masks.append(False)
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
labels.append(input_id)
else:
labels.append(-100)
batched_input_ids.append(input_ids[:max_length])
batched_attention_mask.append(attention_mask[:max_length])
batched_position_ids.append(position_ids[:max_length])
batched_labels.append(labels[:max_length])
2024-09-03 17:52:02 +08:00
batched_images.append(images[0][0])
2024-07-01 17:00:28 +08:00
2024-12-09 15:58:10 +08:00
del (
batched_conv,
conv,
input_ids,
attention_mask,
position_ids,
loss_masks,
message,
new_input_ids,
new_loss_masks,
labels,
input_id,
mask,
)
2024-07-01 17:00:28 +08:00
torch.cuda.empty_cache()
return {
2024-12-09 15:58:10 +08:00
"input_ids": batched_input_ids,
"attention_mask": batched_attention_mask,
"position_ids": batched_position_ids,
"labels": batched_labels,
"images": batched_images,
2024-07-01 17:00:28 +08:00
}
def process_batch_eval(
2024-12-09 15:58:10 +08:00
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
2024-07-01 17:00:28 +08:00
) -> dict[str, list]:
2024-12-09 15:58:10 +08:00
batched_conv = batch["messages"]
2024-07-01 17:00:28 +08:00
batched_input_ids = []
batched_attention_mask = []
batched_position_ids = []
batched_output_ids = []
batched_images = []
for conv in batched_conv:
2024-12-09 15:58:10 +08:00
if conv[0].get("image"):
image = Image.open(conv[0]["image"]).convert("RGB")
2024-09-03 17:52:02 +08:00
else:
2024-12-09 15:58:10 +08:00
image = img
conv[0]["image"] = image
2024-07-01 17:00:28 +08:00
new_input_ids_all = tokenizer.apply_chat_template(
2024-12-09 15:58:10 +08:00
conv, tokenize=True, return_dict=True, padding=True
2024-07-01 17:00:28 +08:00
)
2024-12-09 15:58:10 +08:00
input_ids = new_input_ids_all["input_ids"][0]
attention_mask = new_input_ids_all["attention_mask"][0]
2024-07-01 17:00:28 +08:00
position_ids = list(range(len(input_ids)))
dialogue_parts = [0]
for idx, token_id in enumerate(input_ids):
if token_id == 151337:
dialogue_parts.append(idx + 1)
if not dialogue_parts or dialogue_parts[-1] != len(input_ids):
dialogue_parts.append(len(input_ids))
2024-11-11 17:33:43 +08:00
# Split the conversation into multiple dialogue segments
2024-07-01 17:00:28 +08:00
for end_idx in range(1, len(dialogue_parts)):
2024-12-09 15:58:10 +08:00
input_segment = input_ids[: dialogue_parts[end_idx]]
attention_segment = attention_mask[: dialogue_parts[end_idx]]
position_segment = position_ids[: dialogue_parts[end_idx]]
output_segment = input_ids[
dialogue_parts[end_idx - 1] : dialogue_parts[end_idx]
]
2024-07-01 17:00:28 +08:00
output_segment.append(151336) # Add EOS token
batched_input_ids.append(input_segment[:max_input_length])
batched_attention_mask.append(attention_segment[:max_input_length])
batched_position_ids.append(position_segment[:max_input_length])
batched_output_ids.append(output_segment[:max_output_length])
2024-12-09 15:58:10 +08:00
batched_images.append(new_input_ids_all["images"][0])
del (
batched_conv,
input_ids,
attention_mask,
position_ids,
new_input_ids_all,
output_segment,
)
2024-07-01 17:00:28 +08:00
torch.cuda.empty_cache()
return {
2024-12-09 15:58:10 +08:00
"input_ids": batched_input_ids,
"attention_mask": batched_attention_mask,
"position_ids": batched_position_ids,
"output_ids": batched_output_ids,
"images": batched_images,
2024-07-01 17:00:28 +08:00
}
def load_tokenizer_and_model(
2024-12-09 15:58:10 +08:00
model_dir: str,
peft_config: Optional[PeftConfig] = None,
2024-07-01 17:00:28 +08:00
):
2024-12-09 15:58:10 +08:00
tokenizer = AutoTokenizer.from_pretrained(
model_dir, padding_side="left", trust_remote_code=True
)
2024-07-01 17:00:28 +08:00
if peft_config is not None:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
use_cache=False,
2024-12-09 15:58:10 +08:00
torch_dtype=torch.bfloat16, # Must use BFloat 16
2024-07-01 17:00:28 +08:00
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
use_cache=False,
2024-12-09 15:58:10 +08:00
torch_dtype=torch.bfloat16,
2024-07-01 17:00:28 +08:00
)
return tokenizer, model
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
batched_pred_ids, batched_label_ids = eval_preds
2024-12-09 15:58:10 +08:00
batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
2024-07-01 17:00:28 +08:00
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
2024-12-09 15:58:10 +08:00
pred_txt = tokenizer.decode(pred_ids).strip()
label_txt = tokenizer.decode(label_ids).strip()
2024-07-01 17:00:28 +08:00
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()
2024-12-09 15:58:10 +08:00
scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
2024-07-01 17:00:28 +08:00
for k, v in scores[0].items():
2024-12-09 15:58:10 +08:00
metrics_dct[k].append(round(v["f"] * 100, 4))
metrics_dct["bleu-4"].append(
sentence_bleu(
[label_tokens],
pred_tokens,
smoothing_function=SmoothingFunction().method3,
)
)
2024-07-01 17:00:28 +08:00
return {k: np.mean(v) for k, v in metrics_dct.items()}
@app.command()
def main(
2024-12-09 15:58:10 +08:00
data_dir: Annotated[str, typer.Argument(help="")],
model_dir: Annotated[
str,
typer.Argument(
help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file."
2024-07-01 17:00:28 +08:00
),
2024-12-09 15:58:10 +08:00
],
config_file: Annotated[str, typer.Argument(help="")],
auto_resume_from_checkpoint: str = typer.Argument(
default="",
help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training",
),
2024-07-01 17:00:28 +08:00
):
ft_config = FinetuningConfig.from_file(config_file)
2024-12-09 15:58:10 +08:00
tokenizer, model = load_tokenizer_and_model(
model_dir, peft_config=ft_config.peft_config
)
2024-09-06 17:31:27 +08:00
if ft_config.freezeV:
for param in model.transformer.vision.parameters():
param.requires_grad = False
2024-07-01 17:00:28 +08:00
data_manager = DataManager(data_dir, ft_config.data_config)
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
2024-12-09 15:58:10 +08:00
combine=ft_config.combine, # Not use now
2024-07-01 17:00:28 +08:00
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
2024-12-09 15:58:10 +08:00
print("train_dataset:", train_dataset)
2024-07-02 01:00:12 +08:00
2024-07-01 17:00:28 +08:00
val_dataset = data_manager.get_dataset(
Split.VALIDATION,
functools.partial(
process_batch_eval,
2024-07-25 10:14:51 +08:00
combine=ft_config.combine,
2024-07-01 17:00:28 +08:00
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
2024-07-02 01:00:12 +08:00
2024-07-01 17:00:28 +08:00
if val_dataset is not None:
2024-12-09 15:58:10 +08:00
print("val_dataset:", val_dataset)
2024-07-01 17:00:28 +08:00
test_dataset = data_manager.get_dataset(
Split.TEST,
functools.partial(
process_batch_eval,
2024-07-25 10:14:51 +08:00
combine=ft_config.combine,
2024-07-01 17:00:28 +08:00
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if test_dataset is not None:
2024-12-09 15:58:10 +08:00
print("test_dataset:", test_dataset)
2024-07-01 17:00:28 +08:00
2024-12-09 15:58:10 +08:00
ft_config.training_args.generation_config.pad_token_id = 151329
ft_config.training_args.generation_config.eos_token_id = [151329, 151336, 151338]
2024-07-01 17:00:28 +08:00
trainer = Seq2SeqTrainer(
model=model,
args=ft_config.training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
2024-12-09 15:58:10 +08:00
padding="longest",
return_tensors="pt",
2024-07-01 17:00:28 +08:00
),
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
)
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
trainer.train()
else:
output_dir = ft_config.training_args.output_dir
dirlist = os.listdir(output_dir)
checkpoint_sn = 0
for checkpoint_str in dirlist:
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
if checkpoint > checkpoint_sn:
checkpoint_sn = checkpoint
if auto_resume_from_checkpoint.upper() == "YES":
if checkpoint_sn > 0:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
2024-12-09 15:58:10 +08:00
checkpoint_directory = os.path.join(
output_dir, "checkpoint-" + str(checkpoint_sn)
)
2024-07-01 17:00:28 +08:00
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
trainer.train()
else:
if auto_resume_from_checkpoint.isdigit():
if int(auto_resume_from_checkpoint) > 0:
checkpoint_sn = int(auto_resume_from_checkpoint)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
2024-12-09 15:58:10 +08:00
checkpoint_directory = os.path.join(
output_dir, "checkpoint-" + str(checkpoint_sn)
)
2024-07-01 17:00:28 +08:00
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
2024-12-09 15:58:10 +08:00
print(
auto_resume_from_checkpoint,
"The specified checkpoint sn("
+ auto_resume_from_checkpoint
+ ") has not been saved. Please search for the correct checkpoint in the model output directory",
)
2024-07-01 17:00:28 +08:00
if test_dataset is not None:
trainer.predict(test_dataset)
2024-12-09 15:58:10 +08:00
if __name__ == "__main__":
2024-07-01 17:00:28 +08:00
app()