support NPU 910B
This commit is contained in:
parent
e14f187090
commit
52cd14ce8c
|
@ -11,7 +11,8 @@ Read this in [English](README_en.md)
|
|||
|
||||
## 项目更新
|
||||
|
||||
- 🔥🔥 **News**: ```2024/11/01```: 本仓库依赖进行升级,请更新`requirements.txt`中的依赖以保证正常运行模型。[glm-4-9b-chat-hf](https://huggingface.co/THUDM/glm-4-9b-chat-hf) 是适配 `transformers>=4.46.2` 的模型权重,使用 `transformers` 库中的 `GlmModel` 类实现。
|
||||
- 🔥🔥 **News**: ```2024/12/10```: 本仓库微调代码支持使用`Ascend NPU`进行微调。请更新微调代码并查看代码内注释。
|
||||
- 🔥 **News**: ```2024/11/01```: 本仓库依赖进行升级,请更新`requirements.txt`中的依赖以保证正常运行模型。[glm-4-9b-chat-hf](https://huggingface.co/THUDM/glm-4-9b-chat-hf) 是适配 `transformers>=4.46.2` 的模型权重,使用 `transformers` 库中的 `GlmModel` 类实现。
|
||||
同时,[glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat), [glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b) 中的 `tokenzier_chatglm.py` 已经更新以适配最新版本的 `transformers`库。请前往 HuggingFace 更新文件。
|
||||
- 🔥 **News**: ```2024/10/27```: 我们开源了 [LongReward](https://github.com/THUDM/LongReward),这是一个使用 AI 反馈改进长上下文大型语言模型。
|
||||
- 🔥 **News**: ```2024/10/25```: 我们开源了端到端中英语音对话模型 [GLM-4-Voice](https://github.com/THUDM/GLM-4-Voice)。
|
||||
|
|
|
@ -28,51 +28,42 @@ from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
|
|||
from datasets import load_dataset, DatasetDict, NamedSplit
|
||||
from typing import Optional
|
||||
|
||||
# For Ascend NPU, please add this
|
||||
# import torch_npu
|
||||
# from torch_npu.contrib import transfer_to_npu
|
||||
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
|
||||
output_ids = (
|
||||
[feature["output_ids"] for feature in features]
|
||||
if "output_ids" in features[0].keys()
|
||||
else None
|
||||
)
|
||||
if output_ids is not None:
|
||||
max_output_length = max(len(out) for out in output_ids)
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_output_length = (
|
||||
(
|
||||
max_output_length + self.pad_to_multiple_of - 1) //
|
||||
self.pad_to_multiple_of * self.pad_to_multiple_of
|
||||
(max_output_length + self.pad_to_multiple_of - 1)
|
||||
// self.pad_to_multiple_of
|
||||
* self.pad_to_multiple_of
|
||||
)
|
||||
for feature in features:
|
||||
remainder = [self.tokenizer.pad_token_id] * (
|
||||
max_output_length - len(feature['output_ids'])
|
||||
max_output_length - len(feature["output_ids"])
|
||||
)
|
||||
if isinstance(feature['output_ids'], list):
|
||||
feature['output_ids'] = feature['output_ids'] + remainder
|
||||
if isinstance(feature["output_ids"], list):
|
||||
feature["output_ids"] = feature["output_ids"] + remainder
|
||||
else:
|
||||
feature['output_ids'] = np.concatenate(
|
||||
[feature['output_ids'], remainder]
|
||||
feature["output_ids"] = np.concatenate(
|
||||
[feature["output_ids"], remainder]
|
||||
).astype(np.int64)
|
||||
return super().__call__(features, return_tensors)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(_Seq2SeqTrainer):
|
||||
# Not Support for apex. transformers>=4.46 require additional args: num_items_in_batch
|
||||
def training_step(self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch=None) -> torch.Tensor:
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
self.accelerator.backward(loss)
|
||||
detached_loss = loss.detach() / self.args.gradient_accumulation_steps
|
||||
del inputs
|
||||
torch.cuda.empty_cache()
|
||||
return detached_loss
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
@ -81,17 +72,16 @@ class Seq2SeqTrainer(_Seq2SeqTrainer):
|
|||
ignore_keys=None,
|
||||
**gen_kwargs,
|
||||
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad(): # Ensure no gradient computation
|
||||
if self.args.predict_with_generate:
|
||||
output_ids = inputs.pop('output_ids')
|
||||
input_ids = inputs['input_ids']
|
||||
output_ids = inputs.pop("output_ids")
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
|
||||
)
|
||||
|
||||
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
|
||||
generated_tokens = generated_tokens[:, input_ids.size()[1] :]
|
||||
labels = output_ids
|
||||
|
||||
del inputs, input_ids, output_ids
|
||||
|
@ -133,14 +123,14 @@ class FinetuningConfig(object):
|
|||
freezeV: bool
|
||||
|
||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output")
|
||||
)
|
||||
peft_config: Optional[PeftConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.training_args.do_eval or self.data_config.val_file is None:
|
||||
self.training_args.do_eval = False
|
||||
self.training_args.evaluation_strategy = 'no'
|
||||
self.training_args.evaluation_strategy = "no"
|
||||
self.data_config.val_file = None
|
||||
else:
|
||||
self.training_args.per_device_eval_batch_size = (
|
||||
|
@ -149,31 +139,29 @@ class FinetuningConfig(object):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
|
||||
training_args = kwargs.get('training_args', None)
|
||||
def from_dict(cls, **kwargs) -> "FinetuningConfig":
|
||||
training_args = kwargs.get("training_args", None)
|
||||
if training_args is not None and not isinstance(
|
||||
training_args, Seq2SeqTrainingArguments
|
||||
):
|
||||
gen_config = training_args.get('generation_config')
|
||||
gen_config = training_args.get("generation_config")
|
||||
if not isinstance(gen_config, GenerationConfig):
|
||||
training_args['generation_config'] = GenerationConfig(
|
||||
**gen_config
|
||||
)
|
||||
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
|
||||
training_args["generation_config"] = GenerationConfig(**gen_config)
|
||||
kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args)
|
||||
|
||||
data_config = kwargs.get('data_config')
|
||||
data_config = kwargs.get("data_config")
|
||||
if not isinstance(data_config, DataConfig):
|
||||
kwargs['data_config'] = DataConfig(**data_config)
|
||||
kwargs["data_config"] = DataConfig(**data_config)
|
||||
|
||||
peft_config = kwargs.get('peft_config', None)
|
||||
peft_config = kwargs.get("peft_config", None)
|
||||
if peft_config is not None and not isinstance(peft_config, PeftConfig):
|
||||
kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
|
||||
kwargs["peft_config"] = get_peft_config(config_dict=peft_config)
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
|
||||
def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig":
|
||||
path = Path(path)
|
||||
parser = yaml.YAML(typ='safe', pure=True)
|
||||
parser = yaml.YAML(typ="safe", pure=True)
|
||||
parser.indent(mapping=2, offset=2, sequence=4)
|
||||
parser.default_flow_style = False
|
||||
kwargs = parser.load(path)
|
||||
|
@ -186,7 +174,7 @@ def _load_datasets(
|
|||
data_files: dict[NamedSplit, str],
|
||||
num_proc: Optional[int],
|
||||
) -> DatasetDict:
|
||||
if data_format == '.jsonl':
|
||||
if data_format == ".jsonl":
|
||||
dataset_dct = load_dataset(
|
||||
data_dir,
|
||||
data_files=data_files,
|
||||
|
@ -236,14 +224,14 @@ class DataManager(object):
|
|||
|
||||
|
||||
def process_message(message):
|
||||
if 'tools' in message and message['role'] == 'system':
|
||||
for tool in message['tools']:
|
||||
parameters = tool['function']['parameters']['properties']
|
||||
tool['function']['parameters']['properties'] = \
|
||||
{k: v for k, v in parameters.items() if
|
||||
v is not None}
|
||||
elif 'tools' in message:
|
||||
del message['tools']
|
||||
if "tools" in message and message["role"] == "system":
|
||||
for tool in message["tools"]:
|
||||
parameters = tool["function"]["parameters"]["properties"]
|
||||
tool["function"]["parameters"]["properties"] = {
|
||||
k: v for k, v in parameters.items() if v is not None
|
||||
}
|
||||
elif "tools" in message:
|
||||
del message["tools"]
|
||||
return message
|
||||
|
||||
|
||||
|
@ -254,14 +242,16 @@ def process_batch(
|
|||
max_output_length: int,
|
||||
combine: bool,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_conv = batch["messages"]
|
||||
batched_input_ids = []
|
||||
batched_labels = []
|
||||
for conv in batched_conv:
|
||||
input_ids = [151331, 151333]
|
||||
loss_masks = [False, False]
|
||||
if combine:
|
||||
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
|
||||
new_input_ids = tokenizer.apply_chat_template(
|
||||
conv, tokenize=True, return_dict=False
|
||||
)
|
||||
input_ids = new_input_ids
|
||||
loss_masks = [False] * len(input_ids)
|
||||
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
|
||||
|
@ -270,8 +260,14 @@ def process_batch(
|
|||
else:
|
||||
for message in conv:
|
||||
message = process_message(message)
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
loss_mask_val = (
|
||||
False
|
||||
if message["role"] in ("system", "user", "observation")
|
||||
else True
|
||||
)
|
||||
new_input_ids = tokenizer.apply_chat_template(
|
||||
[message], tokenize=True, return_dict=False
|
||||
)[2:]
|
||||
input_ids += new_input_ids
|
||||
loss_masks += [loss_mask_val] * len(new_input_ids)
|
||||
|
||||
|
@ -290,7 +286,7 @@ def process_batch(
|
|||
del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {'input_ids': batched_input_ids, 'labels': batched_labels}
|
||||
return {"input_ids": batched_input_ids, "labels": batched_labels}
|
||||
|
||||
|
||||
def process_batch_eval(
|
||||
|
@ -300,13 +296,15 @@ def process_batch_eval(
|
|||
max_output_length: int,
|
||||
combine: bool,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_conv = batch["messages"]
|
||||
batched_input_ids = []
|
||||
batched_output_ids = []
|
||||
|
||||
for conv in batched_conv:
|
||||
if combine:
|
||||
new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
|
||||
new_input_ids = tokenizer.apply_chat_template(
|
||||
conv, tokenize=True, return_dict=False
|
||||
)
|
||||
input_ids = new_input_ids
|
||||
last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
|
||||
output_prompt, output_ids = (
|
||||
|
@ -314,9 +312,7 @@ def process_batch_eval(
|
|||
input_ids[last_assistant_index:],
|
||||
)
|
||||
output_ids.append(151336)
|
||||
batched_input_ids.append(
|
||||
input_ids[:max_input_length] + output_prompt[:1]
|
||||
)
|
||||
batched_input_ids.append(input_ids[:max_input_length] + output_prompt[:1])
|
||||
batched_output_ids.append(output_ids[:max_output_length])
|
||||
else:
|
||||
input_ids = [151331, 151333]
|
||||
|
@ -325,8 +321,10 @@ def process_batch_eval(
|
|||
break
|
||||
else:
|
||||
message = process_message(message)
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
if message['role'] == 'assistant':
|
||||
new_input_ids = tokenizer.apply_chat_template(
|
||||
[message], tokenize=True, return_dict=False
|
||||
)[2:]
|
||||
if message["role"] == "assistant":
|
||||
output_prompt, output_ids = (
|
||||
new_input_ids[:1],
|
||||
new_input_ids[1:],
|
||||
|
@ -341,20 +339,22 @@ def process_batch_eval(
|
|||
del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
|
||||
return {"input_ids": batched_input_ids, "output_ids": batched_output_ids}
|
||||
|
||||
|
||||
def load_tokenizer_and_model(
|
||||
model_dir: str,
|
||||
peft_config: Optional[PeftConfig] = None,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side='left', trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir, padding_side="left", trust_remote_code=True
|
||||
)
|
||||
if peft_config is not None:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=True,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.bfloat16 # Must use BFloat 16
|
||||
torch_dtype=torch.bfloat16, # Must use BFloat 16
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
@ -363,47 +363,54 @@ def load_tokenizer_and_model(
|
|||
model_dir,
|
||||
trust_remote_code=True,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
|
||||
batched_pred_ids, batched_label_ids = eval_preds
|
||||
batched_pred_ids[batched_pred_ids==-100] = tokenizer.pad_token_id
|
||||
batched_label_ids[batched_label_ids==-100] = tokenizer.pad_token_id
|
||||
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
||||
batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
|
||||
batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
|
||||
metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
||||
pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
|
||||
label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
|
||||
pred_txt = tokenizer.decode(pred_ids).strip()
|
||||
label_txt = tokenizer.decode(label_ids).strip()
|
||||
pred_tokens = list(jieba.cut(pred_txt))
|
||||
label_tokens = list(jieba.cut(label_txt))
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
|
||||
scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
|
||||
for k, v in scores[0].items():
|
||||
metrics_dct[k].append(round(v['f'] * 100, 4))
|
||||
metrics_dct['bleu-4'].append(
|
||||
sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
|
||||
metrics_dct[k].append(round(v["f"] * 100, 4))
|
||||
metrics_dct["bleu-4"].append(
|
||||
sentence_bleu(
|
||||
[label_tokens],
|
||||
pred_tokens,
|
||||
smoothing_function=SmoothingFunction().method3,
|
||||
)
|
||||
)
|
||||
return {k: np.mean(v) for k, v in metrics_dct.items()}
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
data_dir: Annotated[str, typer.Argument(help='')],
|
||||
data_dir: Annotated[str, typer.Argument(help="")],
|
||||
model_dir: Annotated[
|
||||
str,
|
||||
typer.Argument(
|
||||
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
|
||||
help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file."
|
||||
),
|
||||
],
|
||||
config_file: Annotated[str, typer.Argument(help='')],
|
||||
config_file: Annotated[str, typer.Argument(help="")],
|
||||
auto_resume_from_checkpoint: str = typer.Argument(
|
||||
default='',
|
||||
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
|
||||
default="",
|
||||
help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training",
|
||||
),
|
||||
):
|
||||
ft_config = FinetuningConfig.from_file(config_file)
|
||||
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||
tokenizer, model = load_tokenizer_and_model(
|
||||
model_dir, peft_config=ft_config.peft_config
|
||||
)
|
||||
data_manager = DataManager(data_dir, ft_config.data_config)
|
||||
|
||||
train_dataset = data_manager.get_dataset(
|
||||
|
@ -417,7 +424,7 @@ def main(
|
|||
),
|
||||
batched=True,
|
||||
)
|
||||
print('train_dataset:', train_dataset)
|
||||
print("train_dataset:", train_dataset)
|
||||
val_dataset = data_manager.get_dataset(
|
||||
Split.VALIDATION,
|
||||
functools.partial(
|
||||
|
@ -430,7 +437,7 @@ def main(
|
|||
batched=True,
|
||||
)
|
||||
if val_dataset is not None:
|
||||
print('val_dataset:', val_dataset)
|
||||
print("val_dataset:", val_dataset)
|
||||
test_dataset = data_manager.get_dataset(
|
||||
Split.TEST,
|
||||
functools.partial(
|
||||
|
@ -443,25 +450,18 @@ def main(
|
|||
batched=True,
|
||||
)
|
||||
if test_dataset is not None:
|
||||
print('test_dataset:', test_dataset)
|
||||
print("test_dataset:", test_dataset)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
|
||||
ft_config.training_args.generation_config.pad_token_id = (
|
||||
151329
|
||||
)
|
||||
ft_config.training_args.generation_config.eos_token_id = [
|
||||
151329, 151336, 151338
|
||||
]
|
||||
ft_config.training_args.generation_config.pad_token_id = 151329
|
||||
ft_config.training_args.generation_config.eos_token_id = [151329, 151336, 151338]
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=ft_config.training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
padding="longest",
|
||||
return_tensors="pt",
|
||||
),
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
|
@ -483,7 +483,9 @@ def main(
|
|||
if checkpoint_sn > 0:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
checkpoint_directory = os.path.join(
|
||||
output_dir, "checkpoint-" + str(checkpoint_sn)
|
||||
)
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
|
@ -494,16 +496,22 @@ def main(
|
|||
checkpoint_sn = int(auto_resume_from_checkpoint)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
checkpoint_directory = os.path.join(
|
||||
output_dir, "checkpoint-" + str(checkpoint_sn)
|
||||
)
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
print(auto_resume_from_checkpoint,
|
||||
"The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
|
||||
print(
|
||||
auto_resume_from_checkpoint,
|
||||
"The specified checkpoint sn("
|
||||
+ auto_resume_from_checkpoint
|
||||
+ ") has not been saved. Please search for the correct checkpoint in the model output directory",
|
||||
)
|
||||
|
||||
if test_dataset is not None:
|
||||
trainer.predict(test_dataset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
|
@ -29,51 +29,43 @@ from datasets import load_dataset, DatasetDict, NamedSplit
|
|||
from typing import Optional
|
||||
from PIL import Image
|
||||
|
||||
# For Ascend NPU, please add this
|
||||
# import torch_npu
|
||||
# from torch_npu.contrib import transfer_to_npu
|
||||
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
img = Image.new('L', (224, 224), 0).convert('RGB')
|
||||
img = Image.new("L", (224, 224), 0).convert("RGB")
|
||||
|
||||
|
||||
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
|
||||
output_ids = (
|
||||
[feature["output_ids"] for feature in features]
|
||||
if "output_ids" in features[0].keys()
|
||||
else None
|
||||
)
|
||||
if output_ids is not None:
|
||||
max_output_length = max(len(out) for out in output_ids)
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_output_length = (
|
||||
(
|
||||
max_output_length + self.pad_to_multiple_of - 1) //
|
||||
self.pad_to_multiple_of * self.pad_to_multiple_of
|
||||
(max_output_length + self.pad_to_multiple_of - 1)
|
||||
// self.pad_to_multiple_of
|
||||
* self.pad_to_multiple_of
|
||||
)
|
||||
for feature in features:
|
||||
remainder = [self.tokenizer.pad_token_id] * (
|
||||
max_output_length - len(feature['output_ids'])
|
||||
max_output_length - len(feature["output_ids"])
|
||||
)
|
||||
if isinstance(feature['output_ids'], list):
|
||||
feature['output_ids'] = feature['output_ids'] + remainder
|
||||
if isinstance(feature["output_ids"], list):
|
||||
feature["output_ids"] = feature["output_ids"] + remainder
|
||||
else:
|
||||
feature['output_ids'] = np.concatenate(
|
||||
[feature['output_ids'], remainder]
|
||||
feature["output_ids"] = np.concatenate(
|
||||
[feature["output_ids"], remainder]
|
||||
).astype(np.int64)
|
||||
return super().__call__(features, return_tensors)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(_Seq2SeqTrainer):
|
||||
# Not Support for apex. transformers>=4.46 require additional args: num_items_in_batch
|
||||
def training_step(self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch=None) -> torch.Tensor:
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
self.accelerator.backward(loss)
|
||||
detached_loss = loss.detach() / self.args.gradient_accumulation_steps
|
||||
del inputs
|
||||
torch.cuda.empty_cache()
|
||||
return detached_loss
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
@ -82,20 +74,19 @@ class Seq2SeqTrainer(_Seq2SeqTrainer):
|
|||
ignore_keys=None,
|
||||
**gen_kwargs,
|
||||
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
if self.args.predict_with_generate:
|
||||
output_ids = inputs.pop('output_ids', None)
|
||||
output_ids = inputs.pop("output_ids", None)
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
prediction_loss_only=prediction_loss_only,
|
||||
ignore_keys=ignore_keys,
|
||||
**gen_kwargs
|
||||
**gen_kwargs,
|
||||
)
|
||||
|
||||
if generated_tokens is not None:
|
||||
generated_tokens = generated_tokens[:, inputs["input_ids"].size()[1]:]
|
||||
generated_tokens = generated_tokens[:, inputs["input_ids"].size()[1] :]
|
||||
|
||||
if self.args.predict_with_generate:
|
||||
labels = output_ids
|
||||
|
@ -139,14 +130,14 @@ class FinetuningConfig(object):
|
|||
freezeV: bool
|
||||
|
||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output")
|
||||
)
|
||||
peft_config: Optional[PeftConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.training_args.do_eval or self.data_config.val_file is None:
|
||||
self.training_args.do_eval = False
|
||||
self.training_args.evaluation_strategy = 'no'
|
||||
self.training_args.evaluation_strategy = "no"
|
||||
self.data_config.val_file = None
|
||||
else:
|
||||
self.training_args.per_device_eval_batch_size = (
|
||||
|
@ -155,31 +146,29 @@ class FinetuningConfig(object):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
|
||||
training_args = kwargs.get('training_args', None)
|
||||
def from_dict(cls, **kwargs) -> "FinetuningConfig":
|
||||
training_args = kwargs.get("training_args", None)
|
||||
if training_args is not None and not isinstance(
|
||||
training_args, Seq2SeqTrainingArguments
|
||||
):
|
||||
gen_config = training_args.get('generation_config')
|
||||
gen_config = training_args.get("generation_config")
|
||||
if not isinstance(gen_config, GenerationConfig):
|
||||
training_args['generation_config'] = GenerationConfig(
|
||||
**gen_config
|
||||
)
|
||||
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
|
||||
training_args["generation_config"] = GenerationConfig(**gen_config)
|
||||
kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args)
|
||||
|
||||
data_config = kwargs.get('data_config')
|
||||
data_config = kwargs.get("data_config")
|
||||
if not isinstance(data_config, DataConfig):
|
||||
kwargs['data_config'] = DataConfig(**data_config)
|
||||
kwargs["data_config"] = DataConfig(**data_config)
|
||||
|
||||
peft_config = kwargs.get('peft_config', None)
|
||||
peft_config = kwargs.get("peft_config", None)
|
||||
if peft_config is not None and not isinstance(peft_config, PeftConfig):
|
||||
kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
|
||||
kwargs["peft_config"] = get_peft_config(config_dict=peft_config)
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
|
||||
def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig":
|
||||
path = Path(path)
|
||||
parser = yaml.YAML(typ='safe', pure=True)
|
||||
parser = yaml.YAML(typ="safe", pure=True)
|
||||
parser.indent(mapping=2, offset=2, sequence=4)
|
||||
parser.default_flow_style = False
|
||||
kwargs = parser.load(path)
|
||||
|
@ -192,7 +181,7 @@ def _load_datasets(
|
|||
data_files: dict[NamedSplit, str],
|
||||
num_proc: Optional[int],
|
||||
) -> DatasetDict:
|
||||
if data_format == '.jsonl':
|
||||
if data_format == ".jsonl":
|
||||
dataset_dct = load_dataset(
|
||||
data_dir,
|
||||
data_files=data_files,
|
||||
|
@ -251,7 +240,7 @@ def process_batch(
|
|||
max_output_length: int,
|
||||
combine: bool,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_conv = batch["messages"]
|
||||
batched_input_ids = []
|
||||
batched_attention_mask = []
|
||||
batched_position_ids = []
|
||||
|
@ -267,25 +256,25 @@ def process_batch(
|
|||
loss_masks = [False, False]
|
||||
images = []
|
||||
|
||||
if conv[0].get('image'):
|
||||
conv[0]['image'] = Image.open(conv[0]['image']).convert('RGB')
|
||||
if conv[0].get("image"):
|
||||
conv[0]["image"] = Image.open(conv[0]["image"]).convert("RGB")
|
||||
else:
|
||||
conv[0]['image'] = img
|
||||
conv[0]["image"] = img
|
||||
|
||||
for message in conv:
|
||||
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids_all = tokenizer.apply_chat_template(
|
||||
[message],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
padding=True
|
||||
loss_mask_val = (
|
||||
False if message["role"] in ("system", "user", "observation") else True
|
||||
)
|
||||
new_input_ids = new_input_ids_all['input_ids'][0][2:]
|
||||
new_attention_mask = new_input_ids_all['attention_mask'][0][2:]
|
||||
new_position_ids = list(range(position_ids[-1] + 1, position_ids[-1] + 1 + len(new_input_ids)))
|
||||
if message.get('image'): # Only One Image
|
||||
images.append(new_input_ids_all['images'])
|
||||
new_input_ids_all = tokenizer.apply_chat_template(
|
||||
[message], tokenize=True, return_dict=True, padding=True
|
||||
)
|
||||
new_input_ids = new_input_ids_all["input_ids"][0][2:]
|
||||
new_attention_mask = new_input_ids_all["attention_mask"][0][2:]
|
||||
new_position_ids = list(
|
||||
range(position_ids[-1] + 1, position_ids[-1] + 1 + len(new_input_ids))
|
||||
)
|
||||
if message.get("image"): # Only One Image
|
||||
images.append(new_input_ids_all["images"])
|
||||
|
||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
input_ids += new_input_ids
|
||||
|
@ -311,15 +300,28 @@ def process_batch(
|
|||
batched_labels.append(labels[:max_length])
|
||||
batched_images.append(images[0][0])
|
||||
|
||||
del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
|
||||
del (
|
||||
batched_conv,
|
||||
conv,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
loss_masks,
|
||||
message,
|
||||
new_input_ids,
|
||||
new_loss_masks,
|
||||
labels,
|
||||
input_id,
|
||||
mask,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
'input_ids': batched_input_ids,
|
||||
'attention_mask': batched_attention_mask,
|
||||
'position_ids': batched_position_ids,
|
||||
'labels': batched_labels,
|
||||
'images': batched_images
|
||||
"input_ids": batched_input_ids,
|
||||
"attention_mask": batched_attention_mask,
|
||||
"position_ids": batched_position_ids,
|
||||
"labels": batched_labels,
|
||||
"images": batched_images,
|
||||
}
|
||||
|
||||
|
||||
|
@ -330,7 +332,7 @@ def process_batch_eval(
|
|||
max_output_length: int,
|
||||
combine: bool,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_conv = batch["messages"]
|
||||
batched_input_ids = []
|
||||
batched_attention_mask = []
|
||||
batched_position_ids = []
|
||||
|
@ -338,22 +340,18 @@ def process_batch_eval(
|
|||
batched_images = []
|
||||
|
||||
for conv in batched_conv:
|
||||
|
||||
if conv[0].get('image'):
|
||||
image = Image.open(conv[0]['image']).convert('RGB')
|
||||
if conv[0].get("image"):
|
||||
image = Image.open(conv[0]["image"]).convert("RGB")
|
||||
else:
|
||||
image = img
|
||||
|
||||
conv[0]['image'] = image
|
||||
conv[0]["image"] = image
|
||||
new_input_ids_all = tokenizer.apply_chat_template(
|
||||
conv,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
padding=True
|
||||
conv, tokenize=True, return_dict=True, padding=True
|
||||
)
|
||||
|
||||
input_ids = new_input_ids_all['input_ids'][0]
|
||||
attention_mask = new_input_ids_all['attention_mask'][0]
|
||||
input_ids = new_input_ids_all["input_ids"][0]
|
||||
attention_mask = new_input_ids_all["attention_mask"][0]
|
||||
position_ids = list(range(len(input_ids)))
|
||||
|
||||
dialogue_parts = [0]
|
||||
|
@ -366,27 +364,36 @@ def process_batch_eval(
|
|||
|
||||
# Split the conversation into multiple dialogue segments
|
||||
for end_idx in range(1, len(dialogue_parts)):
|
||||
input_segment = input_ids[:dialogue_parts[end_idx]]
|
||||
attention_segment = attention_mask[:dialogue_parts[end_idx]]
|
||||
position_segment = position_ids[:dialogue_parts[end_idx]]
|
||||
output_segment = input_ids[dialogue_parts[end_idx - 1]:dialogue_parts[end_idx]]
|
||||
input_segment = input_ids[: dialogue_parts[end_idx]]
|
||||
attention_segment = attention_mask[: dialogue_parts[end_idx]]
|
||||
position_segment = position_ids[: dialogue_parts[end_idx]]
|
||||
output_segment = input_ids[
|
||||
dialogue_parts[end_idx - 1] : dialogue_parts[end_idx]
|
||||
]
|
||||
output_segment.append(151336) # Add EOS token
|
||||
|
||||
batched_input_ids.append(input_segment[:max_input_length])
|
||||
batched_attention_mask.append(attention_segment[:max_input_length])
|
||||
batched_position_ids.append(position_segment[:max_input_length])
|
||||
batched_output_ids.append(output_segment[:max_output_length])
|
||||
batched_images.append(new_input_ids_all['images'][0])
|
||||
batched_images.append(new_input_ids_all["images"][0])
|
||||
|
||||
del batched_conv, input_ids, attention_mask, position_ids, new_input_ids_all, output_segment
|
||||
del (
|
||||
batched_conv,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
new_input_ids_all,
|
||||
output_segment,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
'input_ids': batched_input_ids,
|
||||
'attention_mask': batched_attention_mask,
|
||||
'position_ids': batched_position_ids,
|
||||
'output_ids': batched_output_ids,
|
||||
'images': batched_images
|
||||
"input_ids": batched_input_ids,
|
||||
"attention_mask": batched_attention_mask,
|
||||
"position_ids": batched_position_ids,
|
||||
"output_ids": batched_output_ids,
|
||||
"images": batched_images,
|
||||
}
|
||||
|
||||
|
||||
|
@ -394,13 +401,15 @@ def load_tokenizer_and_model(
|
|||
model_dir: str,
|
||||
peft_config: Optional[PeftConfig] = None,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side='left', trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir, padding_side="left", trust_remote_code=True
|
||||
)
|
||||
if peft_config is not None:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=True,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.bfloat16 # Must use BFloat 16
|
||||
torch_dtype=torch.bfloat16, # Must use BFloat 16
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
@ -409,47 +418,54 @@ def load_tokenizer_and_model(
|
|||
model_dir,
|
||||
trust_remote_code=True,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
|
||||
batched_pred_ids, batched_label_ids = eval_preds
|
||||
batched_pred_ids[batched_pred_ids==-100] = tokenizer.pad_token_id
|
||||
batched_label_ids[batched_label_ids==-100] = tokenizer.pad_token_id
|
||||
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
||||
batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
|
||||
batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
|
||||
metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
||||
pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
|
||||
label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
|
||||
pred_txt = tokenizer.decode(pred_ids).strip()
|
||||
label_txt = tokenizer.decode(label_ids).strip()
|
||||
pred_tokens = list(jieba.cut(pred_txt))
|
||||
label_tokens = list(jieba.cut(label_txt))
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
|
||||
scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
|
||||
for k, v in scores[0].items():
|
||||
metrics_dct[k].append(round(v['f'] * 100, 4))
|
||||
metrics_dct['bleu-4'].append(
|
||||
sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
|
||||
metrics_dct[k].append(round(v["f"] * 100, 4))
|
||||
metrics_dct["bleu-4"].append(
|
||||
sentence_bleu(
|
||||
[label_tokens],
|
||||
pred_tokens,
|
||||
smoothing_function=SmoothingFunction().method3,
|
||||
)
|
||||
)
|
||||
return {k: np.mean(v) for k, v in metrics_dct.items()}
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
data_dir: Annotated[str, typer.Argument(help='')],
|
||||
data_dir: Annotated[str, typer.Argument(help="")],
|
||||
model_dir: Annotated[
|
||||
str,
|
||||
typer.Argument(
|
||||
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
|
||||
help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file."
|
||||
),
|
||||
],
|
||||
config_file: Annotated[str, typer.Argument(help='')],
|
||||
config_file: Annotated[str, typer.Argument(help="")],
|
||||
auto_resume_from_checkpoint: str = typer.Argument(
|
||||
default='',
|
||||
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
|
||||
default="",
|
||||
help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training",
|
||||
),
|
||||
):
|
||||
ft_config = FinetuningConfig.from_file(config_file)
|
||||
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||
tokenizer, model = load_tokenizer_and_model(
|
||||
model_dir, peft_config=ft_config.peft_config
|
||||
)
|
||||
|
||||
if ft_config.freezeV:
|
||||
for param in model.transformer.vision.parameters():
|
||||
|
@ -467,7 +483,7 @@ def main(
|
|||
),
|
||||
batched=True,
|
||||
)
|
||||
print('train_dataset:', train_dataset)
|
||||
print("train_dataset:", train_dataset)
|
||||
|
||||
val_dataset = data_manager.get_dataset(
|
||||
Split.VALIDATION,
|
||||
|
@ -477,13 +493,12 @@ def main(
|
|||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
|
||||
),
|
||||
batched=True,
|
||||
)
|
||||
|
||||
if val_dataset is not None:
|
||||
print('val_dataset:', val_dataset)
|
||||
print("val_dataset:", val_dataset)
|
||||
test_dataset = data_manager.get_dataset(
|
||||
Split.TEST,
|
||||
functools.partial(
|
||||
|
@ -496,25 +511,18 @@ def main(
|
|||
batched=True,
|
||||
)
|
||||
if test_dataset is not None:
|
||||
print('test_dataset:', test_dataset)
|
||||
print("test_dataset:", test_dataset)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
|
||||
ft_config.training_args.generation_config.pad_token_id = (
|
||||
151329
|
||||
)
|
||||
ft_config.training_args.generation_config.eos_token_id = [
|
||||
151329, 151336, 151338
|
||||
]
|
||||
ft_config.training_args.generation_config.pad_token_id = 151329
|
||||
ft_config.training_args.generation_config.eos_token_id = [151329, 151336, 151338]
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=ft_config.training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
padding="longest",
|
||||
return_tensors="pt",
|
||||
),
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
|
@ -536,7 +544,9 @@ def main(
|
|||
if checkpoint_sn > 0:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
checkpoint_directory = os.path.join(
|
||||
output_dir, "checkpoint-" + str(checkpoint_sn)
|
||||
)
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
|
@ -547,16 +557,22 @@ def main(
|
|||
checkpoint_sn = int(auto_resume_from_checkpoint)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
checkpoint_directory = os.path.join(
|
||||
output_dir, "checkpoint-" + str(checkpoint_sn)
|
||||
)
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
print(auto_resume_from_checkpoint,
|
||||
"The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
|
||||
print(
|
||||
auto_resume_from_checkpoint,
|
||||
"The specified checkpoint sn("
|
||||
+ auto_resume_from_checkpoint
|
||||
+ ") has not been saved. Please search for the correct checkpoint in the model output directory",
|
||||
)
|
||||
|
||||
if test_dataset is not None:
|
||||
trainer.predict(test_dataset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
|
@ -6,7 +6,6 @@ from transformers import (
|
|||
AutoModel,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
@ -16,47 +15,49 @@ def load_model_and_tokenizer(
|
|||
model_dir: Union[str, Path], trust_remote_code: bool = True
|
||||
):
|
||||
model_dir = Path(model_dir).expanduser().resolve()
|
||||
if (model_dir / 'adapter_config.json').exists():
|
||||
if (model_dir / "adapter_config.json").exists():
|
||||
import json
|
||||
with open(model_dir / 'adapter_config.json', 'r', encoding='utf-8') as file:
|
||||
|
||||
with open(model_dir / "adapter_config.json", "r", encoding="utf-8") as file:
|
||||
config = json.load(file)
|
||||
model = AutoModel.from_pretrained(
|
||||
config.get('base_model_name_or_path'),
|
||||
config.get("base_model_name_or_path"),
|
||||
trust_remote_code=trust_remote_code,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.bfloat16
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model = PeftModelForCausalLM.from_pretrained(
|
||||
model=model,
|
||||
model_id=model_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
||||
tokenizer_dir = model.peft_config["default"].base_model_name_or_path
|
||||
else:
|
||||
model = AutoModel.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.bfloat16
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
tokenizer_dir = model_dir
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
encode_special_tokens=True,
|
||||
use_fast=False
|
||||
use_fast=False,
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
model_dir: Annotated[str, typer.Argument(help='')],
|
||||
model_dir: Annotated[str, typer.Argument(help="")],
|
||||
):
|
||||
# For GLM-4 Finetune Without Tools
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": "#裙子#夏天",
|
||||
"role": "user",
|
||||
"content": "#裙子#夏天",
|
||||
}
|
||||
]
|
||||
|
||||
|
@ -119,7 +120,7 @@ def main(
|
|||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
generate_kwargs = {
|
||||
"max_new_tokens": 1024,
|
||||
|
@ -130,10 +131,12 @@ def main(
|
|||
"eos_token_id": model.config.eos_token_id,
|
||||
}
|
||||
outputs = model.generate(**inputs, **generate_kwargs)
|
||||
response = tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True).strip()
|
||||
response = tokenizer.decode(
|
||||
outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
).strip()
|
||||
print("=========")
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
Loading…
Reference in New Issue