GLM-4V-9B 微调初步代码
This commit is contained in:
parent
aa6c942c19
commit
b9d6d3863f
|
@ -1,3 +1,4 @@
|
|||
*venv
|
||||
*.DS_Store
|
||||
*.idea/
|
||||
dataset
|
|
@ -11,9 +11,8 @@ Read this in [English](README_en.md)
|
|||
|
||||
## 项目更新
|
||||
|
||||
- 🔥 **News**: ``2024/6/28``: We have worked with the Intel technical team to improve the ITREX and OpenVINO deployment
|
||||
tutorials for GLM-4-9B-Chat. You can use Intel CPU/GPU devices to efficiently deploy the GLM-4-9B open source model.
|
||||
Welcome to [view](intel_device_demo).
|
||||
- 🔥 **News**: ``2024/7/1``: 我们更新了 GLM-4V-9B 的多模态微调,您需要更新我们的模型仓库的运行文件和配置文件, 以支持这个功能,更多微调细节 (例如数据集格式,显存要求),请前往[查看](finetune_demo)。
|
||||
- 🔥 **News**: ``2024/6/28``: 我们与英特尔技术团队合作,改进了 GLM-4-9B-Chat 的 ITREX 和 OpenVINO 部署教程。您可以使用英特尔 CPU/GPU 设备高效部署 GLM-4-9B 开源模型。欢迎访问 [查看](intel_device_demo)。
|
||||
- 🔥 **News**: ``2024/6/24``: 我们更新了模型仓库的运行文件和配置文件,支持 Flash Attention 2,
|
||||
请更新模型配置文件并参考 `basic_demo/trans_cli_demo.py` 中的示例代码。
|
||||
- 🔥 **News**: ``2024/6/19``: 我们更新了模型仓库的运行文件和配置文件,修复了部分已知的模型推理的问题,欢迎大家克隆最新的模型仓库。
|
||||
|
|
|
@ -8,10 +8,11 @@
|
|||
</p>
|
||||
|
||||
## Update
|
||||
- 🔥 **News**: ``2024/6/28``: We have updated the running files and configuration files of the model repository to support Flash Attention 2,
|
||||
- 🔥 **News**: ``2024/6/24``: We have updated the running files and configuration files of the model repository to support Flash Attention 2,
|
||||
Please update the model configuration file and refer to the sample code in `basic_demo/trans_cli_demo.py`.
|
||||
- 🔥🔥 **News**: ``2024/6/19``: We updated the running files and configuration files of the model repository and fixed some model inference issues. Welcome to clone the latest model repository.
|
||||
|
||||
- 🔥 **News**: ``2024/7/1``: We have updated the multimodal fine-tuning of GLM-4V-9B. You need to update the run file and configuration file of our model repository to support this feature. For more fine-tuning details (such as dataset format, video memory requirements), please go to [view](finetune_demo).
|
||||
- 🔥 **News**: ``2024/6/28``: We have worked with the Intel technical team to improve the ITREX and OpenVINO deployment tutorials for GLM-4-9B-Chat. You can use Intel CPU/GPU devices to efficiently deploy the GLM-4-9B open source model. Welcome to [view](intel_device_demo).
|
||||
- 🔥 **News**: ``2024/6/24``: We have updated the running files and configuration files of the model repository to support Flash Attention 2, Please update the model configuration file and refer to the sample code in `basic_demo/trans_cli_demo.py`.
|
||||
- 🔥 **News**: ``2024/6/19``: We updated the running files and configuration files of the model repository and fixed some model inference issues. Welcome to clone the latest model repository.
|
||||
- 🔥 **News**: ``2024/6/18``: We released a [technical report](https://arxiv.org/pdf/2406.12793), welcome to check it out.
|
||||
- 🔥 **News**: ``2024/6/05``: We released the GLM-4-9B series of open source models
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@ Read this in [English](README_en.md)
|
|||
|
||||
## 硬件检查
|
||||
|
||||
**本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。**
|
||||
**本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。微调的资源占用均按照
|
||||
configs 文件夹中的配置文件设置**
|
||||
测试硬件信息:
|
||||
|
||||
+ OS: Ubuntu 22.04
|
||||
|
@ -16,11 +17,15 @@ Read this in [English](README_en.md)
|
|||
+ GPU Driver: 535.104.05
|
||||
+ GPU: NVIDIA A100-SXM4-80GB * 8
|
||||
|
||||
| 微调方案 | 显存占用 | 权重保存点大小 |
|
||||
|--------------------|-----------------------------------|---------|
|
||||
| lora (PEFT) | 21531MiB | 17M |
|
||||
| p-tuning v2 (PEFT) | 21381MiB | 121M |
|
||||
| SFT (Zero3 method) | 80935MiB<br/>(Each GPU,需要使用8张GPU) | 20G |
|
||||
| 微调模型 | 微调方案 | 显存占用 | 权重保存点大小 |
|
||||
|-----------------|-----------------------|----------------------------|-----------|
|
||||
| GLM-4-9B-Chat | lora (PEFT) | 22G | 17M |
|
||||
| GLM-4-9B-Chat | p-tuning v2 (PEFT) | 21G | 121M |
|
||||
| GLM-4-9B-Chat | SFT (Zero3 method) | 80G (Each GPU,需要使用8张GPU) | 20G |
|
||||
| GLM-4V-9B | lora (PEFT), 包含视觉模块 | 75G | 37M |
|
||||
| GLM-4V-9B | SFT | 本代码不支持 | 28G |
|
||||
|
||||
**GLM-4V-9B 微调无法可能正常使用 deepspeed,官方微调脚本仅做最基础的微调方案,更多优化需要开发者自行探索**
|
||||
|
||||
在开始微调之前,请你先安装 `basic_demo` 中的依赖,并保证克隆了最新版本的模型仓库,同时您需要安装本目录下的依赖项:
|
||||
|
||||
|
@ -30,7 +35,8 @@ pip install -r requirements.txt
|
|||
|
||||
> NOTE: NLTK 3.8.1 部分代码可能尚未对 python 3.12
|
||||
> 进行适配,该情况下的适配方法可参考[issues #38](https://github.com/THUDM/GLM-4/issues/38)
|
||||
>
|
||||
>
|
||||
|
||||
## 多轮对话格式
|
||||
|
||||
多轮对话微调示例采用 GLM-4 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。
|
||||
|
@ -97,14 +103,98 @@ pip install -r requirements.txt
|
|||
|
||||
这里是一个不带有工具的例子:
|
||||
|
||||
```
|
||||
{"messages": [{"role": "user", "content": "类型#裤*材质#牛仔布*风格#性感"}, {"role": "assistant", "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。"}]}
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "类型#裤*材质#牛仔布*风格#性感"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
这是一个带有工具调用的例子:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "",
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_recommended_books",
|
||||
"description": "Get recommended books based on user's interests",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"interests": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "The interests to recommend books for"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"interests"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"
|
||||
},
|
||||
{
|
||||
"role": "observation",
|
||||
"content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
{"messages": [{"role": "system", "content": "", "tools": [{"type": "function", "function": {"name": "get_recommended_books", "description": "Get recommended books based on user's interests", "parameters": {"type": "object", "properties": {"interests": {"type": "array", "items": {"type": "string"}, "description": "The interests to recommend books for"}}, "required": ["interests"]}}}]}, {"role": "user", "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."}, {"role": "assistant", "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"}, {"role": "observation", "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"}, {"role": "assistant", "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."}]}
|
||||
|
||||
这是一个视觉VQA微调的例子:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "图片中的动物是什么?",
|
||||
"image": "/root/images/0001.jpg"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "图片中有一只猫。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "图片中的猫在做什么?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "这只猫坐在或站在桌子上,桌上有很多食物。"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user`
|
||||
|
@ -112,6 +202,8 @@ pip install -r requirements.txt
|
|||
- `tools` 字段为可选字段,若存在 `tools` 字段,其必须出现在 `system`
|
||||
角色之后,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `tools` 字段。当 `tools` 字段存在时,`system`
|
||||
角色必须存在并且 `content` 字段为空。
|
||||
- `GLM-4V-9B` 不支持 `tools` 字段和 `system` 字段。并且 `image` 必须放在第一条消息中。 `image`
|
||||
字段需要放置置图片的 `绝对路径`。
|
||||
|
||||
## 配置文件
|
||||
|
||||
|
@ -158,16 +250,18 @@ pip install -r requirements.txt
|
|||
|
||||
## 开始微调
|
||||
|
||||
通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`。
|
||||
通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`。接着,按照此命令运行:
|
||||
|
||||
```shell
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b configs/lora.yaml
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml # For Chat Fine-tune
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_vision.py data/CogVLM-311K/ THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune
|
||||
```
|
||||
|
||||
通过以下代码执行 **单机单卡** 运行。
|
||||
|
||||
```shell
|
||||
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
|
||||
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml # For Chat Fine-tune
|
||||
python finetune.py data/CogVLM-311K/ THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune
|
||||
```
|
||||
|
||||
## 从保存点进行微调
|
||||
|
|
|
@ -6,8 +6,10 @@ not supported). Please strictly follow the steps in the document to avoid unnece
|
|||
## Hardware check
|
||||
|
||||
**The data in this document are tested in the following hardware environment. The actual operating environment
|
||||
requirements and the GPU memory occupied by the operation are slightly different. Please refer to the actual operating
|
||||
environment.**
|
||||
requirements and the video memory occupied by the operation are slightly different. Please refer to the actual operating
|
||||
environment. The fine-tuned resource usage is set according to the configuration file in the
|
||||
configs folder**
|
||||
|
||||
Test hardware information:
|
||||
|
||||
+ OS: Ubuntu 22.04
|
||||
|
@ -18,14 +20,19 @@ Test hardware information:
|
|||
+ GPU Driver: 535.104.05
|
||||
+ GPU: NVIDIA A100-SXM4-80GB * 8
|
||||
|
||||
| Fine-tuning solution | GPU memory usage | Weight save point size |
|
||||
|----------------------|----------------------------------------------|------------------------|
|
||||
| lora (PEFT) | 21531MiB | 17M |
|
||||
| p-tuning v2 (PEFT) | 21381MiB | 121M |
|
||||
| SFT (Zero3 method) | 80935MiB<br/>(Each GPU, 8 GPUs are required) | 20G |
|
||||
| Fine-tuning Model | Fine-tuning solution | GPU memory usage | Weight save point size |
|
||||
|-------------------|------------------------------------|-----------------------------|------------------------|
|
||||
| GLM-4-9B-Chat | lora (PEFT) | 22G | 17M |
|
||||
| GLM-4-9B-Chat | p-tuning v2 (PEFT) | 21G | 121M |
|
||||
| GLM-4-9B-Chat | SFT (Zero3 method) | 80G (Each GPU, Need 8 GPUs) | 20G |
|
||||
| GLM-4V-9B | lora (PEFT), Include EVA2CLIPModel | 75G | 37M |
|
||||
| GLM-4V-9B | SFT | Not Support in this Code | 28G |
|
||||
|
||||
Before starting fine-tuning, please install the dependencies in `basic_demo` and clone the latest model repos (Hugging Face) first. You also need to install the
|
||||
dependencies in this directory:
|
||||
**GLM-4V-9B fine-tuning cannot work properly with deepspeed, the official fine-tuning script only does the most basic
|
||||
fine-tuning solution, more optimizations require developers to explore on their own**
|
||||
|
||||
Before starting fine-tuning, please install the dependencies in `basic_demo` and clone the latest model repos (Hugging
|
||||
Face) first. You also need to install the dependencies in this directory:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
|
@ -99,21 +106,107 @@ For data files, the sample uses the following format:
|
|||
|
||||
This is a sample without tools:
|
||||
|
||||
```
|
||||
{"messages": [{"role": "user", "content": "类型#裤*材质#牛仔布*风格#性感"}, {"role": "assistant", "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。"}]}
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "类型#裤*材质#牛仔布*风格#性感"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质,其柔然的手感和细腻的质地,在穿着舒适的同时,透露着清纯甜美的个性气质。除此之外,流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致,不失为一款随性出街的必备单品。"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
This is a sample with tools:
|
||||
|
||||
```
|
||||
{"messages": [{"role": "system", "content": "", "tools": [{"type": "function", "function": {"name": "get_recommended_books", "description": "Get recommended books based on user's interests", "parameters": {"type": "object", "properties": {"interests": {"type": "array", "items": {"type": "string"}, "description": "The interests to recommend books for"}}, "required": ["interests"]}}}]}, {"role": "user", "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."}, {"role": "assistant", "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"}, {"role": "observation", "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"}, {"role": "assistant", "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."}]}
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "",
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_recommended_books",
|
||||
"description": "Get recommended books based on user's interests",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"interests": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "The interests to recommend books for"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"interests"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"
|
||||
},
|
||||
{
|
||||
"role": "observation",
|
||||
"content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- The `system` role is optional, but if it exists, it must appear before the `user` role, and a complete conversation
|
||||
data (whether single-round or multi-round conversation) can only have one `system` role.
|
||||
- The `tools` field is optional. If it exists, it must appear after the `system` role, and a complete conversation
|
||||
data (whether single-round or multi-round conversation) can only have one `tools` field. When the `tools` field
|
||||
exists, the `system` role must exist and the `content` field is empty.
|
||||
This is a sample with VQA Task:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "图片中的动物是什么?",
|
||||
"image": "/root/images/0001.jpg"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "图片中有一只猫。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "图片中的猫在做什么?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "这只猫坐在或站在桌子上,桌上有很多食物。"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- The `system` role is optional, but if it exists, it must appear before the `user` role, and the `system` role can only
|
||||
appear once in a complete conversation (whether it is a single round or a multi-round conversation).
|
||||
- The `tools` field is optional, but if it exists, it must appear after the `system` role, and the `tools` field can
|
||||
only appear once in a complete conversation (whether it is a single round or a multi-round conversation). When
|
||||
the `tools` field exists, the `system` role must exist and the `content` field is empty.
|
||||
- `GLM-4V-9B` does not support the `tools` field and the `system` field. And `image` must be placed in the first
|
||||
message. The `image` field needs to contain the `absolute path` of the image.
|
||||
|
||||
## Configuration file
|
||||
|
||||
|
@ -123,9 +216,8 @@ The fine-tuning configuration file is located in the `config` directory, includi
|
|||
|
||||
2. `lora.yaml / ptuning_v2
|
||||
3. .yaml / sft.yaml`: Configuration files for different modes of models, including model parameters, optimizer
|
||||
parameters, training parameters, etc. Some important parameters are explained as follows:
|
||||
parameters, training parameters, etc. Some important parameters are explained as follows: + data_config section
|
||||
|
||||
+ data_config section
|
||||
+ train_file: File path of training dataset.
|
||||
+ val_file: File path of validation dataset.
|
||||
+ test_file: File path of test dataset.
|
||||
|
@ -156,8 +248,7 @@ The fine-tuning configuration file is located in the `config` directory, includi
|
|||
+ r: rank of LoRA.
|
||||
+ lora_alpha: scaling factor of LoRA.
|
||||
+ lora_dropout: dropout probability to use in LoRA layer.
|
||||
+ P-TuningV2 parameters:
|
||||
+ num_virtual_tokens: the number of virtual tokens.
|
||||
+ P-TuningV2 parameters: + num_virtual_tokens: the number of virtual tokens.
|
||||
+ num_attention_heads: 2: the number of attention heads of P-TuningV2 (do not change).
|
||||
+ token_dim: 256: the token dimension of P-TuningV2 (do not change).
|
||||
|
||||
|
@ -167,13 +258,15 @@ Execute **single machine multi-card/multi-machine multi-card** run through the f
|
|||
the acceleration solution, and you need to install `deepspeed`.
|
||||
|
||||
```shell
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b configs/lora.yaml
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml # For Chat Fine-tune
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_vision.py data/CogVLM-311K/ THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune
|
||||
```
|
||||
|
||||
Execute **single machine single card** run through the following code.
|
||||
|
||||
```shell
|
||||
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
|
||||
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml # For Chat Fine-tune
|
||||
python finetune.py data/CogVLM-311K/ THUDM/glm-4v-9b configs/lora.yaml # For VQA Fine-tune
|
||||
```
|
||||
|
||||
## Fine-tune from a saved point
|
||||
|
|
|
@ -35,10 +35,11 @@ training_args:
|
|||
generation_config:
|
||||
max_new_tokens: 512
|
||||
# set your absolute deepspeed path here
|
||||
#deepspeed: ds_zero_2.json
|
||||
# deepspeed: configs/ds_zero_3.json
|
||||
peft_config:
|
||||
peft_type: LORA
|
||||
task_type: CAUSAL_LM
|
||||
r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["query_key_value"]
|
|
@ -265,7 +265,7 @@ def process_batch(
|
|||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
input_ids += new_input_ids
|
||||
loss_masks += new_loss_masks
|
||||
input_ids.append(tokenizer.eos_token_id)
|
||||
input_ids.append(151336) # EOS for chat
|
||||
loss_masks = [False, *loss_masks]
|
||||
labels = []
|
||||
for input_id, mask in zip(input_ids, loss_masks):
|
||||
|
@ -306,7 +306,7 @@ def process_batch_eval(
|
|||
new_input_ids[:1],
|
||||
new_input_ids[1:],
|
||||
)
|
||||
output_ids.append(tokenizer.eos_token_id)
|
||||
output_ids.append(151336)
|
||||
batched_input_ids.append(
|
||||
input_ids[:max_input_length] + output_prompt[:1]
|
||||
)
|
||||
|
@ -429,7 +429,7 @@ def main(
|
|||
return_tensors='pt',
|
||||
),
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset.select(list(range(50))),
|
||||
eval_dataset=val_dataset,
|
||||
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,535 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import jieba
|
||||
import dataclasses as dc
|
||||
import functools
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Union
|
||||
import numpy as np
|
||||
import ruamel.yaml as yaml
|
||||
import torch
|
||||
import typer
|
||||
from datasets import Dataset, Split
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
from peft import PeftConfig, get_peft_config, get_peft_model
|
||||
from rouge_chinese import Rouge
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
EvalPrediction,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
Seq2SeqTrainingArguments,
|
||||
)
|
||||
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
|
||||
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
|
||||
from datasets import load_dataset, DatasetDict, NamedSplit
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
|
||||
if output_ids is not None:
|
||||
max_output_length = max(len(out) for out in output_ids)
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_output_length = (
|
||||
(
|
||||
max_output_length + self.pad_to_multiple_of - 1) //
|
||||
self.pad_to_multiple_of * self.pad_to_multiple_of
|
||||
)
|
||||
for feature in features:
|
||||
remainder = [self.tokenizer.pad_token_id] * (
|
||||
max_output_length - len(feature['output_ids'])
|
||||
)
|
||||
if isinstance(feature['output_ids'], list):
|
||||
feature['output_ids'] = feature['output_ids'] + remainder
|
||||
else:
|
||||
feature['output_ids'] = np.concatenate(
|
||||
[feature['output_ids'], remainder]
|
||||
).astype(np.int64)
|
||||
return super().__call__(features, return_tensors)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(_Seq2SeqTrainer):
|
||||
# Not Support for apex
|
||||
def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
self.accelerator.backward(loss)
|
||||
detached_loss = loss.detach() / self.args.gradient_accumulation_steps
|
||||
del inputs
|
||||
torch.cuda.empty_cache()
|
||||
return detached_loss
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict,
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys=None,
|
||||
**gen_kwargs,
|
||||
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
with torch.no_grad():
|
||||
if self.args.predict_with_generate:
|
||||
output_ids = inputs.pop('output_ids', None)
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
prediction_loss_only=prediction_loss_only,
|
||||
ignore_keys=ignore_keys,
|
||||
**gen_kwargs
|
||||
)
|
||||
|
||||
if generated_tokens is not None:
|
||||
generated_tokens = generated_tokens[:, inputs["input_ids"].size()[1]:]
|
||||
|
||||
if self.args.predict_with_generate:
|
||||
labels = output_ids
|
||||
|
||||
del inputs, output_ids
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return loss, generated_tokens, labels
|
||||
|
||||
|
||||
@dc.dataclass
|
||||
class DataConfig(object):
|
||||
train_file: Optional[str] = None
|
||||
val_file: Optional[str] = None
|
||||
test_file: Optional[str] = None
|
||||
num_proc: Optional[int] = None
|
||||
|
||||
@property
|
||||
def data_format(self) -> str:
|
||||
return Path(self.train_file).suffix
|
||||
|
||||
@property
|
||||
def data_files(self) -> dict[NamedSplit, str]:
|
||||
return {
|
||||
split: data_file
|
||||
for split, data_file in zip(
|
||||
[Split.TRAIN, Split.VALIDATION, Split.TEST],
|
||||
[self.train_file, self.val_file, self.test_file],
|
||||
)
|
||||
if data_file is not None
|
||||
}
|
||||
|
||||
|
||||
@dc.dataclass
|
||||
class FinetuningConfig(object):
|
||||
data_config: DataConfig
|
||||
|
||||
max_input_length: int
|
||||
max_output_length: int
|
||||
|
||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||
)
|
||||
peft_config: Optional[PeftConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.training_args.do_eval or self.data_config.val_file is None:
|
||||
self.training_args.do_eval = False
|
||||
self.training_args.evaluation_strategy = 'no'
|
||||
self.data_config.val_file = None
|
||||
else:
|
||||
self.training_args.per_device_eval_batch_size = (
|
||||
self.training_args.per_device_eval_batch_size
|
||||
or self.training_args.per_device_train_batch_size
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
|
||||
training_args = kwargs.get('training_args', None)
|
||||
if training_args is not None and not isinstance(
|
||||
training_args, Seq2SeqTrainingArguments
|
||||
):
|
||||
gen_config = training_args.get('generation_config')
|
||||
if not isinstance(gen_config, GenerationConfig):
|
||||
training_args['generation_config'] = GenerationConfig(
|
||||
**gen_config
|
||||
)
|
||||
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
|
||||
|
||||
data_config = kwargs.get('data_config')
|
||||
if not isinstance(data_config, DataConfig):
|
||||
kwargs['data_config'] = DataConfig(**data_config)
|
||||
|
||||
peft_config = kwargs.get('peft_config', None)
|
||||
if peft_config is not None and not isinstance(peft_config, PeftConfig):
|
||||
kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
|
||||
path = Path(path)
|
||||
parser = yaml.YAML(typ='safe', pure=True)
|
||||
parser.indent(mapping=2, offset=2, sequence=4)
|
||||
parser.default_flow_style = False
|
||||
kwargs = parser.load(path)
|
||||
return cls.from_dict(**kwargs)
|
||||
|
||||
|
||||
def _load_datasets(
|
||||
data_dir: str,
|
||||
data_format: str,
|
||||
data_files: dict[NamedSplit, str],
|
||||
num_proc: Optional[int],
|
||||
) -> DatasetDict:
|
||||
if data_format == '.jsonl':
|
||||
dataset_dct = load_dataset(
|
||||
data_dir,
|
||||
data_files=data_files,
|
||||
split=None,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
|
||||
return dataset_dct
|
||||
|
||||
|
||||
class DataManager(object):
|
||||
def __init__(self, data_dir: str, data_config: DataConfig):
|
||||
self._num_proc = data_config.num_proc
|
||||
|
||||
self._dataset_dct = _load_datasets(
|
||||
data_dir,
|
||||
data_config.data_format,
|
||||
data_config.data_files,
|
||||
self._num_proc,
|
||||
)
|
||||
|
||||
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
|
||||
return self._dataset_dct.get(split, None)
|
||||
|
||||
def get_dataset(
|
||||
self,
|
||||
split: NamedSplit,
|
||||
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||
batched: bool = True,
|
||||
remove_orig_columns: bool = True,
|
||||
) -> Optional[Dataset]:
|
||||
orig_dataset = self._get_dataset(split)
|
||||
if orig_dataset is None:
|
||||
return
|
||||
|
||||
if remove_orig_columns:
|
||||
remove_columns = orig_dataset.column_names
|
||||
else:
|
||||
remove_columns = None
|
||||
return orig_dataset.map(
|
||||
process_fn,
|
||||
batched=batched,
|
||||
remove_columns=remove_columns,
|
||||
num_proc=self._num_proc,
|
||||
)
|
||||
|
||||
|
||||
def process_batch(
|
||||
batch: Mapping[str, Sequence],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int,
|
||||
max_output_length: int,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_input_ids = []
|
||||
batched_attention_mask = []
|
||||
batched_position_ids = []
|
||||
batched_labels = []
|
||||
batched_images = []
|
||||
|
||||
max_length = max_input_length + max_output_length
|
||||
|
||||
for conv in batched_conv:
|
||||
input_ids = [151331, 151333]
|
||||
attention_mask = [1, 1]
|
||||
position_ids = list(range(len(input_ids)))
|
||||
loss_masks = [False, False]
|
||||
images = []
|
||||
|
||||
for message in conv:
|
||||
if message.get('image'):
|
||||
image = Image.open(message['image']).convert('RGB')
|
||||
message['image'] = image
|
||||
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids_all = tokenizer.apply_chat_template(
|
||||
[message],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
padding=True,
|
||||
)
|
||||
new_input_ids = new_input_ids_all['input_ids'][0][2:]
|
||||
new_attention_mask = new_input_ids_all['attention_mask'][0][2:]
|
||||
new_position_ids = list(range(position_ids[-1] + 1, position_ids[-1] + 1 + len(new_input_ids)))
|
||||
if message.get('image'): # Only One Image
|
||||
images.append(new_input_ids_all['images'])
|
||||
|
||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
input_ids += new_input_ids
|
||||
attention_mask += new_attention_mask
|
||||
position_ids += new_position_ids
|
||||
loss_masks += new_loss_masks
|
||||
|
||||
input_ids.append(151336) # EOS
|
||||
attention_mask.append(1)
|
||||
position_ids.append(len(position_ids))
|
||||
loss_masks.append(False)
|
||||
|
||||
labels = []
|
||||
for input_id, mask in zip(input_ids, loss_masks):
|
||||
if mask:
|
||||
labels.append(input_id)
|
||||
else:
|
||||
labels.append(-100)
|
||||
|
||||
batched_input_ids.append(input_ids[:max_length])
|
||||
batched_attention_mask.append(attention_mask[:max_length])
|
||||
batched_position_ids.append(position_ids[:max_length])
|
||||
batched_labels.append(labels[:max_length])
|
||||
if images is not None:
|
||||
batched_images.append(images[0][0])
|
||||
|
||||
del batched_conv, conv, input_ids, attention_mask, position_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
'input_ids': batched_input_ids,
|
||||
'attention_mask': batched_attention_mask,
|
||||
'position_ids': batched_position_ids,
|
||||
'labels': batched_labels,
|
||||
'images': batched_images
|
||||
}
|
||||
|
||||
|
||||
def process_batch_eval(
|
||||
batch: Mapping[str, Sequence],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int,
|
||||
max_output_length: int,
|
||||
) -> dict[str, list]:
|
||||
batched_conv = batch['messages']
|
||||
batched_input_ids = []
|
||||
batched_attention_mask = []
|
||||
batched_position_ids = []
|
||||
batched_output_ids = []
|
||||
batched_images = []
|
||||
|
||||
for conv in batched_conv:
|
||||
|
||||
if conv[0].get('image'):
|
||||
image = Image.open(conv[0]['image']).convert('RGB')
|
||||
conv[0]['image'] = image
|
||||
|
||||
new_input_ids_all = tokenizer.apply_chat_template(
|
||||
conv,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
padding=True
|
||||
)
|
||||
|
||||
input_ids = new_input_ids_all['input_ids'][0]
|
||||
attention_mask = new_input_ids_all['attention_mask'][0]
|
||||
position_ids = list(range(len(input_ids)))
|
||||
|
||||
dialogue_parts = [0]
|
||||
for idx, token_id in enumerate(input_ids):
|
||||
if token_id == 151337:
|
||||
dialogue_parts.append(idx + 1)
|
||||
|
||||
if not dialogue_parts or dialogue_parts[-1] != len(input_ids):
|
||||
dialogue_parts.append(len(input_ids))
|
||||
|
||||
# Split the conversation into multiple dialogue segments
|
||||
for end_idx in range(1, len(dialogue_parts)):
|
||||
input_segment = input_ids[:dialogue_parts[end_idx]]
|
||||
attention_segment = attention_mask[:dialogue_parts[end_idx]]
|
||||
position_segment = position_ids[:dialogue_parts[end_idx]]
|
||||
output_segment = input_ids[dialogue_parts[end_idx - 1]:dialogue_parts[end_idx]]
|
||||
output_segment.append(151336) # Add EOS token
|
||||
|
||||
batched_input_ids.append(input_segment[:max_input_length])
|
||||
batched_attention_mask.append(attention_segment[:max_input_length])
|
||||
batched_position_ids.append(position_segment[:max_input_length])
|
||||
batched_output_ids.append(output_segment[:max_output_length])
|
||||
batched_images.append(new_input_ids_all['images'][0])
|
||||
|
||||
del batched_conv, input_ids, attention_mask, position_ids, new_input_ids_all, output_segment
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
'input_ids': batched_input_ids,
|
||||
'attention_mask': batched_attention_mask,
|
||||
'position_ids': batched_position_ids,
|
||||
'output_ids': batched_output_ids,
|
||||
'images': batched_images
|
||||
}
|
||||
|
||||
|
||||
def load_tokenizer_and_model(
|
||||
model_dir: str,
|
||||
peft_config: Optional[PeftConfig] = None,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
if peft_config is not None:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=True,
|
||||
empty_init=False,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.bfloat16 # Must use BFloat 16
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=True,
|
||||
empty_init=False,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
|
||||
batched_pred_ids, batched_label_ids = eval_preds
|
||||
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
||||
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
||||
pred_txt = tokenizer.decode(pred_ids).strip()
|
||||
label_txt = tokenizer.decode(label_ids).strip()
|
||||
pred_tokens = list(jieba.cut(pred_txt))
|
||||
label_tokens = list(jieba.cut(label_txt))
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
|
||||
for k, v in scores[0].items():
|
||||
metrics_dct[k].append(round(v['f'] * 100, 4))
|
||||
metrics_dct['bleu-4'].append(
|
||||
sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
|
||||
return {k: np.mean(v) for k, v in metrics_dct.items()}
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
data_dir: Annotated[str, typer.Argument(help='')],
|
||||
model_dir: Annotated[
|
||||
str,
|
||||
typer.Argument(
|
||||
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
|
||||
),
|
||||
],
|
||||
config_file: Annotated[str, typer.Argument(help='')],
|
||||
auto_resume_from_checkpoint: str = typer.Argument(
|
||||
default='',
|
||||
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
|
||||
),
|
||||
):
|
||||
ft_config = FinetuningConfig.from_file(config_file)
|
||||
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||
data_manager = DataManager(data_dir, ft_config.data_config)
|
||||
|
||||
train_dataset = data_manager.get_dataset(
|
||||
Split.TRAIN,
|
||||
functools.partial(
|
||||
process_batch,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
batched=True,
|
||||
)
|
||||
print('train_dataset:', train_dataset)
|
||||
val_dataset = data_manager.get_dataset(
|
||||
Split.VALIDATION,
|
||||
functools.partial(
|
||||
process_batch_eval,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
batched=True,
|
||||
)
|
||||
if val_dataset is not None:
|
||||
print('val_dataset:', val_dataset)
|
||||
test_dataset = data_manager.get_dataset(
|
||||
Split.TEST,
|
||||
functools.partial(
|
||||
process_batch_eval,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
batched=True,
|
||||
)
|
||||
if test_dataset is not None:
|
||||
print('test_dataset:', test_dataset)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=ft_config.training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
),
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
|
||||
trainer.train()
|
||||
else:
|
||||
output_dir = ft_config.training_args.output_dir
|
||||
dirlist = os.listdir(output_dir)
|
||||
checkpoint_sn = 0
|
||||
for checkpoint_str in dirlist:
|
||||
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
|
||||
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
|
||||
if checkpoint > checkpoint_sn:
|
||||
checkpoint_sn = checkpoint
|
||||
if auto_resume_from_checkpoint.upper() == "YES":
|
||||
if checkpoint_sn > 0:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
trainer.train()
|
||||
else:
|
||||
if auto_resume_from_checkpoint.isdigit():
|
||||
if int(auto_resume_from_checkpoint) > 0:
|
||||
checkpoint_sn = int(auto_resume_from_checkpoint)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
print(auto_resume_from_checkpoint,
|
||||
"The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
|
||||
|
||||
if test_dataset is not None:
|
||||
trainer.predict(test_dataset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app()
|
|
@ -2,37 +2,45 @@ from pathlib import Path
|
|||
from typing import Annotated, Union
|
||||
|
||||
import typer
|
||||
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
|
||||
from peft import PeftModelForCausalLM
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast
|
||||
)
|
||||
|
||||
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
||||
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_dir: Union[str, Path], trust_remote_code: bool = True
|
||||
) -> tuple[ModelType, TokenizerType]:
|
||||
):
|
||||
model_dir = Path(model_dir).expanduser().resolve()
|
||||
if (model_dir / 'adapter_config.json').exists():
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
||||
model = AutoModel.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
||||
model = AutoModel.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
device_map='auto',
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer_dir = model_dir
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_dir, trust_remote_code=trust_remote_code, encode_special_tokens=True, use_fast=False
|
||||
tokenizer_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
encode_special_tokens=True,
|
||||
use_fast=False
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
@ -41,6 +49,14 @@ def load_model_and_tokenizer(
|
|||
def main(
|
||||
model_dir: Annotated[str, typer.Argument(help='')],
|
||||
):
|
||||
# For GLM-4 Finetune Without Tools
|
||||
# messages = [
|
||||
# {
|
||||
# "role": "user", "content": "#裙子#夏天",
|
||||
# }
|
||||
# ]
|
||||
|
||||
# For GLM-4 Finetune With Tools
|
||||
messages = [
|
||||
{
|
||||
"role": "system", "content": "",
|
||||
|
@ -83,15 +99,25 @@ def main(
|
|||
"content": "Can you help me create a calendar event for my meeting tomorrow? The title is \"Team Meeting\". It starts at 10:00 AM and ends at 11:00 AM."
|
||||
},
|
||||
]
|
||||
|
||||
# For GLM-4V Finetune
|
||||
# messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "女孩可能希望观众做什么?",
|
||||
# "image": Image.open("your Image").convert("RGB")
|
||||
# }
|
||||
# ]
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(model_dir)
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt"
|
||||
return_tensors="pt",
|
||||
return_dict=True
|
||||
).to(model.device)
|
||||
generate_kwargs = {
|
||||
"input_ids": inputs,
|
||||
"max_new_tokens": 1024,
|
||||
"do_sample": True,
|
||||
"top_p": 0.8,
|
||||
|
@ -99,8 +125,8 @@ def main(
|
|||
"repetition_penalty": 1.2,
|
||||
"eos_token_id": model.config.eos_token_id,
|
||||
}
|
||||
outputs = model.generate(**generate_kwargs)
|
||||
response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()
|
||||
outputs = model.generate(**inputs, **generate_kwargs)
|
||||
response = tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True).strip()
|
||||
print("=========")
|
||||
print(response)
|
||||
|
||||
|
|
Loading…
Reference in New Issue