448 lines
16 KiB
Python
448 lines
16 KiB
Python
# -*- coding: utf-8 -*-
|
|
import json
|
|
import os
|
|
import jieba
|
|
import dataclasses as dc
|
|
import functools
|
|
from collections.abc import Callable, Mapping, Sequence
|
|
from pathlib import Path
|
|
from typing import Annotated, Any, Optional, Union
|
|
import numpy as np
|
|
import ruamel.yaml as yaml
|
|
import torch
|
|
import typer
|
|
from datasets import Dataset, NamedSplit, Split
|
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
|
from peft import PeftConfig, get_peft_config, get_peft_model
|
|
from rouge_chinese import Rouge
|
|
from torch import nn
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
EvalPrediction,
|
|
GenerationConfig,
|
|
PreTrainedTokenizer,
|
|
Seq2SeqTrainingArguments,
|
|
)
|
|
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
|
|
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
|
|
|
|
app = typer.Typer(pretty_exceptions_show_locals=False)
|
|
|
|
|
|
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
|
def __call__(self, features, return_tensors=None):
|
|
output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
|
|
if output_ids is not None:
|
|
max_output_length = max(len(out) for out in output_ids)
|
|
if self.pad_to_multiple_of is not None:
|
|
max_output_length = (
|
|
(
|
|
max_output_length + self.pad_to_multiple_of - 1) //
|
|
self.pad_to_multiple_of * self.pad_to_multiple_of
|
|
)
|
|
for feature in features:
|
|
remainder = [self.tokenizer.pad_token_id] * (
|
|
max_output_length - len(feature['output_ids'])
|
|
)
|
|
if isinstance(feature['output_ids'], list):
|
|
feature['output_ids'] = feature['output_ids'] + remainder
|
|
else:
|
|
feature['output_ids'] = np.concatenate(
|
|
[feature['output_ids'], remainder]
|
|
).astype(np.int64)
|
|
return super().__call__(features, return_tensors)
|
|
|
|
|
|
class Seq2SeqTrainer(_Seq2SeqTrainer):
|
|
def prediction_step(
|
|
self,
|
|
model: nn.Module,
|
|
inputs: dict[str, Any],
|
|
prediction_loss_only: bool,
|
|
ignore_keys=None,
|
|
**gen_kwargs,
|
|
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
if self.args.predict_with_generate:
|
|
output_ids = inputs.pop('output_ids')
|
|
input_ids = inputs['input_ids']
|
|
loss, generated_tokens, labels = super().prediction_step(
|
|
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
|
|
)
|
|
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
|
|
labels = output_ids
|
|
return loss, generated_tokens, labels
|
|
|
|
|
|
@dc.dataclass
|
|
class DataConfig(object):
|
|
train_file: Optional[str] = None
|
|
val_file: Optional[str] = None
|
|
test_file: Optional[str] = None
|
|
num_proc: Optional[int] = None
|
|
|
|
@property
|
|
def data_format(self) -> str:
|
|
return Path(self.train_file).suffix
|
|
|
|
@property
|
|
def data_files(self) -> dict[NamedSplit, str]:
|
|
return {
|
|
split: data_file
|
|
for split, data_file in zip(
|
|
[Split.TRAIN, Split.VALIDATION, Split.TEST],
|
|
[self.train_file, self.val_file, self.test_file],
|
|
)
|
|
if data_file is not None
|
|
}
|
|
|
|
|
|
@dc.dataclass
|
|
class FinetuningConfig(object):
|
|
data_config: DataConfig
|
|
|
|
max_input_length: int
|
|
max_output_length: int
|
|
|
|
training_args: Seq2SeqTrainingArguments = dc.field(
|
|
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
|
)
|
|
peft_config: Optional[PeftConfig] = None
|
|
|
|
def __post_init__(self):
|
|
if not self.training_args.do_eval or self.data_config.val_file is None:
|
|
self.training_args.do_eval = False
|
|
self.training_args.evaluation_strategy = 'no'
|
|
self.data_config.val_file = None
|
|
else:
|
|
self.training_args.per_device_eval_batch_size = (
|
|
self.training_args.per_device_eval_batch_size
|
|
or self.training_args.per_device_train_batch_size
|
|
)
|
|
|
|
@classmethod
|
|
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
|
|
training_args = kwargs.get('training_args', None)
|
|
if training_args is not None and not isinstance(
|
|
training_args, Seq2SeqTrainingArguments
|
|
):
|
|
gen_config = training_args.get('generation_config')
|
|
# TODO: a bit hacky
|
|
if not isinstance(gen_config, GenerationConfig):
|
|
training_args['generation_config'] = GenerationConfig(
|
|
**gen_config
|
|
)
|
|
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
|
|
|
|
data_config = kwargs.get('data_config')
|
|
if not isinstance(data_config, DataConfig):
|
|
kwargs['data_config'] = DataConfig(**data_config)
|
|
|
|
peft_config = kwargs.get('peft_config', None)
|
|
if peft_config is not None and not isinstance(peft_config, PeftConfig):
|
|
kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
|
|
return cls(**kwargs)
|
|
|
|
@classmethod
|
|
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
|
|
path = Path(path)
|
|
parser = yaml.YAML(typ='safe', pure=True)
|
|
parser.indent(mapping=2, offset=2, sequence=4)
|
|
parser.default_flow_style = False
|
|
kwargs = parser.load(path)
|
|
return cls.from_dict(**kwargs)
|
|
|
|
|
|
from datasets import load_dataset, DatasetDict, NamedSplit
|
|
from typing import Optional
|
|
|
|
|
|
def _load_datasets(
|
|
data_dir: str,
|
|
data_format: str,
|
|
data_files: dict[NamedSplit, str],
|
|
num_proc: Optional[int],
|
|
) -> DatasetDict:
|
|
if data_format == '.jsonl':
|
|
dataset_dct = load_dataset(
|
|
data_dir,
|
|
data_files=data_files,
|
|
split=None,
|
|
num_proc=num_proc,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
|
|
return dataset_dct
|
|
|
|
|
|
class DataManager(object):
|
|
def __init__(self, data_dir: str, data_config: DataConfig):
|
|
self._num_proc = data_config.num_proc
|
|
|
|
self._dataset_dct = _load_datasets(
|
|
data_dir,
|
|
data_config.data_format,
|
|
data_config.data_files,
|
|
self._num_proc,
|
|
)
|
|
|
|
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
|
|
return self._dataset_dct.get(split, None)
|
|
|
|
def get_dataset(
|
|
self,
|
|
split: NamedSplit,
|
|
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
|
batched: bool = True,
|
|
remove_orig_columns: bool = True,
|
|
) -> Optional[Dataset]:
|
|
orig_dataset = self._get_dataset(split)
|
|
if orig_dataset is None:
|
|
return
|
|
|
|
if remove_orig_columns:
|
|
remove_columns = orig_dataset.column_names
|
|
else:
|
|
remove_columns = None
|
|
return orig_dataset.map(
|
|
process_fn,
|
|
batched=batched,
|
|
remove_columns=remove_columns,
|
|
num_proc=self._num_proc,
|
|
)
|
|
|
|
|
|
def process_message(message):
|
|
if 'tools' in message and message['role'] == 'system':
|
|
for tool in message['tools']:
|
|
parameters = tool['function']['parameters']['properties']
|
|
tool['function']['parameters']['properties'] = \
|
|
{k: v for k, v in parameters.items() if
|
|
v is not None}
|
|
elif 'tools' in message:
|
|
del message['tools']
|
|
return message
|
|
|
|
|
|
def process_batch(
|
|
batch: Mapping[str, Sequence],
|
|
tokenizer: PreTrainedTokenizer,
|
|
max_input_length: int,
|
|
max_output_length: int,
|
|
) -> dict[str, list]:
|
|
batched_conv = batch['messages']
|
|
batched_input_ids = []
|
|
batched_labels = []
|
|
|
|
for conv in batched_conv:
|
|
input_ids = [151331, 151333]
|
|
loss_masks = [False, False]
|
|
for message in conv:
|
|
message = process_message(message)
|
|
loss_mask_val = False if message['role'] in ('system', 'user') else True
|
|
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
|
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
|
input_ids += new_input_ids
|
|
loss_masks += new_loss_masks
|
|
input_ids.append(tokenizer.eos_token_id)
|
|
loss_masks = [False, *loss_masks]
|
|
labels = []
|
|
for input_id, mask in zip(input_ids, loss_masks):
|
|
if mask:
|
|
labels.append(input_id)
|
|
else:
|
|
labels.append(-100)
|
|
max_length = max_input_length + max_output_length + 1
|
|
batched_input_ids.append(input_ids[:max_length])
|
|
batched_labels.append(labels[:max_length])
|
|
return {'input_ids': batched_input_ids, 'labels': batched_labels}
|
|
|
|
|
|
def process_batch_eval(
|
|
batch: Mapping[str, Sequence],
|
|
tokenizer: PreTrainedTokenizer,
|
|
max_input_length: int,
|
|
max_output_length: int,
|
|
) -> dict[str, list]:
|
|
batched_conv = batch['messages']
|
|
batched_input_ids = []
|
|
batched_output_ids = []
|
|
|
|
for conv in batched_conv:
|
|
|
|
input_ids = [151331, 151333]
|
|
for message in conv:
|
|
if len(input_ids) >= max_input_length:
|
|
break
|
|
else:
|
|
message = process_message(message)
|
|
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
|
if message['role'] == 'assistant':
|
|
output_prompt, output_ids = (
|
|
new_input_ids[:1],
|
|
new_input_ids[1:],
|
|
)
|
|
output_ids.append(tokenizer.eos_token_id)
|
|
batched_input_ids.append(
|
|
input_ids[:max_input_length] + output_prompt[:1]
|
|
)
|
|
batched_output_ids.append(output_ids[:max_output_length])
|
|
input_ids += new_input_ids
|
|
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
|
|
|
|
|
|
def load_tokenizer_and_model(
|
|
model_dir: str,
|
|
peft_config: Optional[PeftConfig] = None,
|
|
):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
|
if peft_config is not None:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_dir,
|
|
trust_remote_code=True,
|
|
empty_init=False,
|
|
use_cache=False,
|
|
torch_dtype=torch.bfloat16 # Must use BFloat 16
|
|
)
|
|
model = get_peft_model(model, peft_config)
|
|
model.print_trainable_parameters()
|
|
else:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_dir,
|
|
trust_remote_code=True,
|
|
empty_init=False,
|
|
use_cache=False,
|
|
torch_dtype=torch.bfloat16
|
|
)
|
|
return tokenizer, model
|
|
|
|
|
|
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
|
|
batched_pred_ids, batched_label_ids = eval_preds
|
|
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
|
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
|
pred_txt = tokenizer.decode(pred_ids).strip()
|
|
label_txt = tokenizer.decode(label_ids).strip()
|
|
pred_tokens = list(jieba.cut(pred_txt))
|
|
label_tokens = list(jieba.cut(label_txt))
|
|
rouge = Rouge()
|
|
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
|
|
for k, v in scores[0].items():
|
|
metrics_dct[k].append(round(v['f'] * 100, 4))
|
|
metrics_dct['bleu-4'].append(
|
|
sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
|
|
return {k: np.mean(v) for k, v in metrics_dct.items()}
|
|
|
|
|
|
@app.command()
|
|
def main(
|
|
data_dir: Annotated[str, typer.Argument(help='')],
|
|
model_dir: Annotated[
|
|
str,
|
|
typer.Argument(
|
|
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
|
|
),
|
|
],
|
|
config_file: Annotated[str, typer.Argument(help='')],
|
|
auto_resume_from_checkpoint: str = typer.Argument(
|
|
default='',
|
|
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
|
|
),
|
|
|
|
):
|
|
ft_config = FinetuningConfig.from_file(config_file)
|
|
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
|
data_manager = DataManager(data_dir, ft_config.data_config)
|
|
|
|
train_dataset = data_manager.get_dataset(
|
|
Split.TRAIN,
|
|
functools.partial(
|
|
process_batch,
|
|
tokenizer=tokenizer,
|
|
max_input_length=ft_config.max_input_length,
|
|
max_output_length=ft_config.max_output_length,
|
|
),
|
|
batched=True,
|
|
)
|
|
print('train_dataset:', train_dataset)
|
|
val_dataset = data_manager.get_dataset(
|
|
Split.VALIDATION,
|
|
functools.partial(
|
|
process_batch_eval,
|
|
tokenizer=tokenizer,
|
|
max_input_length=ft_config.max_input_length,
|
|
max_output_length=ft_config.max_output_length,
|
|
),
|
|
batched=True,
|
|
)
|
|
if val_dataset is not None:
|
|
print('val_dataset:', val_dataset)
|
|
test_dataset = data_manager.get_dataset(
|
|
Split.TEST,
|
|
functools.partial(
|
|
process_batch_eval,
|
|
tokenizer=tokenizer,
|
|
max_input_length=ft_config.max_input_length,
|
|
max_output_length=ft_config.max_output_length,
|
|
),
|
|
batched=True,
|
|
)
|
|
if test_dataset is not None:
|
|
print('test_dataset:', test_dataset)
|
|
|
|
model.gradient_checkpointing_enable()
|
|
model.enable_input_require_grads()
|
|
|
|
trainer = Seq2SeqTrainer(
|
|
model=model,
|
|
args=ft_config.training_args,
|
|
data_collator=DataCollatorForSeq2Seq(
|
|
tokenizer=tokenizer,
|
|
padding='longest',
|
|
return_tensors='pt',
|
|
),
|
|
train_dataset=train_dataset,
|
|
eval_dataset=val_dataset.select(list(range(50))),
|
|
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
|
|
)
|
|
|
|
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
|
|
trainer.train()
|
|
else:
|
|
output_dir = ft_config.training_args.output_dir
|
|
dirlist = os.listdir(output_dir)
|
|
checkpoint_sn = 0
|
|
for checkpoint_str in dirlist:
|
|
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
|
|
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
|
|
if checkpoint > checkpoint_sn:
|
|
checkpoint_sn = checkpoint
|
|
if auto_resume_from_checkpoint.upper() == "YES":
|
|
if checkpoint_sn > 0:
|
|
model.gradient_checkpointing_enable()
|
|
model.enable_input_require_grads()
|
|
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
|
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
|
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
|
else:
|
|
trainer.train()
|
|
else:
|
|
if auto_resume_from_checkpoint.isdigit():
|
|
if int(auto_resume_from_checkpoint) > 0:
|
|
checkpoint_sn = int(auto_resume_from_checkpoint)
|
|
model.gradient_checkpointing_enable()
|
|
model.enable_input_require_grads()
|
|
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
|
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
|
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
|
else:
|
|
print(auto_resume_from_checkpoint,
|
|
"The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
|
|
|
|
if test_dataset is not None:
|
|
trainer.predict(test_dataset)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app()
|