2024-06-05 10:22:16 +08:00
# -*- coding: utf-8 -*-
import os
import jieba
import dataclasses as dc
import functools
from collections . abc import Callable , Mapping , Sequence
from pathlib import Path
2024-06-19 11:22:31 +08:00
from typing import Annotated , Any , Union
2024-06-05 10:22:16 +08:00
import numpy as np
import ruamel . yaml as yaml
import torch
import typer
2024-06-19 11:22:31 +08:00
from datasets import Dataset , Split
2024-06-05 10:22:16 +08:00
from nltk . translate . bleu_score import sentence_bleu , SmoothingFunction
from peft import PeftConfig , get_peft_config , get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
AutoModelForCausalLM ,
AutoTokenizer ,
EvalPrediction ,
GenerationConfig ,
PreTrainedTokenizer ,
Seq2SeqTrainingArguments ,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
2024-06-19 11:22:31 +08:00
from datasets import load_dataset , DatasetDict , NamedSplit
from typing import Optional
2024-06-05 10:22:16 +08:00
2024-12-09 15:58:10 +08:00
# For Ascend NPU, please add this
# import torch_npu
# from torch_npu.contrib import transfer_to_npu
2024-06-05 10:22:16 +08:00
app = typer . Typer ( pretty_exceptions_show_locals = False )
class DataCollatorForSeq2Seq ( _DataCollatorForSeq2Seq ) :
def __call__ ( self , features , return_tensors = None ) :
2024-12-09 15:58:10 +08:00
output_ids = (
[ feature [ " output_ids " ] for feature in features ]
if " output_ids " in features [ 0 ] . keys ( )
else None
)
2024-06-05 10:22:16 +08:00
if output_ids is not None :
max_output_length = max ( len ( out ) for out in output_ids )
if self . pad_to_multiple_of is not None :
max_output_length = (
2024-12-09 15:58:10 +08:00
( max_output_length + self . pad_to_multiple_of - 1 )
/ / self . pad_to_multiple_of
* self . pad_to_multiple_of
2024-06-05 10:22:16 +08:00
)
for feature in features :
remainder = [ self . tokenizer . pad_token_id ] * (
2024-12-09 15:58:10 +08:00
max_output_length - len ( feature [ " output_ids " ] )
2024-06-05 10:22:16 +08:00
)
2024-12-09 15:58:10 +08:00
if isinstance ( feature [ " output_ids " ] , list ) :
feature [ " output_ids " ] = feature [ " output_ids " ] + remainder
2024-06-05 10:22:16 +08:00
else :
2024-12-09 15:58:10 +08:00
feature [ " output_ids " ] = np . concatenate (
[ feature [ " output_ids " ] , remainder ]
2024-06-05 10:22:16 +08:00
) . astype ( np . int64 )
return super ( ) . __call__ ( features , return_tensors )
class Seq2SeqTrainer ( _Seq2SeqTrainer ) :
def prediction_step (
2024-12-09 15:58:10 +08:00
self ,
model : nn . Module ,
inputs : dict [ str , Any ] ,
prediction_loss_only : bool ,
ignore_keys = None ,
* * gen_kwargs ,
2024-06-05 10:22:16 +08:00
) - > tuple [ Optional [ float ] , Optional [ torch . Tensor ] , Optional [ torch . Tensor ] ] :
2024-06-19 11:22:31 +08:00
with torch . no_grad ( ) : # Ensure no gradient computation
if self . args . predict_with_generate :
2024-12-09 15:58:10 +08:00
output_ids = inputs . pop ( " output_ids " )
input_ids = inputs [ " input_ids " ]
2024-06-19 11:22:31 +08:00
loss , generated_tokens , labels = super ( ) . prediction_step (
model , inputs , prediction_loss_only , ignore_keys , * * gen_kwargs
)
2024-12-09 15:58:10 +08:00
generated_tokens = generated_tokens [ : , input_ids . size ( ) [ 1 ] : ]
2024-06-19 11:22:31 +08:00
labels = output_ids
del inputs , input_ids , output_ids
torch . cuda . empty_cache ( )
2024-06-05 10:22:16 +08:00
return loss , generated_tokens , labels
@dc.dataclass
class DataConfig ( object ) :
train_file : Optional [ str ] = None
val_file : Optional [ str ] = None
test_file : Optional [ str ] = None
num_proc : Optional [ int ] = None
@property
def data_format ( self ) - > str :
return Path ( self . train_file ) . suffix
@property
def data_files ( self ) - > dict [ NamedSplit , str ] :
return {
split : data_file
for split , data_file in zip (
[ Split . TRAIN , Split . VALIDATION , Split . TEST ] ,
[ self . train_file , self . val_file , self . test_file ] ,
)
if data_file is not None
}
@dc.dataclass
class FinetuningConfig ( object ) :
data_config : DataConfig
max_input_length : int
max_output_length : int
2024-07-19 22:04:53 +08:00
combine : bool
2024-09-17 09:50:19 +08:00
freezeV : bool
2024-06-05 10:22:16 +08:00
training_args : Seq2SeqTrainingArguments = dc . field (
2024-12-09 15:58:10 +08:00
default_factory = lambda : Seq2SeqTrainingArguments ( output_dir = " ./output " )
2024-06-05 10:22:16 +08:00
)
peft_config : Optional [ PeftConfig ] = None
def __post_init__ ( self ) :
if not self . training_args . do_eval or self . data_config . val_file is None :
self . training_args . do_eval = False
2024-12-09 15:58:10 +08:00
self . training_args . evaluation_strategy = " no "
2024-06-05 10:22:16 +08:00
self . data_config . val_file = None
else :
self . training_args . per_device_eval_batch_size = (
2024-12-09 15:58:10 +08:00
self . training_args . per_device_eval_batch_size
or self . training_args . per_device_train_batch_size
2024-06-05 10:22:16 +08:00
)
@classmethod
2024-12-09 15:58:10 +08:00
def from_dict ( cls , * * kwargs ) - > " FinetuningConfig " :
training_args = kwargs . get ( " training_args " , None )
2024-06-05 10:22:16 +08:00
if training_args is not None and not isinstance (
2024-12-09 15:58:10 +08:00
training_args , Seq2SeqTrainingArguments
2024-06-05 10:22:16 +08:00
) :
2024-12-09 15:58:10 +08:00
gen_config = training_args . get ( " generation_config " )
2024-06-05 10:22:16 +08:00
if not isinstance ( gen_config , GenerationConfig ) :
2024-12-09 15:58:10 +08:00
training_args [ " generation_config " ] = GenerationConfig ( * * gen_config )
kwargs [ " training_args " ] = Seq2SeqTrainingArguments ( * * training_args )
2024-06-05 10:22:16 +08:00
2024-12-09 15:58:10 +08:00
data_config = kwargs . get ( " data_config " )
2024-06-05 10:22:16 +08:00
if not isinstance ( data_config , DataConfig ) :
2024-12-09 15:58:10 +08:00
kwargs [ " data_config " ] = DataConfig ( * * data_config )
2024-06-05 10:22:16 +08:00
2024-12-09 15:58:10 +08:00
peft_config = kwargs . get ( " peft_config " , None )
2024-06-05 10:22:16 +08:00
if peft_config is not None and not isinstance ( peft_config , PeftConfig ) :
2024-12-09 15:58:10 +08:00
kwargs [ " peft_config " ] = get_peft_config ( config_dict = peft_config )
2024-06-05 10:22:16 +08:00
return cls ( * * kwargs )
@classmethod
2024-12-09 15:58:10 +08:00
def from_file ( cls , path : Union [ str , Path ] ) - > " FinetuningConfig " :
2024-06-05 10:22:16 +08:00
path = Path ( path )
2024-12-09 15:58:10 +08:00
parser = yaml . YAML ( typ = " safe " , pure = True )
2024-06-05 10:22:16 +08:00
parser . indent ( mapping = 2 , offset = 2 , sequence = 4 )
parser . default_flow_style = False
kwargs = parser . load ( path )
return cls . from_dict ( * * kwargs )
def _load_datasets (
2024-12-09 15:58:10 +08:00
data_dir : str ,
data_format : str ,
data_files : dict [ NamedSplit , str ] ,
num_proc : Optional [ int ] ,
2024-06-05 10:22:16 +08:00
) - > DatasetDict :
2024-12-09 15:58:10 +08:00
if data_format == " .jsonl " :
2024-06-05 10:22:16 +08:00
dataset_dct = load_dataset (
data_dir ,
data_files = data_files ,
split = None ,
num_proc = num_proc ,
)
else :
raise NotImplementedError ( f " Cannot load dataset in the ' { data_format } ' format. " )
return dataset_dct
class DataManager ( object ) :
def __init__ ( self , data_dir : str , data_config : DataConfig ) :
self . _num_proc = data_config . num_proc
self . _dataset_dct = _load_datasets (
data_dir ,
data_config . data_format ,
data_config . data_files ,
self . _num_proc ,
)
def _get_dataset ( self , split : NamedSplit ) - > Optional [ Dataset ] :
return self . _dataset_dct . get ( split , None )
def get_dataset (
2024-12-09 15:58:10 +08:00
self ,
split : NamedSplit ,
process_fn : Callable [ [ dict [ str , Any ] ] , dict [ str , Any ] ] ,
batched : bool = True ,
remove_orig_columns : bool = True ,
2024-06-05 10:22:16 +08:00
) - > Optional [ Dataset ] :
orig_dataset = self . _get_dataset ( split )
if orig_dataset is None :
return
if remove_orig_columns :
remove_columns = orig_dataset . column_names
else :
remove_columns = None
return orig_dataset . map (
process_fn ,
batched = batched ,
remove_columns = remove_columns ,
num_proc = self . _num_proc ,
)
def process_message ( message ) :
2024-12-09 15:58:10 +08:00
if " tools " in message and message [ " role " ] == " system " :
for tool in message [ " tools " ] :
parameters = tool [ " function " ] [ " parameters " ] [ " properties " ]
tool [ " function " ] [ " parameters " ] [ " properties " ] = {
k : v for k , v in parameters . items ( ) if v is not None
}
elif " tools " in message :
del message [ " tools " ]
2024-06-05 10:22:16 +08:00
return message
def process_batch (
2024-12-09 15:58:10 +08:00
batch : Mapping [ str , Sequence ] ,
tokenizer : PreTrainedTokenizer ,
max_input_length : int ,
max_output_length : int ,
combine : bool ,
2024-06-05 10:22:16 +08:00
) - > dict [ str , list ] :
2024-12-09 15:58:10 +08:00
batched_conv = batch [ " messages " ]
2024-06-05 10:22:16 +08:00
batched_input_ids = [ ]
batched_labels = [ ]
for conv in batched_conv :
input_ids = [ 151331 , 151333 ]
loss_masks = [ False , False ]
2024-07-19 22:04:53 +08:00
if combine :
2024-12-09 15:58:10 +08:00
new_input_ids = tokenizer . apply_chat_template (
conv , tokenize = True , return_dict = False
)
2024-07-19 22:04:53 +08:00
input_ids = new_input_ids
loss_masks = [ False ] * len ( input_ids )
last_assistant_index = len ( input_ids ) - input_ids [ : : - 1 ] . index ( 151337 ) - 1
for j in range ( last_assistant_index + 1 , len ( input_ids ) ) :
loss_masks [ j ] = True
else :
for message in conv :
message = process_message ( message )
2024-12-09 15:58:10 +08:00
loss_mask_val = (
False
if message [ " role " ] in ( " system " , " user " , " observation " )
else True
)
new_input_ids = tokenizer . apply_chat_template (
[ message ] , tokenize = True , return_dict = False
) [ 2 : ]
2024-07-19 22:04:53 +08:00
input_ids + = new_input_ids
loss_masks + = [ loss_mask_val ] * len ( new_input_ids )
2024-07-01 17:00:28 +08:00
input_ids . append ( 151336 ) # EOS for chat
2024-06-05 10:22:16 +08:00
loss_masks = [ False , * loss_masks ]
labels = [ ]
for input_id , mask in zip ( input_ids , loss_masks ) :
if mask :
labels . append ( input_id )
else :
labels . append ( - 100 )
max_length = max_input_length + max_output_length + 1
batched_input_ids . append ( input_ids [ : max_length ] )
batched_labels . append ( labels [ : max_length ] )
2024-06-19 11:22:31 +08:00
2024-07-19 22:04:53 +08:00
del batched_conv , conv , input_ids , loss_masks , new_input_ids , labels
2024-06-19 11:22:31 +08:00
torch . cuda . empty_cache ( )
2024-12-09 15:58:10 +08:00
return { " input_ids " : batched_input_ids , " labels " : batched_labels }
2024-06-05 10:22:16 +08:00
def process_batch_eval (
2024-12-09 15:58:10 +08:00
batch : Mapping [ str , Sequence ] ,
tokenizer : PreTrainedTokenizer ,
max_input_length : int ,
max_output_length : int ,
combine : bool ,
2024-06-05 10:22:16 +08:00
) - > dict [ str , list ] :
2024-12-09 15:58:10 +08:00
batched_conv = batch [ " messages " ]
2024-06-05 10:22:16 +08:00
batched_input_ids = [ ]
batched_output_ids = [ ]
for conv in batched_conv :
2024-07-19 22:04:53 +08:00
if combine :
2024-12-09 15:58:10 +08:00
new_input_ids = tokenizer . apply_chat_template (
conv , tokenize = True , return_dict = False
)
2024-07-19 22:04:53 +08:00
input_ids = new_input_ids
last_assistant_index = len ( input_ids ) - input_ids [ : : - 1 ] . index ( 151337 ) - 1
output_prompt , output_ids = (
input_ids [ : 1 ] ,
input_ids [ last_assistant_index : ] ,
)
output_ids . append ( 151336 )
2024-12-09 15:58:10 +08:00
batched_input_ids . append ( input_ids [ : max_input_length ] + output_prompt [ : 1 ] )
2024-07-19 22:04:53 +08:00
batched_output_ids . append ( output_ids [ : max_output_length ] )
else :
input_ids = [ 151331 , 151333 ]
for message in conv :
if len ( input_ids ) > = max_input_length :
break
else :
message = process_message ( message )
2024-12-09 15:58:10 +08:00
new_input_ids = tokenizer . apply_chat_template (
[ message ] , tokenize = True , return_dict = False
) [ 2 : ]
if message [ " role " ] == " assistant " :
2024-07-19 22:04:53 +08:00
output_prompt , output_ids = (
new_input_ids [ : 1 ] ,
new_input_ids [ 1 : ] ,
)
output_ids . append ( 151336 )
batched_input_ids . append (
input_ids [ : max_input_length ] + output_prompt [ : 1 ]
)
batched_output_ids . append ( output_ids [ : max_output_length ] )
input_ids + = new_input_ids
del batched_conv , conv , input_ids , new_input_ids , output_prompt , output_ids
2024-06-19 11:22:31 +08:00
torch . cuda . empty_cache ( )
2024-12-09 15:58:10 +08:00
return { " input_ids " : batched_input_ids , " output_ids " : batched_output_ids }
2024-06-05 10:22:16 +08:00
def load_tokenizer_and_model (
2024-12-09 15:58:10 +08:00
model_dir : str ,
peft_config : Optional [ PeftConfig ] = None ,
2024-06-05 10:22:16 +08:00
) :
2024-12-09 15:58:10 +08:00
tokenizer = AutoTokenizer . from_pretrained (
model_dir , padding_side = " left " , trust_remote_code = True
)
2024-06-05 10:22:16 +08:00
if peft_config is not None :
model = AutoModelForCausalLM . from_pretrained (
model_dir ,
trust_remote_code = True ,
use_cache = False ,
2024-12-09 15:58:10 +08:00
torch_dtype = torch . bfloat16 , # Must use BFloat 16
2024-06-05 10:22:16 +08:00
)
model = get_peft_model ( model , peft_config )
model . print_trainable_parameters ( )
else :
model = AutoModelForCausalLM . from_pretrained (
model_dir ,
trust_remote_code = True ,
use_cache = False ,
2024-12-09 15:58:10 +08:00
torch_dtype = torch . bfloat16 ,
2024-06-05 10:22:16 +08:00
)
return tokenizer , model
def compute_metrics ( eval_preds : EvalPrediction , tokenizer ) :
batched_pred_ids , batched_label_ids = eval_preds
2024-12-09 15:58:10 +08:00
batched_pred_ids [ batched_pred_ids == - 100 ] = tokenizer . pad_token_id
batched_label_ids [ batched_label_ids == - 100 ] = tokenizer . pad_token_id
metrics_dct = { " rouge-1 " : [ ] , " rouge-2 " : [ ] , " rouge-l " : [ ] , " bleu-4 " : [ ] }
2024-06-05 10:22:16 +08:00
for pred_ids , label_ids in zip ( batched_pred_ids , batched_label_ids ) :
2024-12-09 15:58:10 +08:00
pred_txt = tokenizer . decode ( pred_ids ) . strip ( )
label_txt = tokenizer . decode ( label_ids ) . strip ( )
2024-06-05 10:22:16 +08:00
pred_tokens = list ( jieba . cut ( pred_txt ) )
label_tokens = list ( jieba . cut ( label_txt ) )
rouge = Rouge ( )
2024-12-09 15:58:10 +08:00
scores = rouge . get_scores ( " " . join ( pred_tokens ) , " " . join ( label_tokens ) )
2024-06-05 10:22:16 +08:00
for k , v in scores [ 0 ] . items ( ) :
2024-12-09 15:58:10 +08:00
metrics_dct [ k ] . append ( round ( v [ " f " ] * 100 , 4 ) )
metrics_dct [ " bleu-4 " ] . append (
sentence_bleu (
[ label_tokens ] ,
pred_tokens ,
smoothing_function = SmoothingFunction ( ) . method3 ,
)
)
2024-06-05 10:22:16 +08:00
return { k : np . mean ( v ) for k , v in metrics_dct . items ( ) }
@app.command ( )
def main (
2024-12-09 15:58:10 +08:00
data_dir : Annotated [ str , typer . Argument ( help = " " ) ] ,
model_dir : Annotated [
str ,
typer . Argument (
help = " A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file. "
2024-06-05 10:22:16 +08:00
) ,
2024-12-09 15:58:10 +08:00
] ,
config_file : Annotated [ str , typer . Argument ( help = " " ) ] ,
auto_resume_from_checkpoint : str = typer . Argument (
default = " " ,
help = " If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training " ,
) ,
2024-06-05 10:22:16 +08:00
) :
ft_config = FinetuningConfig . from_file ( config_file )
2024-12-09 15:58:10 +08:00
tokenizer , model = load_tokenizer_and_model (
model_dir , peft_config = ft_config . peft_config
)
2024-06-05 10:22:16 +08:00
data_manager = DataManager ( data_dir , ft_config . data_config )
train_dataset = data_manager . get_dataset (
Split . TRAIN ,
functools . partial (
process_batch ,
tokenizer = tokenizer ,
2024-07-19 22:04:53 +08:00
combine = ft_config . combine ,
2024-06-05 10:22:16 +08:00
max_input_length = ft_config . max_input_length ,
max_output_length = ft_config . max_output_length ,
) ,
batched = True ,
)
2024-12-09 15:58:10 +08:00
print ( " train_dataset: " , train_dataset )
2024-06-05 10:22:16 +08:00
val_dataset = data_manager . get_dataset (
Split . VALIDATION ,
functools . partial (
process_batch_eval ,
tokenizer = tokenizer ,
2024-07-19 22:04:53 +08:00
combine = ft_config . combine ,
2024-06-05 10:22:16 +08:00
max_input_length = ft_config . max_input_length ,
max_output_length = ft_config . max_output_length ,
) ,
batched = True ,
)
if val_dataset is not None :
2024-12-09 15:58:10 +08:00
print ( " val_dataset: " , val_dataset )
2024-06-05 10:22:16 +08:00
test_dataset = data_manager . get_dataset (
Split . TEST ,
functools . partial (
process_batch_eval ,
tokenizer = tokenizer ,
2024-07-19 22:04:53 +08:00
combine = ft_config . combine ,
2024-06-05 10:22:16 +08:00
max_input_length = ft_config . max_input_length ,
max_output_length = ft_config . max_output_length ,
) ,
batched = True ,
)
if test_dataset is not None :
2024-12-09 15:58:10 +08:00
print ( " test_dataset: " , test_dataset )
2024-06-05 10:22:16 +08:00
2024-12-09 15:58:10 +08:00
ft_config . training_args . generation_config . pad_token_id = 151329
ft_config . training_args . generation_config . eos_token_id = [ 151329 , 151336 , 151338 ]
2024-06-05 10:22:16 +08:00
trainer = Seq2SeqTrainer (
model = model ,
args = ft_config . training_args ,
data_collator = DataCollatorForSeq2Seq (
tokenizer = tokenizer ,
2024-12-09 15:58:10 +08:00
padding = " longest " ,
return_tensors = " pt " ,
2024-06-05 10:22:16 +08:00
) ,
train_dataset = train_dataset ,
2024-07-01 17:00:28 +08:00
eval_dataset = val_dataset ,
2024-06-05 10:22:16 +08:00
compute_metrics = functools . partial ( compute_metrics , tokenizer = tokenizer ) ,
)
if auto_resume_from_checkpoint . upper ( ) == " " or auto_resume_from_checkpoint is None :
trainer . train ( )
else :
output_dir = ft_config . training_args . output_dir
dirlist = os . listdir ( output_dir )
checkpoint_sn = 0
for checkpoint_str in dirlist :
if checkpoint_str . find ( " eckpoint " ) > 0 and checkpoint_str . find ( " tmp " ) == - 1 :
checkpoint = int ( checkpoint_str . replace ( " checkpoint- " , " " ) )
if checkpoint > checkpoint_sn :
checkpoint_sn = checkpoint
if auto_resume_from_checkpoint . upper ( ) == " YES " :
if checkpoint_sn > 0 :
model . gradient_checkpointing_enable ( )
model . enable_input_require_grads ( )
2024-12-09 15:58:10 +08:00
checkpoint_directory = os . path . join (
output_dir , " checkpoint- " + str ( checkpoint_sn )
)
2024-06-19 11:22:31 +08:00
print ( " resume checkpoint from checkpoint- " + str ( checkpoint_sn ) )
2024-06-05 10:22:16 +08:00
trainer . train ( resume_from_checkpoint = checkpoint_directory )
else :
trainer . train ( )
else :
if auto_resume_from_checkpoint . isdigit ( ) :
if int ( auto_resume_from_checkpoint ) > 0 :
checkpoint_sn = int ( auto_resume_from_checkpoint )
model . gradient_checkpointing_enable ( )
model . enable_input_require_grads ( )
2024-12-09 15:58:10 +08:00
checkpoint_directory = os . path . join (
output_dir , " checkpoint- " + str ( checkpoint_sn )
)
2024-06-19 11:22:31 +08:00
print ( " resume checkpoint from checkpoint- " + str ( checkpoint_sn ) )
2024-06-05 10:22:16 +08:00
trainer . train ( resume_from_checkpoint = checkpoint_directory )
else :
2024-12-09 15:58:10 +08:00
print (
auto_resume_from_checkpoint ,
" The specified checkpoint sn( "
+ auto_resume_from_checkpoint
+ " ) has not been saved. Please search for the correct checkpoint in the model output directory " ,
)
2024-06-05 10:22:16 +08:00
if test_dataset is not None :
trainer . predict ( test_dataset )
2024-12-09 15:58:10 +08:00
if __name__ == " __main__ " :
2024-06-05 10:22:16 +08:00
app ( )