glm4/finetune_demo/finetune.py

518 lines
18 KiB
Python
Raw Permalink Normal View History

2024-06-05 10:22:16 +08:00
# -*- coding: utf-8 -*-
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
2024-06-19 11:22:31 +08:00
from typing import Annotated, Any, Union
2024-06-05 10:22:16 +08:00
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
2024-06-19 11:22:31 +08:00
from datasets import Dataset, Split
2024-06-05 10:22:16 +08:00
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
EvalPrediction,
GenerationConfig,
PreTrainedTokenizer,
Seq2SeqTrainingArguments,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
2024-06-19 11:22:31 +08:00
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional
2024-06-05 10:22:16 +08:00
2024-12-09 15:58:10 +08:00
# For Ascend NPU, please add this
# import torch_npu
# from torch_npu.contrib import transfer_to_npu
2024-06-05 10:22:16 +08:00
app = typer.Typer(pretty_exceptions_show_locals=False)
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
2024-12-09 15:58:10 +08:00
output_ids = (
[feature["output_ids"] for feature in features]
if "output_ids" in features[0].keys()
else None
)
2024-06-05 10:22:16 +08:00
if output_ids is not None:
max_output_length = max(len(out) for out in output_ids)
if self.pad_to_multiple_of is not None:
max_output_length = (
2024-12-09 15:58:10 +08:00
(max_output_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
2024-06-05 10:22:16 +08:00
)
for feature in features:
remainder = [self.tokenizer.pad_token_id] * (
2024-12-09 15:58:10 +08:00
max_output_length - len(feature["output_ids"])
2024-06-05 10:22:16 +08:00
)
2024-12-09 15:58:10 +08:00
if isinstance(feature["output_ids"], list):
feature["output_ids"] = feature["output_ids"] + remainder
2024-06-05 10:22:16 +08:00
else:
2024-12-09 15:58:10 +08:00
feature["output_ids"] = np.concatenate(
[feature["output_ids"], remainder]
2024-06-05 10:22:16 +08:00
).astype(np.int64)
return super().__call__(features, return_tensors)
class Seq2SeqTrainer(_Seq2SeqTrainer):
def prediction_step(
2024-12-09 15:58:10 +08:00
self,
model: nn.Module,
inputs: dict[str, Any],
prediction_loss_only: bool,
ignore_keys=None,
**gen_kwargs,
2024-06-05 10:22:16 +08:00
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
2024-06-19 11:22:31 +08:00
with torch.no_grad(): # Ensure no gradient computation
if self.args.predict_with_generate:
2024-12-09 15:58:10 +08:00
output_ids = inputs.pop("output_ids")
input_ids = inputs["input_ids"]
2024-06-19 11:22:31 +08:00
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
)
2024-12-09 15:58:10 +08:00
generated_tokens = generated_tokens[:, input_ids.size()[1] :]
2024-06-19 11:22:31 +08:00
labels = output_ids
del inputs, input_ids, output_ids
torch.cuda.empty_cache()
2024-06-05 10:22:16 +08:00
return loss, generated_tokens, labels
@dc.dataclass
class DataConfig(object):
train_file: Optional[str] = None
val_file: Optional[str] = None
test_file: Optional[str] = None
num_proc: Optional[int] = None
@property
def data_format(self) -> str:
return Path(self.train_file).suffix
@property
def data_files(self) -> dict[NamedSplit, str]:
return {
split: data_file
for split, data_file in zip(
[Split.TRAIN, Split.VALIDATION, Split.TEST],
[self.train_file, self.val_file, self.test_file],
)
if data_file is not None
}
@dc.dataclass
class FinetuningConfig(object):
data_config: DataConfig
max_input_length: int
max_output_length: int
2024-07-19 22:04:53 +08:00
combine: bool
freezeV: bool
2024-06-05 10:22:16 +08:00
training_args: Seq2SeqTrainingArguments = dc.field(
2024-12-09 15:58:10 +08:00
default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output")
2024-06-05 10:22:16 +08:00
)
peft_config: Optional[PeftConfig] = None
def __post_init__(self):
if not self.training_args.do_eval or self.data_config.val_file is None:
self.training_args.do_eval = False
2024-12-09 15:58:10 +08:00
self.training_args.evaluation_strategy = "no"
2024-06-05 10:22:16 +08:00
self.data_config.val_file = None
else:
self.training_args.per_device_eval_batch_size = (
2024-12-09 15:58:10 +08:00
self.training_args.per_device_eval_batch_size
or self.training_args.per_device_train_batch_size
2024-06-05 10:22:16 +08:00
)
@classmethod
2024-12-09 15:58:10 +08:00
def from_dict(cls, **kwargs) -> "FinetuningConfig":
training_args = kwargs.get("training_args", None)
2024-06-05 10:22:16 +08:00
if training_args is not None and not isinstance(
2024-12-09 15:58:10 +08:00
training_args, Seq2SeqTrainingArguments
2024-06-05 10:22:16 +08:00
):
2024-12-09 15:58:10 +08:00
gen_config = training_args.get("generation_config")
2024-06-05 10:22:16 +08:00
if not isinstance(gen_config, GenerationConfig):
2024-12-09 15:58:10 +08:00
training_args["generation_config"] = GenerationConfig(**gen_config)
kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args)
2024-06-05 10:22:16 +08:00
2024-12-09 15:58:10 +08:00
data_config = kwargs.get("data_config")
2024-06-05 10:22:16 +08:00
if not isinstance(data_config, DataConfig):
2024-12-09 15:58:10 +08:00
kwargs["data_config"] = DataConfig(**data_config)
2024-06-05 10:22:16 +08:00
2024-12-09 15:58:10 +08:00
peft_config = kwargs.get("peft_config", None)
2024-06-05 10:22:16 +08:00
if peft_config is not None and not isinstance(peft_config, PeftConfig):
2024-12-09 15:58:10 +08:00
kwargs["peft_config"] = get_peft_config(config_dict=peft_config)
2024-06-05 10:22:16 +08:00
return cls(**kwargs)
@classmethod
2024-12-09 15:58:10 +08:00
def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig":
2024-06-05 10:22:16 +08:00
path = Path(path)
2024-12-09 15:58:10 +08:00
parser = yaml.YAML(typ="safe", pure=True)
2024-06-05 10:22:16 +08:00
parser.indent(mapping=2, offset=2, sequence=4)
parser.default_flow_style = False
kwargs = parser.load(path)
return cls.from_dict(**kwargs)
def _load_datasets(
2024-12-09 15:58:10 +08:00
data_dir: str,
data_format: str,
data_files: dict[NamedSplit, str],
num_proc: Optional[int],
2024-06-05 10:22:16 +08:00
) -> DatasetDict:
2024-12-09 15:58:10 +08:00
if data_format == ".jsonl":
2024-06-05 10:22:16 +08:00
dataset_dct = load_dataset(
data_dir,
data_files=data_files,
split=None,
num_proc=num_proc,
)
else:
raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
return dataset_dct
class DataManager(object):
def __init__(self, data_dir: str, data_config: DataConfig):
self._num_proc = data_config.num_proc
self._dataset_dct = _load_datasets(
data_dir,
data_config.data_format,
data_config.data_files,
self._num_proc,
)
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
return self._dataset_dct.get(split, None)
def get_dataset(
2024-12-09 15:58:10 +08:00
self,
split: NamedSplit,
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
batched: bool = True,
remove_orig_columns: bool = True,
2024-06-05 10:22:16 +08:00
) -> Optional[Dataset]:
orig_dataset = self._get_dataset(split)
if orig_dataset is None:
return
if remove_orig_columns:
remove_columns = orig_dataset.column_names
else:
remove_columns = None
return orig_dataset.map(
process_fn,
batched=batched,
remove_columns=remove_columns,
num_proc=self._num_proc,
)
def process_message(message):
2024-12-09 15:58:10 +08:00
if "tools" in message and message["role"] == "system":
for tool in message["tools"]:
parameters = tool["function"]["parameters"]["properties"]
tool["function"]["parameters"]["properties"] = {
k: v for k, v in parameters.items() if v is not None
}
elif "tools" in message:
del message["tools"]
2024-06-05 10:22:16 +08:00
return message
def process_batch(
2024-12-09 15:58:10 +08:00
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
2024-06-05 10:22:16 +08:00
) -> dict[str, list]:
2024-12-09 15:58:10 +08:00
batched_conv = batch["messages"]
2024-06-05 10:22:16 +08:00
batched_input_ids = []
batched_labels = []
for conv in batched_conv:
input_ids = [151331, 151333]
loss_masks = [False, False]
2024-07-19 22:04:53 +08:00
if combine:
2024-12-09 15:58:10 +08:00
new_input_ids = tokenizer.apply_chat_template(
conv, tokenize=True, return_dict=False
)
2024-07-19 22:04:53 +08:00
input_ids = new_input_ids
loss_masks = [False] * len(input_ids)
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
for j in range(last_assistant_index + 1, len(input_ids)):
loss_masks[j] = True
else:
for message in conv:
message = process_message(message)
2024-12-09 15:58:10 +08:00
loss_mask_val = (
False
if message["role"] in ("system", "user", "observation")
else True
)
new_input_ids = tokenizer.apply_chat_template(
[message], tokenize=True, return_dict=False
)[2:]
2024-07-19 22:04:53 +08:00
input_ids += new_input_ids
loss_masks += [loss_mask_val] * len(new_input_ids)
2024-07-01 17:00:28 +08:00
input_ids.append(151336) # EOS for chat
2024-06-05 10:22:16 +08:00
loss_masks = [False, *loss_masks]
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
labels.append(input_id)
else:
labels.append(-100)
max_length = max_input_length + max_output_length + 1
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
2024-06-19 11:22:31 +08:00
2024-07-19 22:04:53 +08:00
del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
2024-06-19 11:22:31 +08:00
torch.cuda.empty_cache()
2024-12-09 15:58:10 +08:00
return {"input_ids": batched_input_ids, "labels": batched_labels}
2024-06-05 10:22:16 +08:00
def process_batch_eval(
2024-12-09 15:58:10 +08:00
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
combine: bool,
2024-06-05 10:22:16 +08:00
) -> dict[str, list]:
2024-12-09 15:58:10 +08:00
batched_conv = batch["messages"]
2024-06-05 10:22:16 +08:00
batched_input_ids = []
batched_output_ids = []
for conv in batched_conv:
2024-07-19 22:04:53 +08:00
if combine:
2024-12-09 15:58:10 +08:00
new_input_ids = tokenizer.apply_chat_template(
conv, tokenize=True, return_dict=False
)
2024-07-19 22:04:53 +08:00
input_ids = new_input_ids
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
output_prompt, output_ids = (
input_ids[:1],
input_ids[last_assistant_index:],
)
output_ids.append(151336)
2024-12-09 15:58:10 +08:00
batched_input_ids.append(input_ids[:max_input_length] + output_prompt[:1])
2024-07-19 22:04:53 +08:00
batched_output_ids.append(output_ids[:max_output_length])
else:
input_ids = [151331, 151333]
for message in conv:
if len(input_ids) >= max_input_length:
break
else:
message = process_message(message)
2024-12-09 15:58:10 +08:00
new_input_ids = tokenizer.apply_chat_template(
[message], tokenize=True, return_dict=False
)[2:]
if message["role"] == "assistant":
2024-07-19 22:04:53 +08:00
output_prompt, output_ids = (
new_input_ids[:1],
new_input_ids[1:],
)
output_ids.append(151336)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
2024-06-19 11:22:31 +08:00
torch.cuda.empty_cache()
2024-12-09 15:58:10 +08:00
return {"input_ids": batched_input_ids, "output_ids": batched_output_ids}
2024-06-05 10:22:16 +08:00
def load_tokenizer_and_model(
2024-12-09 15:58:10 +08:00
model_dir: str,
peft_config: Optional[PeftConfig] = None,
2024-06-05 10:22:16 +08:00
):
2024-12-09 15:58:10 +08:00
tokenizer = AutoTokenizer.from_pretrained(
model_dir, padding_side="left", trust_remote_code=True
)
2024-06-05 10:22:16 +08:00
if peft_config is not None:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
use_cache=False,
2024-12-09 15:58:10 +08:00
torch_dtype=torch.bfloat16, # Must use BFloat 16
2024-06-05 10:22:16 +08:00
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
use_cache=False,
2024-12-09 15:58:10 +08:00
torch_dtype=torch.bfloat16,
2024-06-05 10:22:16 +08:00
)
return tokenizer, model
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
batched_pred_ids, batched_label_ids = eval_preds
2024-12-09 15:58:10 +08:00
batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
2024-06-05 10:22:16 +08:00
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
2024-12-09 15:58:10 +08:00
pred_txt = tokenizer.decode(pred_ids).strip()
label_txt = tokenizer.decode(label_ids).strip()
2024-06-05 10:22:16 +08:00
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()
2024-12-09 15:58:10 +08:00
scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
2024-06-05 10:22:16 +08:00
for k, v in scores[0].items():
2024-12-09 15:58:10 +08:00
metrics_dct[k].append(round(v["f"] * 100, 4))
metrics_dct["bleu-4"].append(
sentence_bleu(
[label_tokens],
pred_tokens,
smoothing_function=SmoothingFunction().method3,
)
)
2024-06-05 10:22:16 +08:00
return {k: np.mean(v) for k, v in metrics_dct.items()}
@app.command()
def main(
2024-12-09 15:58:10 +08:00
data_dir: Annotated[str, typer.Argument(help="")],
model_dir: Annotated[
str,
typer.Argument(
help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file."
2024-06-05 10:22:16 +08:00
),
2024-12-09 15:58:10 +08:00
],
config_file: Annotated[str, typer.Argument(help="")],
auto_resume_from_checkpoint: str = typer.Argument(
default="",
help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training",
),
2024-06-05 10:22:16 +08:00
):
ft_config = FinetuningConfig.from_file(config_file)
2024-12-09 15:58:10 +08:00
tokenizer, model = load_tokenizer_and_model(
model_dir, peft_config=ft_config.peft_config
)
2024-06-05 10:22:16 +08:00
data_manager = DataManager(data_dir, ft_config.data_config)
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
2024-07-19 22:04:53 +08:00
combine=ft_config.combine,
2024-06-05 10:22:16 +08:00
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
2024-12-09 15:58:10 +08:00
print("train_dataset:", train_dataset)
2024-06-05 10:22:16 +08:00
val_dataset = data_manager.get_dataset(
Split.VALIDATION,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
2024-07-19 22:04:53 +08:00
combine=ft_config.combine,
2024-06-05 10:22:16 +08:00
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if val_dataset is not None:
2024-12-09 15:58:10 +08:00
print("val_dataset:", val_dataset)
2024-06-05 10:22:16 +08:00
test_dataset = data_manager.get_dataset(
Split.TEST,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
2024-07-19 22:04:53 +08:00
combine=ft_config.combine,
2024-06-05 10:22:16 +08:00
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if test_dataset is not None:
2024-12-09 15:58:10 +08:00
print("test_dataset:", test_dataset)
2024-06-05 10:22:16 +08:00
2024-12-09 15:58:10 +08:00
ft_config.training_args.generation_config.pad_token_id = 151329
ft_config.training_args.generation_config.eos_token_id = [151329, 151336, 151338]
2024-06-05 10:22:16 +08:00
trainer = Seq2SeqTrainer(
model=model,
args=ft_config.training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
2024-12-09 15:58:10 +08:00
padding="longest",
return_tensors="pt",
2024-06-05 10:22:16 +08:00
),
train_dataset=train_dataset,
2024-07-01 17:00:28 +08:00
eval_dataset=val_dataset,
2024-06-05 10:22:16 +08:00
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
)
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
trainer.train()
else:
output_dir = ft_config.training_args.output_dir
dirlist = os.listdir(output_dir)
checkpoint_sn = 0
for checkpoint_str in dirlist:
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
if checkpoint > checkpoint_sn:
checkpoint_sn = checkpoint
if auto_resume_from_checkpoint.upper() == "YES":
if checkpoint_sn > 0:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
2024-12-09 15:58:10 +08:00
checkpoint_directory = os.path.join(
output_dir, "checkpoint-" + str(checkpoint_sn)
)
2024-06-19 11:22:31 +08:00
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
2024-06-05 10:22:16 +08:00
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
trainer.train()
else:
if auto_resume_from_checkpoint.isdigit():
if int(auto_resume_from_checkpoint) > 0:
checkpoint_sn = int(auto_resume_from_checkpoint)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
2024-12-09 15:58:10 +08:00
checkpoint_directory = os.path.join(
output_dir, "checkpoint-" + str(checkpoint_sn)
)
2024-06-19 11:22:31 +08:00
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
2024-06-05 10:22:16 +08:00
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
2024-12-09 15:58:10 +08:00
print(
auto_resume_from_checkpoint,
"The specified checkpoint sn("
+ auto_resume_from_checkpoint
+ ") has not been saved. Please search for the correct checkpoint in the model output directory",
)
2024-06-05 10:22:16 +08:00
if test_dataset is not None:
trainer.predict(test_dataset)
2024-12-09 15:58:10 +08:00
if __name__ == "__main__":
2024-06-05 10:22:16 +08:00
app()