fix for ds and zero3 error
This commit is contained in:
parent
b475ebe8ae
commit
081f0fb96c
|
@ -22,15 +22,15 @@ Read this in [English](README_en.md)
|
|||
| p-tuning v2 (PEFT) | 21381MiB | 121M |
|
||||
| SFT (Zero3 method) | 80935MiB<br/>(Each GPU,需要使用8张GPU) | 20G |
|
||||
|
||||
在开始微调之前,请你先安装`basic_demo`中的依赖,同时您需要安装本目录下的依赖项:
|
||||
|
||||
> NOTE: NLTK 3.8.1 部分代码可能尚未对 python 3.12
|
||||
> 进行适配,该情况下的适配方法可参考[issues #38](https://github.com/THUDM/GLM-4/issues/38)
|
||||
在开始微调之前,请你先安装 `basic_demo` 中的依赖,并保证克隆了最新版本的模型仓库,同时您需要安装本目录下的依赖项:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
> NOTE: NLTK 3.8.1 部分代码可能尚未对 python 3.12
|
||||
> 进行适配,该情况下的适配方法可参考[issues #38](https://github.com/THUDM/GLM-4/issues/38)
|
||||
>
|
||||
## 多轮对话格式
|
||||
|
||||
多轮对话微调示例采用 GLM-4 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。
|
||||
|
|
|
@ -24,16 +24,16 @@ Test hardware information:
|
|||
| p-tuning v2 (PEFT) | 21381MiB | 121M |
|
||||
| SFT (Zero3 method) | 80935MiB<br/>(Each GPU, 8 GPUs are required) | 20G |
|
||||
|
||||
Before starting fine-tuning, please install the dependencies in `basic_demo` first. You also need to install the
|
||||
Before starting fine-tuning, please install the dependencies in `basic_demo` and clone the latest model repos (Hugging Face) first. You also need to install the
|
||||
dependencies in this directory:
|
||||
|
||||
> NOTE: Some codes in NLTK 3.8.1 might not yet be compatible with Python 3.12. For adaptation methods in such cases,
|
||||
> please refer to [issues #38](https://github.com/THUDM/GLM-4/issues/38).
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
> NOTE: Some codes in NLTK 3.8.1 might not yet be compatible with Python 3.12. For adaptation methods in such cases,
|
||||
> please refer to [issues #38](https://github.com/THUDM/GLM-4/issues/38).
|
||||
|
||||
## Multi-round dialogue format
|
||||
|
||||
The multi-round dialogue fine-tuning example uses the GLM-4 dialogue format convention, adding different `loss_mask` to
|
||||
|
|
|
@ -56,16 +56,22 @@ class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
|||
|
||||
|
||||
class Seq2SeqTrainer(_Seq2SeqTrainer):
|
||||
# Not Support for apex
|
||||
def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
loss = self.compute_loss(model, inputs)
|
||||
if self.args.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
self.accelerator.backward(loss)
|
||||
detached_loss = loss.detach() / self.args.gradient_accumulation_steps
|
||||
del inputs
|
||||
torch.cuda.empty_cache()
|
||||
return loss.detach()
|
||||
return detached_loss
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
|
@ -75,6 +81,7 @@ class Seq2SeqTrainer(_Seq2SeqTrainer):
|
|||
ignore_keys=None,
|
||||
**gen_kwargs,
|
||||
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad(): # Ensure no gradient computation
|
||||
if self.args.predict_with_generate:
|
||||
output_ids = inputs.pop('output_ids')
|
||||
|
@ -255,12 +262,7 @@ def process_batch(
|
|||
message = process_message(message)
|
||||
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
|
||||
# New Code With Using apply_chat_template in jinjia template in tokenizer_config.json
|
||||
# new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)
|
||||
|
||||
# Old Code With Using apply_chat_template in tokenization_chatglm.py
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
input_ids += new_input_ids
|
||||
loss_masks += new_loss_masks
|
||||
|
@ -299,12 +301,7 @@ def process_batch_eval(
|
|||
break
|
||||
else:
|
||||
message = process_message(message)
|
||||
|
||||
# New Code With Using apply_chat_template in jinjia template in tokenizer_config.json
|
||||
# new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)
|
||||
|
||||
# Old Code With Using apply_chat_template in tokenization_chatglm.py
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
if message['role'] == 'assistant':
|
||||
output_prompt, output_ids = (
|
||||
new_input_ids[:1],
|
||||
|
|
Loading…
Reference in New Issue