update finetune demo

This commit is contained in:
zR 2024-06-19 11:22:31 +08:00
parent bab384d193
commit b475ebe8ae
2 changed files with 53 additions and 21 deletions

View File

@ -19,6 +19,7 @@ from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
import os
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
MAX_MODEL_LENGTH = 8192
@ -444,7 +445,7 @@ async def predict_stream(model_id, gen_params):
function_name = None
response_id = generate_id('chatcmpl-', 29)
system_fingerprint = generate_id('fp_', 9)
tools = {tool['function']['name'] for tool in gen_params['tools']}
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else None
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]

View File

@ -1,17 +1,16 @@
# -*- coding: utf-8 -*-
import json
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Optional, Union
from typing import Annotated, Any, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, NamedSplit, Split
from datasets import Dataset, Split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
@ -26,6 +25,8 @@ from transformers import (
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional
app = typer.Typer(pretty_exceptions_show_locals=False)
@ -55,6 +56,17 @@ class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
class Seq2SeqTrainer(_Seq2SeqTrainer):
def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
model.train()
inputs = self._prepare_inputs(inputs)
loss = self.compute_loss(model, inputs)
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
loss.backward()
del inputs
torch.cuda.empty_cache()
return loss.detach()
def prediction_step(
self,
model: nn.Module,
@ -63,14 +75,21 @@ class Seq2SeqTrainer(_Seq2SeqTrainer):
ignore_keys=None,
**gen_kwargs,
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.args.predict_with_generate:
output_ids = inputs.pop('output_ids')
input_ids = inputs['input_ids']
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
)
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
labels = output_ids
with torch.no_grad(): # Ensure no gradient computation
if self.args.predict_with_generate:
output_ids = inputs.pop('output_ids')
input_ids = inputs['input_ids']
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
)
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
labels = output_ids
del inputs, input_ids, output_ids
torch.cuda.empty_cache()
return loss, generated_tokens, labels
@ -127,7 +146,6 @@ class FinetuningConfig(object):
training_args, Seq2SeqTrainingArguments
):
gen_config = training_args.get('generation_config')
# TODO: a bit hacky
if not isinstance(gen_config, GenerationConfig):
training_args['generation_config'] = GenerationConfig(
**gen_config
@ -153,10 +171,6 @@ class FinetuningConfig(object):
return cls.from_dict(**kwargs)
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional
def _load_datasets(
data_dir: str,
data_format: str,
@ -239,7 +253,13 @@ def process_batch(
loss_masks = [False, False]
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
# New Code With Using apply_chat_template in jinjia template in tokenizer_config.json
# new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)
# Old Code With Using apply_chat_template in tokenization_chatglm.py
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
new_loss_masks = [loss_mask_val] * len(new_input_ids)
input_ids += new_input_ids
@ -255,6 +275,10 @@ def process_batch(
max_length = max_input_length + max_output_length + 1
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
del batched_conv, conv, input_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
torch.cuda.empty_cache()
return {'input_ids': batched_input_ids, 'labels': batched_labels}
@ -269,13 +293,17 @@ def process_batch_eval(
batched_output_ids = []
for conv in batched_conv:
input_ids = [151331, 151333]
for message in conv:
if len(input_ids) >= max_input_length:
break
else:
message = process_message(message)
# New Code With Using apply_chat_template in jinjia template in tokenizer_config.json
# new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)
# Old Code With Using apply_chat_template in tokenization_chatglm.py
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
if message['role'] == 'assistant':
output_prompt, output_ids = (
@ -288,6 +316,10 @@ def process_batch_eval(
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
del batched_conv, conv, input_ids, message, new_input_ids, output_prompt, output_ids
torch.cuda.empty_cache()
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
@ -348,7 +380,6 @@ def main(
default='',
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
),
):
ft_config = FinetuningConfig.from_file(config_file)
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
@ -422,7 +453,7 @@ def main(
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
trainer.train()
@ -433,7 +464,7 @@ def main(
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
print(auto_resume_from_checkpoint,