From 081f0fb96c4cfa6390797c2e214282e1533c9738 Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Thu, 20 Jun 2024 01:24:00 +0800
Subject: [PATCH] fix for ds and zero3 error

---
 finetune_demo/README.md    |  8 ++++----
 finetune_demo/README_en.md |  8 ++++----
 finetune_demo/finetune.py  | 31 ++++++++++++++-----------------
 3 files changed, 22 insertions(+), 25 deletions(-)

diff --git a/finetune_demo/README.md b/finetune_demo/README.md
index 0fe5fb4..64ef3f6 100644
--- a/finetune_demo/README.md
+++ b/finetune_demo/README.md
@@ -22,15 +22,15 @@ Read this in [English](README_en.md)
 | p-tuning v2 (PEFT) | 21381MiB                          | 121M    |
 | SFT (Zero3 method) | 80935MiB<br/>(Each GPU,需要使用8张GPU) | 20G     |
 
-在开始微调之前,请你先安装`basic_demo`中的依赖,同时您需要安装本目录下的依赖项:
-
-> NOTE: NLTK 3.8.1 部分代码可能尚未对 python 3.12
-> 进行适配,该情况下的适配方法可参考[issues #38](https://github.com/THUDM/GLM-4/issues/38)
+在开始微调之前,请你先安装 `basic_demo` 中的依赖,并保证克隆了最新版本的模型仓库,同时您需要安装本目录下的依赖项:
 
 ```bash
 pip install -r requirements.txt
 ```
 
+> NOTE: NLTK 3.8.1 部分代码可能尚未对 python 3.12
+> 进行适配,该情况下的适配方法可参考[issues #38](https://github.com/THUDM/GLM-4/issues/38)
+> 
 ## 多轮对话格式
 
 多轮对话微调示例采用 GLM-4 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。
diff --git a/finetune_demo/README_en.md b/finetune_demo/README_en.md
index 206953f..8e4f82a 100644
--- a/finetune_demo/README_en.md
+++ b/finetune_demo/README_en.md
@@ -24,16 +24,16 @@ Test hardware information:
 | p-tuning v2 (PEFT)   | 21381MiB                                     | 121M                   |
 | SFT (Zero3 method)   | 80935MiB<br/>(Each GPU, 8 GPUs are required) | 20G                    |
 
-Before starting fine-tuning, please install the dependencies in `basic_demo` first. You also need to install the
+Before starting fine-tuning, please install the dependencies in `basic_demo` and clone the latest model repos (Hugging Face) first. You also need to install the
 dependencies in this directory:
 
-> NOTE: Some codes in NLTK 3.8.1 might not yet be compatible with Python 3.12. For adaptation methods in such cases,
-> please refer to [issues #38](https://github.com/THUDM/GLM-4/issues/38).
-
 ```bash
 pip install -r requirements.txt
 ```
 
+> NOTE: Some codes in NLTK 3.8.1 might not yet be compatible with Python 3.12. For adaptation methods in such cases,
+> please refer to [issues #38](https://github.com/THUDM/GLM-4/issues/38).
+
 ## Multi-round dialogue format
 
 The multi-round dialogue fine-tuning example uses the GLM-4 dialogue format convention, adding different `loss_mask` to
diff --git a/finetune_demo/finetune.py b/finetune_demo/finetune.py
index e65eed7..6e67bb3 100644
--- a/finetune_demo/finetune.py
+++ b/finetune_demo/finetune.py
@@ -56,16 +56,22 @@ class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
 
 
 class Seq2SeqTrainer(_Seq2SeqTrainer):
+    # Not Support for apex
     def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
+
         model.train()
         inputs = self._prepare_inputs(inputs)
-        loss = self.compute_loss(model, inputs)
-        if self.args.gradient_accumulation_steps > 1:
-            loss = loss / self.args.gradient_accumulation_steps
-        loss.backward()
+
+        with self.compute_loss_context_manager():
+            loss = self.compute_loss(model, inputs)
+
+        if self.args.n_gpu > 1:
+            loss = loss.mean()
+        self.accelerator.backward(loss)
+        detached_loss = loss.detach() / self.args.gradient_accumulation_steps
         del inputs
         torch.cuda.empty_cache()
-        return loss.detach()
+        return detached_loss
 
     def prediction_step(
             self,
@@ -75,6 +81,7 @@ class Seq2SeqTrainer(_Seq2SeqTrainer):
             ignore_keys=None,
             **gen_kwargs,
     ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+
         with torch.no_grad():  # Ensure no gradient computation
             if self.args.predict_with_generate:
                 output_ids = inputs.pop('output_ids')
@@ -255,12 +262,7 @@ def process_batch(
             message = process_message(message)
 
             loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
-
-            # New Code With Using apply_chat_template in jinjia template in tokenizer_config.json
-            # new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)
-
-            # Old Code With Using apply_chat_template in tokenization_chatglm.py
-            new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
+            new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
             new_loss_masks = [loss_mask_val] * len(new_input_ids)
             input_ids += new_input_ids
             loss_masks += new_loss_masks
@@ -299,12 +301,7 @@ def process_batch_eval(
                 break
             else:
                 message = process_message(message)
-
-                # New Code With Using apply_chat_template in jinjia template in tokenizer_config.json
-                # new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)
-
-                # Old Code With Using apply_chat_template in tokenization_chatglm.py
-                new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
+                new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                 if message['role'] == 'assistant':
                     output_prompt, output_ids = (
                         new_input_ids[:1],