update finetune demo
This commit is contained in:
parent
bab384d193
commit
b475ebe8ae
|
@ -19,6 +19,7 @@ from sse_starlette.sse import EventSourceResponse
|
|||
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
import os
|
||||
|
||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
||||
MAX_MODEL_LENGTH = 8192
|
||||
|
||||
|
@ -444,7 +445,7 @@ async def predict_stream(model_id, gen_params):
|
|||
function_name = None
|
||||
response_id = generate_id('chatcmpl-', 29)
|
||||
system_fingerprint = generate_id('fp_', 9)
|
||||
tools = {tool['function']['name'] for tool in gen_params['tools']}
|
||||
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else None
|
||||
async for new_response in generate_stream_glm4(gen_params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text = decoded_unicode[len(output):]
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import jieba
|
||||
import dataclasses as dc
|
||||
import functools
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
from typing import Annotated, Any, Union
|
||||
import numpy as np
|
||||
import ruamel.yaml as yaml
|
||||
import torch
|
||||
import typer
|
||||
from datasets import Dataset, NamedSplit, Split
|
||||
from datasets import Dataset, Split
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
from peft import PeftConfig, get_peft_config, get_peft_model
|
||||
from rouge_chinese import Rouge
|
||||
|
@ -26,6 +25,8 @@ from transformers import (
|
|||
)
|
||||
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
|
||||
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
|
||||
from datasets import load_dataset, DatasetDict, NamedSplit
|
||||
from typing import Optional
|
||||
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
||||
|
@ -55,6 +56,17 @@ class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
|||
|
||||
|
||||
class Seq2SeqTrainer(_Seq2SeqTrainer):
|
||||
def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
loss = self.compute_loss(model, inputs)
|
||||
if self.args.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
del inputs
|
||||
torch.cuda.empty_cache()
|
||||
return loss.detach()
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
@ -63,14 +75,21 @@ class Seq2SeqTrainer(_Seq2SeqTrainer):
|
|||
ignore_keys=None,
|
||||
**gen_kwargs,
|
||||
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.args.predict_with_generate:
|
||||
output_ids = inputs.pop('output_ids')
|
||||
input_ids = inputs['input_ids']
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
|
||||
)
|
||||
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
|
||||
labels = output_ids
|
||||
with torch.no_grad(): # Ensure no gradient computation
|
||||
if self.args.predict_with_generate:
|
||||
output_ids = inputs.pop('output_ids')
|
||||
input_ids = inputs['input_ids']
|
||||
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
|
||||
)
|
||||
|
||||
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
|
||||
labels = output_ids
|
||||
|
||||
del inputs, input_ids, output_ids
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return loss, generated_tokens, labels
|
||||
|
||||
|
||||
|
@ -127,7 +146,6 @@ class FinetuningConfig(object):
|
|||
training_args, Seq2SeqTrainingArguments
|
||||
):
|
||||
gen_config = training_args.get('generation_config')
|
||||
# TODO: a bit hacky
|
||||
if not isinstance(gen_config, GenerationConfig):
|
||||
training_args['generation_config'] = GenerationConfig(
|
||||
**gen_config
|
||||
|
@ -153,10 +171,6 @@ class FinetuningConfig(object):
|
|||
return cls.from_dict(**kwargs)
|
||||
|
||||
|
||||
from datasets import load_dataset, DatasetDict, NamedSplit
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _load_datasets(
|
||||
data_dir: str,
|
||||
data_format: str,
|
||||
|
@ -239,7 +253,13 @@ def process_batch(
|
|||
loss_masks = [False, False]
|
||||
for message in conv:
|
||||
message = process_message(message)
|
||||
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
|
||||
# New Code With Using apply_chat_template in jinjia template in tokenizer_config.json
|
||||
# new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)
|
||||
|
||||
# Old Code With Using apply_chat_template in tokenization_chatglm.py
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
input_ids += new_input_ids
|
||||
|
@ -255,6 +275,10 @@ def process_batch(
|
|||
max_length = max_input_length + max_output_length + 1
|
||||
batched_input_ids.append(input_ids[:max_length])
|
||||
batched_labels.append(labels[:max_length])
|
||||
|
||||
del batched_conv, conv, input_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {'input_ids': batched_input_ids, 'labels': batched_labels}
|
||||
|
||||
|
||||
|
@ -269,13 +293,17 @@ def process_batch_eval(
|
|||
batched_output_ids = []
|
||||
|
||||
for conv in batched_conv:
|
||||
|
||||
input_ids = [151331, 151333]
|
||||
for message in conv:
|
||||
if len(input_ids) >= max_input_length:
|
||||
break
|
||||
else:
|
||||
message = process_message(message)
|
||||
|
||||
# New Code With Using apply_chat_template in jinjia template in tokenizer_config.json
|
||||
# new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)
|
||||
|
||||
# Old Code With Using apply_chat_template in tokenization_chatglm.py
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
if message['role'] == 'assistant':
|
||||
output_prompt, output_ids = (
|
||||
|
@ -288,6 +316,10 @@ def process_batch_eval(
|
|||
)
|
||||
batched_output_ids.append(output_ids[:max_output_length])
|
||||
input_ids += new_input_ids
|
||||
|
||||
del batched_conv, conv, input_ids, message, new_input_ids, output_prompt, output_ids
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
|
||||
|
||||
|
||||
|
@ -348,7 +380,6 @@ def main(
|
|||
default='',
|
||||
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
|
||||
),
|
||||
|
||||
):
|
||||
ft_config = FinetuningConfig.from_file(config_file)
|
||||
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||
|
@ -422,7 +453,7 @@ def main(
|
|||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
trainer.train()
|
||||
|
@ -433,7 +464,7 @@ def main(
|
|||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
print(auto_resume_from_checkpoint,
|
||||
|
|
Loading…
Reference in New Issue