init commit

This commit is contained in:
duzx16 2024-06-05 10:22:16 +08:00
commit c0a6d1e0fa
65 changed files with 11654 additions and 0 deletions

72
.github/ISSUE_TEMPLATE/bug_report.yaml vendored Normal file
View File

@ -0,0 +1,72 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve GLM-4-9B / 提交一个 Bug 问题报告来帮助我们改进 GLM-4-9B
body:
- type: textarea
id: system-info
attributes:
label: System Info / 系統信息
description: Your operating environment / 您的运行环境信息
placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本Transformers版本Python版本操作系统硬件信息(如果您怀疑是硬件方面的问题)...
validations:
required: true
- type: textarea
id: who-can-help
attributes:
label: Who can help? / 谁可以帮助到您?
description: |
Your issue will be replied to more quickly if you can figure out the right person to tag with @
All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person.
Please tag fewer than 3 people.
如果您能找到合适的标签 @,您的问题会更快得到回复。
所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。
标记的人数应该不超过 3 个人。
If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem.
如果不是这三个子版块的bug您可以不指明帮助者我们的维护人员会找到合适的开发组成员来解决问题。
placeholder: "@Username ..."
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information / 问题信息
description: 'The problem arises when using: / 问题出现在'
options:
- label: "The official example scripts / 官方的示例脚本"
- label: "My own modified scripts / 我自己修改的脚本和任务"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction / 复现过程
description: |
Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
If you have code snippets, error messages, stack traces, please provide them here as well.
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
placeholder: |
Steps to reproduce the behavior/复现Bug的步骤:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior / 期待表现
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"

View File

@ -0,0 +1,34 @@
name: "\U0001F680 Feature request"
description: Submit a request for a new GLM-4-9B feature / 提交一个新的 GLM-4-9B 的功能建议
labels: [ "feature" ]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request / 功能建议
description: |
A brief description of the functional proposal. Links to corresponding papers and code are desirable.
对功能建议的简述。最好提供对应的论文和代码链接
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation / 动机
description: |
Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution / 您的贡献
description: |
Your PR link or any other link you can help with.
您的PR链接或者其他您能提供帮助的链接。

View File

@ -0,0 +1,34 @@
# Raise valuable PR / 提出有价值的PR
## Caution/ 注意事项:
Users should keep the following points in mind when submitting PRs:
1. The proposed PR should be about this project.
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
用户在提交PR时候应该注意以下几点:
1. 提出的PR应该是关于本项目的。
2. 提出的PR应该具有针对性如果具有多个不同的想法和优化方案应该分配到不同的PR中。
## 不应该提出的PR / PRs that should not be proposed
If a developer proposes a PR about any of the following, it may be closed or Rejected.
1. those that don't describe improvement options.
2. multiple issues of different types combined in one PR.
3. The proposed PR is highly duplicative of already existing PRs.
如果开发者提出关于以下方面的PR则可能会被直接关闭或拒绝通过。
1. 没有说明改进方案的。
2. 多个不同类型的问题合并在一个PR中的。
3. 提出的PR与已经存在的PR高度重复的。
# 检查您的PR
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
*venv
*.DS_Store
*base_model
*multimodal
chat_model
*.idea/

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2024 GLM-4-9B Model Team @ Zhipu AI
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

272
README.md Normal file
View File

@ -0,0 +1,272 @@
# GLM-4
<p align="center">
🤗 <a href="https://huggingface.co/THUDM/glm-4-9b-chat" target="_blank">HF Repo</a> • 🤖 <a href="https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat" target="_blank">ModelScope</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 👋 加入我们的 <a href="https://join.slack.com/t/chatglm/shared_invite/zt-25ti5uohv-A_hs~am_D3Q8XPZMpj7wwQ" target="_blank">Slack</a><a href="resources/WECHAT.md" target="_blank">微信</a>
</p>
<p align="center">
📍在 <a href="https://https://open.bigmodel.cn">智谱AI开放平台</a> 体验和使用更大规模的 GLM 商业模型。
</p>
Read this in [English](README_en.md)
## 模型介绍
GLM-4-9B 是智谱 AI 推出的最新一代预训练模型 GLM-4 系列中的开源版本。 在语义、数学、推理、代码和知识等多方面的数据集测评中,**GLM-4-9B**
及其人类偏好对齐的版本 **GLM-4-9B-Chat** 均表现出超越 Llama-3-8B 的卓越性能。除了能进行多轮对话GLM-4-9B-Chat
还具备网页浏览、代码执行、自定义工具调用Function Call和长文本推理支持最大 128K 上下文)等高级功能。本代模型增加了多语言支持,支持包括日语,韩语,德语在内的
26 种语言。我们还推出了支持 1M 上下文长度(约 200 万中文字符)的 **GLM-4-9B-Chat-1M** 模型和基于 GLM-4-9B 的多模态模型
GLM-4V-9B。**GLM-4V-9B** 具备 1120 * 1120 高分辨率下的中英双语多轮对话能力在中英文综合能力、感知推理、文字识别、图表理解等多方面多模态评测中GLM-4V-9B 表现出超越 GPT-4-turbo-2024-04-09、Gemini
1.0 Pro、Qwen-VL-Max 和 Claude 3 Opus 的卓越性能。
## 模型列表
| Model | Seq Length | Download |
|------------------|------------|-----------------------------------------------------------------------------------------------------------------------------------------|
| GLM-4-9B | 8K | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-9b) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-9b) |
| GLM-4-9B-Chat | 128K | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-9b-chat) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat) |
| GLM-4-9B-Chat-1M | 1M | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-9b-chat-1m) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat-1m) |
| GLM-4V-9B | 8K | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4v-9b) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4v-9b) |
## 评测结果
### 对话模型典型任务
| Model | AlignBench | MT-Bench | IFEval | MMLU | C-Eval | GSM8K | MATH | HumanEval | NaturalCodeBench |
|:--------------------|:----------:|:--------:|:------:|:----:|:------:|:-----:|:----:|:---------:|:----------------:|
| Llama-3-8B-Instruct | 6.40 | 8.00 | 68.58 | 68.4 | 51.3 | 79.6 | 30.0 | 62.2 | 24.7 |
| ChatGLM3-6B | 5.18 | 5.50 | 28.1 | 66.4 | 69.0 | 72.3 | 25.7 | 58.5 | 11.3 |
| GLM-4-9B-Chat | 7.01 | 8.35 | 69.0 | 72.4 | 75.6 | 79.6 | 50.6 | 71.8 | 32.2 |
### 基座模型典型任务
| Model | MMLU | C-Eval | GPQA | GSM8K | MATH | HumanEval |
|:--------------------|:----:|:------:|:----:|:-----:|:----:|:---------:|
| Llama-3-8B | 66.6 | 51.2 | - | 45.8 | - | 33.5 |
| Llama-3-8B-Instruct | 68.4 | 51.3 | 34.2 | 79.6 | 30.0 | 62.2 |
| ChatGLM3-6B-Base | 61.4 | 69.0 | 26.8 | 72.3 | 25.7 | 58.5 |
| GLM-4-9B | 74.7 | 77.1 | 34.3 | 84.0 | 30.4 | 70.1 |
> 由于 `GLM-4-9B` 在预训练过程中加入了部分数学、推理、代码相关的 instruction 数据,所以将 Llama-3-8B-Instruct 也列入比较范围。
### 长文本
在 1M 的上下文长度下进行[大海捞针实验](https://github.com/LargeWorldModel/LWM/blob/main/scripts/eval_needle.py),结果如下:
![needle](resources/eval_needle.jpeg)
在 LongBench-Chat 上对长文本能力进行了进一步评测,结果如下:
<p align="center">
<img src="resources/longbench.png" alt="描述文字" style="display: block; margin: auto; width: 65%;">
</p>
### 多语言能力
在六个多语言数据集上对 GLM-4-9B-Chat 和 Llama-3-8B-Instruct 进行了测试,测试结果及数据集对应选取语言如下表
| Dataset | Llama-3-8B-Instruct | GLM-4-9B-Chat | Languages
|:------------|:-------------------:|:-------------:|:----------------------------------------------------------------------------------------------:|
| M-MMLU | 49.6 | 56.6 | all
| FLORES | 25.0 | 28.8 | ru, es, de, fr, it, pt, pl, ja, nl, ar, tr, cs, vi, fa, hu, el, ro, sv, uk, fi, ko, da, bg, no
| MGSM | 54.0 | 65.3 | zh, en, bn, de, es, fr, ja, ru, sw, te, th
| XWinograd | 61.7 | 73.1 | zh, en, fr, jp, ru, pt
| XStoryCloze | 84.7 | 90.7 | zh, en, ar, es, eu, hi, id, my, ru, sw, te
| XCOPA | 73.3 | 80.1 | zh, et, ht, id, it, qu, sw, ta, th, tr, vi
### 工具调用能力
我们在 [Berkeley Function Calling Leaderboard](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard)
上进行了测试并得到了以下结果:
| Model | Overall Acc. | AST Summary | Exec Summary | Relevance |
|:-----------------------|:------------:|:-----------:|:------------:|:---------:|
| Llama-3-8B-Instruct | 58.88 | 59.25 | 70.01 | 45.83 |
| gpt-4-turbo-2024-04-09 | 81.24 | 82.14 | 78.61 | 88.75 |
| ChatGLM3-6B | 57.88 | 62.18 | 69.78 | 5.42 |
| GLM-4-9B-Chat | 81.00 | 80.26 | 84.40 | 87.92 |
### 多模态能力
GLM-4V-9B 是一个多模态语言模型,具备视觉理解能力,其相关经典任务的评测结果如下:
| | **MMBench-EN-Test** | **MMBench-CN-Test** | **SEEDBench_IMG** | **MMStar** | **MMMU** | **MME** | **HallusionBench** | **AI2D** | **OCRBench** |
|----------------------------|---------------------|---------------------|-------------------|------------|----------|---------|--------------------|----------|--------------|
| **gpt-4o-2024-05-13** | 83.4 | 82.1 | 77.1 | 63.9 | 69.2 | 2310.3 | 55 | 84.6 | 736 |
| **gpt-4-turbo-2024-04-09** | 81.0 | 80.2 | 73.0 | 56.0 | 61.7 | 2070.2 | 43.9 | 78.6 | 656 |
| **gpt-4-1106-preview** | 77.0 | 74.4 | 72.3 | 49.7 | 53.8 | 1771.5 | 46.5 | 75.9 | 516 |
| **InternVL-Chat-V1.5** | 82.3 | 80.7 | 75.2 | 57.1 | 46.8 | 2189.6 | 47.4 | 80.6 | 720 |
| **LLaVA-Next-Yi-34B** | 81.1 | 79 | 75.7 | 51.6 | 48.8 | 2050.2 | 34.8 | 78.9 | 574 |
| **Step-1V** | 80.7 | 79.9 | 70.3 | 50.0 | 49.9 | 2206.4 | 48.4 | 79.2 | 625 |
| **MiniCPM-Llama3-V2.5** | 77.6 | 73.8 | 72.3 | 51.8 | 45.8 | 2024.6 | 42.4 | 78.4 | 725 |
| **Qwen-VL-Max** | 77.6 | 75.7 | 72.7 | 49.5 | 52 | 2281.7 | 41.2 | 75.7 | 684 |
| **Gemini 1.0 Pro** | 73.6 | 74.3 | 70.7 | 38.6 | 49 | 2148.9 | 45.7 | 72.9 | 680 |
| **Claude 3 Opus** | 63.3 | 59.2 | 64 | 45.7 | 54.9 | 1586.8 | 37.8 | 70.6 | 694 |
| **GLM-4V-9B** | 81.1 | 79.4 | 76.8 | 58.7 | 47.2 | 2163.8 | 46.6 | 81.1 | 786 |
## 快速调用
### 使用以下方法快速调用 GLM-4-9B-Chat 语言模型
使用 transformers 后端进行推理:
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
query = "你好"
inputs = tokenizer.apply_chat_template([{"role": "user", "content": query}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
)
inputs = inputs.to(device)
model = AutoModelForCausalLM.from_pretrained(
"THUDM/glm-4-9b-chat",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(device).eval()
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
使用 vLLM 后端进行推理:
```python
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
# GLM-4-9B-Chat-1M
# max_model_len, tp_size = 1048576, 4
# GLM-4-9B-Chat
max_model_len, tp_size = 131072, 1
model_name = "THUDM/glm-4-9b-chat"
prompt = '你好'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
max_model_len=max_model_len,
trust_remote_code=True,
enforce_eager=True,
# GLM-4-9B-Chat-1M 如果遇见 OOM 现象,建议开启下述参数
# enable_chunked_prefill=True,
# max_num_batched_tokens=8192
)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0.95, max_tokens=1024, stop_token_ids=stop_token_ids)
inputs = tokenizer.build_chat_input(prompt, history=None, role='user')['input_ids'].tolist()
outputs = llm.generate(prompt_token_ids=inputs, sampling_params=sampling_params)
generated_text = [output.outputs[0].text for output in outputs]
print(generated_text)
```
### 使用以下方法快速调用 GLM-4V-9B 多模态模型
使用 transformers 后端进行推理:
```python
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4v-9b", trust_remote_code=True)
query = '描述这张图片'
image = Image.open("your image").convert('RGB')
inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}],
add_generation_prompt=True, tokenize=True, return_tensors="pt",
return_dict=True) # chat mode
inputs = inputs.to(device)
model = AutoModelForCausalLM.from_pretrained(
"THUDM/glm-4v-9b",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(device).eval()
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
print(tokenizer.decode(outputs[0]))
```
注意: GLM-4V-9B 暂不支持使用 vLLM 方式调用。
## 完整项目列表
如果你想更进一步了解 GLM-4-9B 系列开源模型,本开源仓库通过以下内容为开发者提供基础的 GLM-4-9B的使用和开发代码
+ [base](basic_demo/README.md): 在这里包含了
+ 使用 transformers 和 VLLM 后端的交互代码
+ OpenAI API 后端交互代码
+ Batch 推理代码
+ [composite_demo](composite_demo/README.md): 在这里包含了
+ GLM-4-9B 以及 GLM-4V-9B 开源模型的完整功能演示代码,包含了 All Tools 能力、长文档解读和多模态能力的展示。
+ [fintune_demo](finetune_demo/README.md): 在这里包含了
+ PEFT (LORA, P-Tuning) 微调代码
+ SFT 微调代码
## 协议
+ GLM-4 模型的权重的使用则需要遵循 [模型协议](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE)。
+ 本开源仓库的代码则遵循 [Apache 2.0](LICENSE) 协议。
请您严格遵循开源协议。
## 引用
如果你觉得我们的工作有帮助的话,请考虑引用下列论文。
```
@inproceedings{zeng2022glm,
title={{GLM-130B:} An Open Bilingual Pre-trained Model},
author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
booktitle={The Eleventh International Conference on Learning Representations,
{ICLR} 2023, Kigali, Rwanda, May 1-5, 2023},
year= {2023},
}
```
```
@inproceedings{du2022glm,
title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
pages={320--335},
year={2022}
}
```
```
@misc{wang2023cogvlm,
title={CogVLM: Visual Expert for Pretrained Language Models},
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
year={2023},
eprint={2311.03079},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

280
README_en.md Normal file
View File

@ -0,0 +1,280 @@
# GLM-4
<p align="center">
🤗 <a href="https://huggingface.co/THUDM/glm-4-9b-chat" target="_blank">HF Repo</a> • 🤖 <a href="https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat" target="_blank">ModelScope</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 👋 Join <a href="https://join.slack.com/t/chatglm/shared_invite/zt-25ti5uohv-A_hs~am_D3Q8XPZMpj7wwQ" target="_blank">Slack</a> and <a href="resources/WECHAT.md" target="_blank">WeChat</a>
</p>
<p align="center">
📍Experience and use a larger-scale GLM business model on the <a href="https://https://open.bigmodel.cn">Zhipu AI Open Platform</a>
</p>
## Model Introduction
GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu
AI. In the evaluation of data sets in semantics, mathematics, reasoning, code, and knowledge, **GLM-4-9B**
and its human preference-aligned version **GLM-4-9B-Chat** have shown superior performance beyond Llama-3-8B. In addition to
multi-round conversations, GLM-4-9B-Chat
also has advanced features such as web browsing, code execution, custom tool calls (Function Call), and long text
reasoning (supporting up to 128K context). This generation of models has added multi-language support, supporting 26
languages including Japanese, Korean, and German. We have also launched the **GLM-4-9B-Chat-1M** model that supports 1M
context length (about 2 million Chinese characters) and the multimodal model GLM-4V-9B based on GLM-4-9B.
**GLM-4V-9B** possesses dialogue capabilities in both Chinese and English at a high resolution of 1120*1120.
In various multimodal evaluations, including comprehensive abilities in Chinese and English, perception & reasoning, text recognition, and chart understanding, GLM-4V-9B demonstrates superior performance compared to GPT-4-turbo-2024-04-09, Gemini 1.0 Pro, Qwen-VL-Max, and Claude 3 Opus.
## Model List
| Model | Seq Length | Download |
|------------------|------------|-----------------------------------------------------------------------------------------------------------------------------------------|
| GLM-4-9B | 8K | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-9b) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-9b) |
| GLM-4-9B-Chat | 128K | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-9b-chat) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat) |
| GLM-4-9B-Chat-1M | 1M | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-9b-chat-1m) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat-1m) |
| GLM-4V-9B | 8K | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4v-9b) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4v-9b) |
## BenchMark
### Typical Tasks
| Model | AlignBench | MT-Bench | IFEval | MMLU | C-Eval | GSM8K | MATH | HumanEval | NaturalCodeBench |
|:--------------------|:----------:|:--------:|:------:|:----:|:------:|:-----:|:----:|:---------:|:----------------:|
| Llama-3-8B-Instruct | 6.40 | 8.00 | 68.58 | 68.4 | 51.3 | 79.6 | 30.0 | 62.2 | 24.7 |
| ChatGLM3-6B | 5.18 | 5.50 | 28.1 | 66.4 | 69.0 | 72.3 | 25.7 | 58.5 | 11.3 |
| GLM-4-9B-Chat | 7.01 | 8.35 | 69.0 | 72.4 | 75.6 | 79.6 | 50.6 | 71.8 | 32.2 |
### Base Model
| Model | MMLU | C-Eval | GPQA | GSM8K | MATH | HumanEval |
|:--------------------|:----:|:------:|:----:|:-----:|:----:|:---------:|
| Llama-3-8B | 66.6 | 51.2 | - | 45.8 | - | 33.5 |
| Llama-3-8B-Instruct | 68.4 | 51.3 | 34.2 | 79.6 | 30.0 | 62.2 |
| ChatGLM3-6B-Base | 61.4 | 69.0 | 26.8 | 72.3 | 25.7 | 58.5 |
| GLM-4-9B | 74.7 | 77.1 | 34.3 | 84.0 | 30.4 | 70.1 |
> Since `GLM-4-9B` adds some math, reasoning, and code-related instruction data during pre-training, Llama-3-8B-Instruct
> is also included in the comparison range.
### Long Context
The [needle-in-the-haystack experiment](https://github.com/LargeWorldModel/LWM/blob/main/scripts/eval_needle.py) was
conducted with a context length of 1M, and the results are as follows:
![needle](resources/eval_needle.jpeg)
The long text capability was further evaluated on LongBench-Chat, and the results are as follows:
<p align="center">
<img src="resources/longbench.png" alt="Description text" style="display: block; margin: auto; width: 65%;">
</p>
### 多语言能力
The tests for GLM-4-9B-Chat and Llama-3-8B-Instruct are conducted on six multilingual datasets. The test results and the corresponding languages selected for each dataset are shown in the table below:
| Dataset | Llama-3-8B-Instruct | GLM-4-9B-Chat | Languages
|:------------|:-------------------:|:-------------:|:----------------------------------------------------------------------------------------------:|
| M-MMLU | 49.6 | 56.6 | all
| FLORES | 25.0 | 28.8 | ru, es, de, fr, it, pt, pl, ja, nl, ar, tr, cs, vi, fa, hu, el, ro, sv, uk, fi, ko, da, bg, no
| MGSM | 54.0 | 65.3 | zh, en, bn, de, es, fr, ja, ru, sw, te, th
| XWinograd | 61.7 | 73.1 | zh, en, fr, jp, ru, pt
| XStoryCloze | 84.7 | 90.7 | zh, en, ar, es, eu, hi, id, my, ru, sw, te
| XCOPA | 73.3 | 80.1 | zh, et, ht, id, it, qu, sw, ta, th, tr, vi
### Function Call
Tested
on [Berkeley Function Calling Leaderboard](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard).
| Model | Overall Acc. | AST Summary | Exec Summary | Relevance |
|:-----------------------|:------------:|:-----------:|:------------:|:---------:|
| Llama-3-8B-Instruct | 58.88 | 59.25 | 70.01 | 45.83 |
| gpt-4-turbo-2024-04-09 | 81.24 | 82.14 | 78.61 | 88.75 |
| ChatGLM3-6B | 57.88 | 62.18 | 69.78 | 5.42 |
| GLM-4-9B-Chat | 81.00 | 80.26 | 84.40 | 87.92 |
### Multi-Modal
GLM-4V-9B is a multimodal language model with visual understanding capabilities. The evaluation results of its related
classic tasks are as follows:
| | **MMBench-EN-Test** | **MMBench-CN-Test** | **SEEDBench_IMG** | **MMStar** | **MMMU** | **MME** | **HallusionBench** | **AI2D** | **OCRBench** |
|----------------------------|---------------------|---------------------|-------------------|------------|----------|---------|--------------------|----------|--------------|
| **gpt-4o-2024-05-13** | 83.4 | 82.1 | 77.1 | 63.9 | 69.2 | 2310.3 | 55 | 84.6 | 736 |
| **gpt-4-turbo-2024-04-09** | 81.0 | 80.2 | 73.0 | 56.0 | 61.7 | 2070.2 | 43.9 | 78.6 | 656 |
| **gpt-4-1106-preview** | 77.0 | 74.4 | 72.3 | 49.7 | 53.8 | 1771.5 | 46.5 | 75.9 | 516 |
| **InternVL-Chat-V1.5** | 82.3 | 80.7 | 75.2 | 57.1 | 46.8 | 2189.6 | 47.4 | 80.6 | 720 |
| **LLaVA-Next-Yi-34B** | 81.1 | 79 | 75.7 | 51.6 | 48.8 | 2050.2 | 34.8 | 78.9 | 574 |
| **Step-1V** | 80.7 | 79.9 | 70.3 | 50.0 | 49.9 | 2206.4 | 48.4 | 79.2 | 625 |
| **MiniCPM-Llama3-V2.5** | 77.6 | 73.8 | 72.3 | 51.8 | 45.8 | 2024.6 | 42.4 | 78.4 | 725 |
| **Qwen-VL-Max** | 77.6 | 75.7 | 72.7 | 49.5 | 52 | 2281.7 | 41.2 | 75.7 | 684 |
| **Gemini 1.0 Pro** | 73.6 | 74.3 | 70.7 | 38.6 | 49 | 2148.9 | 45.7 | 72.9 | 680 |
| **Claude 3 Opus** | 63.3 | 59.2 | 64 | 45.7 | 54.9 | 1586.8 | 37.8 | 70.6 | 694 |
| **GLM-4V-9B** | 81.1 | 79.4 | 76.8 | 58.7 | 47.2 | 2163.8 | 46.6 | 81.1 | 786 |
## Quick call
### Use the following method to quickly call the GLM-4-9B-Chat language model
Use the transformers backend for inference:
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
query = "你好"
inputs = tokenizer.apply_chat_template([{"role": "user", "content": query}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
)
inputs = inputs.to(device)
model = AutoModelForCausalLM.from_pretrained(
"THUDM/glm-4-9b-chat",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(device).eval()
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
Use the vLLM backend for inference:
```python
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
# GLM-4-9B-Chat-1M
# max_model_len, tp_size = 1048576, 4
# GLM-4-9B-Chat
max_model_len, tp_size = 131072, 1
model_name = "THUDM/glm-4-9b-chat"
prompt = '你好'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
max_model_len=max_model_len,
trust_remote_code=True,
enforce_eager=True,
# GLM-4-9B-Chat-1M If you encounter OOM phenomenon, it is recommended to turn on the following parameters
# enable_chunked_prefill=True,
# max_num_batched_tokens=8192
)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0.95, max_tokens=1024, stop_token_ids=stop_token_ids)
inputs = tokenizer.build_chat_input(prompt, history=None, role='user')['input_ids'].tolist()
outputs = llm.generate(prompt_token_ids=inputs, sampling_params=sampling_params)
generated_text = [output.outputs[0].text for output in outputs]
print(generated_text)
```
### Use the following method to quickly call the GLM-4V-9B multimodal model
Use the transformers backend for inference:
```python
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4v-9b", trust_remote_code=True)
query = 'display this image'
image = Image.open("your image").convert('RGB')
inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}],
add_generation_prompt=True, tokenize=True, return_tensors="pt",
return_dict=True) # chat mode
inputs = inputs.to(device)
model = AutoModelForCausalLM.from_pretrained(
"THUDM/glm-4v-9b",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(device).eval()
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
print(tokenizer.decode(outputs[0]))
```
Note: GLM-4V-9B does not support calling using vLLM method yet.
## Complete project list
If you want to learn more about the GLM-4-9B series open source models, this open source repository provides developers
with basic GLM-4-9B usage and development code through the following content
+ [base](basic_demo/README.md): Contains
+ Interaction code using transformers and VLLM backend
+ OpenAI API backend interaction code
+ Batch reasoning code
+ [composite_demo](composite_demo/README.md): Contains
+ Fully functional demonstration code for GLM-4-9B and GLM-4V-9B open source models, including All Tools capabilities,
long document interpretation, and multimodal capabilities.
+ [fintune_demo](finetune_demo/README.md): Contains
+ PEFT (LORA, P-Tuning) fine-tuning code
+ SFT fine-tuning code
## License
+ The use of GLM-4 model weights must follow
the [Model License](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE).
+ The code in this open source repository follows the [Apache 2.0](LICENSE) license.
Please strictly follow the open source license.
## Reference
If you find our work helpful, please consider citing the following paper.
```
@inproceedings{zeng2022glm,
title={{GLM-130B:} An Open Bilingual Pre-trained Model},
author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
booktitle={The Eleventh International Conference on Learning Representations,
{ICLR} 2023, Kigali, Rwanda, May 1-5, 2023},
year= {2023},
}
```
```
@inproceedings{du2022glm,
title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
pages={320--335},
year={2022}
}
```
```
@misc{wang2023cogvlm,
title={CogVLM: Visual Expert for Pretrained Language Models},
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
year={2023},
eprint={2311.03079},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

114
basic_demo/README.md Normal file
View File

@ -0,0 +1,114 @@
# Basic Demo
Read this in [English](README_en.md)
本 demo 中,你将体验到如何使用 glm-4-9b 开源模型进行基本的任务。
请严格按照文档的步骤进行操作,以避免不必要的错误。
## 设备和依赖检查
### 相关推理测试数据
**本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。**
测试硬件信息:
+ OS: Ubuntu 22.04
+ Memory: 512GB
+ Python: 3.12.3
+ CUDA Version: 12.3
+ GPU Driver: 535.104.05
+ GPU: NVIDIA A100-SXM4-80GB * 8
相关推理的压力测试数据如下:
**所有测试均在单张GPU上进行测试,所有显存消耗都按照峰值左右进行测算**
| 精度 | 显存占用 | Prefilling / 首响 | Decode Speed | Remarks |
|------|----------|-----------------|------------------|--------------|
| BF16 | 19047MiB | 0.1554s | 27.8193 tokens/s | 输入长度为 1000 |
| BF16 | 20629MiB | 0.8199s | 31.8613 tokens/s | 输入长度为 8000 |
| BF16 | 27779MiB | 4.3554s | 14.4108 tokens/s | 输入长度为 32000 |
| BF16 | 57379MiB | 38.1467s | 3.4205 tokens/s | 输入长度为 128000 |
| BF16 | 74497MiB | 98.4930s | 2.3653 tokens/s | 输入长度为 200000 |
| 精度 | 显存占用 | Prefilling / 首响 | Decode Speed | Remarks |
|------|----------|-----------------|------------------|-------------|
| Int4 | 8251MiB | 0.1667s | 23.3903 tokens/s | 输入长度为 1000 |
| Int4 | 9613MiB | 0.8629s | 23.4248 tokens/s | 输入长度为 8000 |
| Int4 | 16065MiB | 4.3906s | 14.6553 tokens/s | 输入长度为 32000 |
### 最低硬件要求
如果您希望运行官方提供的最基础代码 (transformers 后端) 您需要:
+ Python >= 3.10
+ 内存不少于 32 GB
如果您希望运行官方提供的本文件夹的所有代码,您还需要:
+ Linux 操作系统 (Debian 系列最佳)
+ 大于 8GB 显存的,支持 CUDA 或者 ROCM 并且支持 `BF16` 推理的 GPU 设备 (A100以上GPUV10020以及更老的GPU架构不受支持)
安装依赖
```shell
pip install -r requirements.txt
```
## 基础功能调用
**除非特殊说明,本文件夹所有 demo 并不支持 Function Call 和 All Tools 等进阶用法**
### 使用 transformers 后端代码
+ 使用 命令行 与 glm-4-9b 模型进行对话。
```shell
python trans_cli_demo.py
```
+ 使用 Gradio 网页端与 glm-4-9b 模型进行对话。
```shell
python trans_web_demo.py
```
+ 使用 Batch 推理。
```shell
python cli_batch_request_demo.py
```
### 使用 VLLM 后端代码
+ 使用命令行与 glm-4-9b 模型进行对话。
```shell
python vllm_cli_demo.py
```
+ 自行构建服务端,并使用 `OpenAI API` 的请求格式与 glm-4-9b 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。
启动服务端:
```shell
python openai_api_server.py
```
客户端请求:
```shell
python openai_api_request.py
```
## 压力测试
用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度:
```shell
python trans_stress_test.py
```

114
basic_demo/README_en.md Normal file
View File

@ -0,0 +1,114 @@
# Basic Demo
In this demo, you will experience how to use the glm-4-9b open source model to perform basic tasks.
Please follow the steps in the document strictly to avoid unnecessary errors.
## Device and dependency check
### Related inference test data
**The data in this document are tested in the following hardware environment. The actual operating environment
requirements and the video memory occupied by the operation are slightly different. Please refer to the actual operating
environment. **
Test hardware information:
+ OS: Ubuntu 22.04
+ Memory: 512GB
+ Python: 3.12.3
+ CUDA Version: 12.3
+ GPU Driver: 535.104.05
+ GPU: NVIDIA A100-SXM4-80GB * 8
The stress test data of relevant inference are as follows:
**All tests are performed on a single GPU, and all video memory consumption is calculated based on the peak value**
| Accuracy | Video memory usage | Prefilling / First ring | Decode Speed | Remarks |
|----------|--------------------|-------------------------|------------------|------------------------|
| BF16 | 19047MiB | 0.1554s | 27.8193 tokens/s | Input length is 1000 |
| BF16 | 20629MiB | 0.8199s | 31.8613 tokens/s | Input length is 8000 |
| BF16 | 27779MiB | 4.3554s | 14.4108 tokens/s | Input length is 32000 |
| BF16 | 57379MiB | 38.1467s | 3.4205 tokens/s | Input length is 128000 |
| BF16 | 74497MiB | 98.4930s | 2.3653 tokens/s | Input length is 200000 |
| Precision | Video Memory | Prefilling / First Sound | Decode Speed | Remarks |
|-----------|--------------|--------------------------|------------------|-----------------------|
| Int4 | 8251MiB | 0.1667s | 23.3903 tokens/s | Input length is 1000 |
| Int4 | 9613MiB | 0.8629s | 23.4248 tokens/s | Input length is 8000 |
| Int4 | 16065MiB | 4.3906s | 14.6553 tokens/s | Input length is 32000 |
### Minimum hardware requirements
If you want to run the most basic code provided by the official (transformers backend) you need:
+ Python >= 3.10
+ Memory of at least 32 GB
If you want to run all the codes in this folder provided by the official, you also need:
+ Linux operating system (Debian series is best)
+ GPU device with more than 8GB video memory, supporting CUDA or ROCM and supporting `BF16` reasoning (GPUs above A100,
V100, 20 and older GPU architectures are not supported)
Install dependencies
```shell
pip install -r requirements.txt
```
## Basic function calls
**Unless otherwise specified, all demos in this folder do not support advanced usage such as Function Call and All Tools
**
### Use transformers backend code
+ Use the command line to communicate with the glm-4-9b model.
```shell
python trans_cli_demo.py
```
+ Use the Gradio web client to communicate with the glm-4-9b model.
```shell
python trans_web_demo.py
```
+ Use Batch inference.
```shell
python cli_batch_request_demo.py
```
### Use VLLM backend code
+ Use the command line to communicate with the glm-4-9b model.
```shell
python vllm_cli_demo.py
```
+ Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This
demo supports Function Call and All Tools functions.
Start the server:
```shell
python openai_api_server.py
```
Client request:
```shell
python openai_api_request.py
```
## Stress test
Users can use this code to test the generation speed of the model on the transformers backend on their own devices:
```shell
python trans_stress_test.py
```

View File

@ -0,0 +1,88 @@
"""
This script creates a OpenAI Request demo for the glm-4-9b model, just Use OpenAI API to interact with the model.
"""
from openai import OpenAI
base_url = "http://127.0.0.1:8000/v1/"
client = OpenAI(api_key="EMPTY", base_url=base_url)
def function_chat():
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
# All Tools 能力: 绘图
# messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
# tools = [{"type": "cogview"}]
#
# All Tools 能力: 联网查询
# messages = [{"role": "user", "content": "今天黄金的价格"}]
# tools = [{"type": "simple_browser"}]
response = client.chat.completions.create(
model="glm-4",
messages=messages,
tools=tools,
tool_choice="auto", # use "auto" to let the model choose the tool automatically
# tool_choice={"type": "function", "function": {"name": "my_function"}},
)
if response:
content = response.choices[0].message.content
print(content)
else:
print("Error:", response.status_code)
def simple_chat(use_stream=False):
messages = [
{
"role": "system",
"content": "你是 GLM-4请你热情回答用户的问题。",
},
{
"role": "user",
"content": "你好,请你用生动的话语给我讲一个小故事吧"
}
]
response = client.chat.completions.create(
model="glm-4",
messages=messages,
stream=use_stream,
max_tokens=1024,
temperature=0.8,
presence_penalty=1.1,
top_p=0.8)
if response:
if use_stream:
for chunk in response:
print(chunk.choices[0].delta.content)
else:
content = response.choices[0].message.content
print(content)
else:
print("Error:", response.status_code)
if __name__ == "__main__":
simple_chat()
function_chat()

View File

@ -0,0 +1,543 @@
import os
import time
from asyncio.log import logger
import uvicorn
import gc
import json
import torch
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, LogitsProcessor
from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = 'THUDM/glm-4-9b'
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class FunctionCallResponse(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "tool"]
content: str = None
name: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
class EmbeddingRequest(BaseModel):
input: Union[List[str], str]
model: str
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: CompletionUsage
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
tool_choice: Optional[Union[str, dict]] = "None"
repetition_penalty: Optional[float] = 1.1
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "function_call"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
content = ""
for response in output.split(""):
metadata, content = response.split("\n", maxsplit=1)
if not metadata.strip():
content = content.strip()
else:
if use_tool:
content = "\n".join(content.split("\n")[1:-1])
parameters = eval(content)
content = {
"name": metadata.strip(),
"arguments": json.dumps(parameters, ensure_ascii=False)
}
else:
content = {
"name": metadata.strip(),
"content": content
}
return content
@torch.inference_mode()
async def generate_stream_glm4(params):
messages = params["messages"]
tools = params["tools"]
tool_choice = params["tool_choice"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 8192))
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
"top_k": -1,
"repetition_penalty": repetition_penalty,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"max_tokens": max_new_tokens,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id="glm-4-9b"):
output_len = len(output.outputs[0].token_ids)
input_len = len(output.prompt_token_ids)
ret = {
"text": output.outputs[0].text,
"usage": {
"prompt_tokens": input_len,
"completion_tokens": output_len,
"total_tokens": output_len + input_len
},
"finish_reason": output.outputs[0].finish_reason,
}
yield ret
gc.collect()
torch.cuda.empty_cache()
def process_messages(messages, tools=None, tool_choice="none"):
_messages = messages
messages = []
msg_has_sys = False
def filter_tools(tool_choice, tools):
function_name = tool_choice.get('function', {}).get('name', None)
if not function_name:
return []
filtered_tools = [
tool for tool in tools
if tool.get('function', {}).get('name') == function_name
]
return filtered_tools
if tool_choice != "none":
if isinstance(tool_choice, dict):
tools = filter_tools(tool_choice, tools)
if tools:
messages.append(
{
"role": "system",
"content": None,
"tools": tools
}
)
msg_has_sys = True
# add to metadata
if isinstance(tool_choice, dict) and tools:
messages.append(
{
"role": "assistant",
"metadata": tool_choice["function"]["name"],
"content": ""
}
)
for m in _messages:
role, content, func_call = m.role, m.content, m.function_call
if role == "function":
messages.append(
{
"role": "observation",
"content": content
}
)
elif role == "assistant" and func_call is not None:
for response in content.split(""):
metadata, sub_content = response.split("\n", maxsplit=1)
messages.append(
{
"role": role,
"metadata": metadata,
"content": sub_content.strip()
}
)
else:
if role == "system" and msg_has_sys:
msg_has_sys = False
continue
messages.append({"role": role, "content": content})
return messages
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="glm-4")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
tool_choice=request.tool_choice,
)
logger.debug(f"==== request ====\n{gen_params}")
if request.stream:
predict_stream_generator = predict_stream(request.model, gen_params)
output = await anext(predict_stream_generator)
if not output and 'get_' in output:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
logger.debug(f"First result output\n{output}")
function_call = None
if output and request.tools:
try:
function_call = process_response(output, use_tool=True)
except:
logger.warning("Failed to parse tool call")
# CallFunction
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
tool_response = ""
if not gen_params.get("messages"):
gen_params["messages"] = []
gen_params["messages"].append(ChatMessage(role="assistant", content=output))
gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
generate = predict(request.model, gen_params)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
generate = parse_output_text(request.model, output)
return EventSourceResponse(generate, media_type="text/event-stream")
response = ""
async for response in generate_stream_glm4(gen_params):
pass
if response["text"].startswith("\n"):
response["text"] = response["text"][1:]
response["text"] = response["text"].strip()
usage = UsageInfo()
function_call, finish_reason = None, "stop"
if request.tools:
try:
function_call = process_response(response["text"], use_tool=True)
except:
logger.warning(
"Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
if isinstance(function_call, dict):
finish_reason = "function_call"
function_call = FunctionCallResponse(**function_call)
message = ChatMessage(
role="assistant",
content=response["text"],
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)
logger.debug(f"==== message ====\n{message}")
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.model_validate(response["usage"])
for usage_key, usage_value in task_usage.model_dump().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(
model=request.model,
id="", # for open_source model, id is empty
choices=[choice_data],
object="chat.completion",
usage=usage
)
async def predict(model_id: str, params: dict):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
previous_text = ""
async for new_response in generate_stream_glm4(params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(previous_text):]
previous_text = decoded_unicode
finish_reason = new_response["finish_reason"]
if len(delta_text) == 0 and finish_reason != "function_call":
continue
function_call = None
if finish_reason == "function_call":
try:
function_call = process_response(decoded_unicode, use_tool=True)
except:
logger.warning(
"Failed to parse tool call, maybe the response is not a tool call or have been answered.")
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
delta = DeltaMessage(
content=delta_text,
role="assistant",
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
async def predict_stream(model_id, gen_params):
output = ""
is_function_call = False
has_send_first_chunk = False
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode
if not is_function_call and len(output) > 7:
is_function_call = output and 'get_' in output
if is_function_call:
continue
finish_reason = new_response["finish_reason"]
if not has_send_first_chunk:
message = DeltaMessage(
content="",
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
send_msg = delta_text if has_send_first_chunk else output
has_send_first_chunk = True
message = DeltaMessage(
content=send_msg,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
choices=[choice_data],
created=int(time.time()),
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
if is_function_call:
yield output
else:
yield '[DONE]'
async def parse_output_text(model_id: str, value: str):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=value),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
tokenizer=MODEL_PATH,
tokenizer_mode="slow",
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,
gpu_memory_utilization=0.3,
enforce_eager=True,
worker_use_ray=True,
engine_use_ray=False,
disable_log_requests=True
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

View File

@ -0,0 +1,23 @@
torch>=2.3.0
torchvision>=0.18.0
transformers==4.40.0
huggingface-hub>=0.23.1
sentencepiece>=0.2.0
pydantic>=2.7.1
timm>=0.9.16
tiktoken>=0.7.0
accelerate>=0.30.1
sentence_transformers>=2.7.0
vllm>=0.4.3
# web demo
gradio>=4.31.5
# openai demo
openai>=1.30.3
einops>=0.7.0
sse-starlette>=2.1.0
# Int4
bitsandbytes>=0.43.1

View File

@ -0,0 +1,90 @@
"""
Here is an example of using batch request glm-4-9b,
here you need to build the conversation format yourself and then call the batch function to make batch requests.
Please note that in this demo, the memory consumption is significantly higher.
"""
from typing import Optional, Union
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
MODEL_PATH = 'THUDM/glm-4-9b-chat'
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
encode_special_tokens=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
def process_model_outputs(inputs, outputs, tokenizer):
responses = []
for input_ids, output_ids in zip(inputs.input_ids, outputs):
response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
responses.append(response)
return responses
def batch(
model,
tokenizer,
messages: Union[str, list[str]],
max_input_tokens: int = 8192,
max_new_tokens: int = 8192,
num_beams: int = 1,
do_sample: bool = True,
top_p: float = 0.8,
temperature: float = 0.8,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
):
messages = [messages] if isinstance(messages, str) else messages
batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True,
max_length=max_input_tokens).to(model.device)
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
"eos_token_id": model.config.eos_token_id
}
batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer)
return batched_response
if __name__ == "__main__":
batch_message = [
[
{"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"},
{"role": "assistant", "content": "因为他们结婚时你还没有出生"},
{"role": "user", "content": "我刚才的提问是"}
],
[
{"role": "user", "content": "你好,你是谁"}
]
]
batch_inputs = []
max_input_tokens = 1024
for i, messages in enumerate(batch_message):
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
max_input_tokens = max(max_input_tokens, len(new_batch_input))
batch_inputs.append(new_batch_input)
gen_kwargs = {
"max_input_tokens": max_input_tokens,
"max_new_tokens": 8192,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.8,
"num_beams": 1,
}
batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs)
for response in batch_responses:
print("=" * 10)
print(response)

View File

@ -0,0 +1,120 @@
"""
This script creates a CLI demo with transformers backend for the glm-4-9b model,
allowing users to interact with the model through a command-line interface.
Usage:
- Run the script to start the CLI demo.
- Interact with the model by typing questions and receiving responses.
Note: The script includes a modification to handle markdown to plain text conversion,
ensuring that the CLI interface displays formatted text correctly.
"""
import os
import torch
from threading import Thread
from typing import Union
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b')
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = Path(model_dir).expanduser().resolve()
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto')
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=trust_remote_code, device_map='auto')
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code, encode_special_tokens=True, use_fast=False
)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
if __name__ == "__main__":
history = []
max_length = 8192
top_p = 0.8
temperature = 0.6
stop = StopOnTokens()
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
while True:
user_input = input("\nYou: ")
if user_input.lower() in ["exit", "quit"]:
break
history.append([user_input, ""])
messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
"eos_token_id": model.config.eos_token_id,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
print("GLM-4:", end="", flush=True)
for new_token in streamer:
if new_token:
print(new_token, end="", flush=True)
history[-1][1] += new_token
history[-1][1] = history[-1][1].strip()

View File

@ -0,0 +1,128 @@
import argparse
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread
MODEL_PATH = 'THUDM/glm-4-9b-chat'
def stress_test(token_len, n, num_gpu):
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
padding_side="left"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
# low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16
).to(device).eval()
times = []
decode_times = []
print("Warming up...")
vocab_size = tokenizer.vocab_size
warmup_token_len = 20
random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
start_tokens = [151331, 151333, 151336, 198]
end_tokens = [151337]
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
device)
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
warmup_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids
}
with torch.no_grad():
_ = model.generate(
input_ids=warmup_inputs['input_ids'],
attention_mask=warmup_inputs['attention_mask'],
max_new_tokens=2048,
do_sample=False,
repetition_penalty=1.0,
eos_token_id=[151329, 151336, 151338]
)
print("Warming up complete. Starting stress test...")
for i in range(n):
random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
0).to(device)
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
test_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids
}
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=36000,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": test_inputs['input_ids'],
"attention_mask": test_inputs['attention_mask'],
"max_new_tokens": 512,
"do_sample": False,
"repetition_penalty": 1.0,
"eos_token_id": [151329, 151336, 151338],
"streamer": streamer
}
start_time = time.time()
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
first_token_time = None
all_token_times = []
for token in streamer:
current_time = time.time()
if first_token_time is None:
first_token_time = current_time
times.append(first_token_time - start_time)
all_token_times.append(current_time)
t.join()
end_time = time.time()
avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
decode_times.append(avg_decode_time_per_token)
print(
f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")
torch.cuda.empty_cache()
avg_first_token_time = sum(times) / n
avg_decode_time = sum(decode_times) / n
print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
return times, avg_first_token_time, decode_times, avg_decode_time
def main():
parser = argparse.ArgumentParser(description="Stress test for model inference")
parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
args = parser.parse_args()
token_len = args.token_len
n = args.n
num_gpu = args.num_gpu
stress_test(token_len, n, num_gpu)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,165 @@
"""
This script creates an interactive web demo for the GLM-4-9B model using Gradio,
a Python library for building quick and easy UI components for machine learning models.
It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface,
allowing users to interact with the model through a chat-like interface.
"""
import os
import gradio as gr
import torch
from threading import Thread
from typing import Union
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def predict(history, max_length, top_p, temperature):
stop = StopOnTokens()
messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
model_inputs = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
"eos_token_id": model.config.eos_token_id,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
if new_token:
history[-1][1] += new_token
yield history
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">GLM-4-9B Gradio Simple Chat Demo</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
def user(query, history):
return "", history + [[parse_text(query), ""]]
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, max_length, top_p, temperature], chatbot
)
emptyBtn.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True)

108
basic_demo/vllm_cli_demo.py Normal file
View File

@ -0,0 +1,108 @@
"""
This script creates a CLI demo with vllm backand for the glm-4-9b model,
allowing users to interact with the model through a command-line interface.
Usage:
- Run the script to start the CLI demo.
- Interact with the model by typing questions and receiving responses.
Note: The script includes a modification to handle markdown to plain text conversion,
ensuring that the CLI interface displays formatted text correctly.
"""
import time
import asyncio
from transformers import AutoTokenizer
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from typing import List, Dict
MODEL_PATH = 'THUDM/glm-4-9b'
def load_model_and_tokenizer(model_dir: str):
engine_args = AsyncEngineArgs(
model=model_dir,
tokenizer=model_dir,
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,
gpu_memory_utilization=0.3,
enforce_eager=True,
worker_use_ray=True,
engine_use_ray=False,
disable_log_requests=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_dir,
trust_remote_code=True,
encode_special_tokens=True
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, tokenizer
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH)
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False
)
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
"top_k": -1,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"max_tokens": max_dec_len,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
sampling_params = SamplingParams(**params_dict)
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
yield output.outputs[0].text
async def chat():
history = []
max_length = 8192
top_p = 0.8
temperature = 0.6
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
while True:
user_input = input("\nYou: ")
if user_input.lower() in ["exit", "quit"]:
break
history.append([user_input, ""])
messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
print("\nGLM-4: ", end="")
current_length = 0
output = ""
async for output in vllm_gen(messages, top_p, temperature, max_length):
print(output[current_length:], end="", flush=True)
current_length = len(output)
history[-1][1] = output
if __name__ == "__main__":
asyncio.run(chat())

181
composite_demo/.gitignore vendored Normal file
View File

@ -0,0 +1,181 @@
*venv
*.DS_Store
*model
*.idea/
# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python

167
composite_demo/README.md Normal file
View File

@ -0,0 +1,167 @@
# GLM-4-9B Web Demo
Read this in [English](README_en.md)
![Demo webpage](assets/demo.png)
## 安装
我们建议通过 [Conda](https://docs.conda.io/en/latest/) 进行环境管理。
执行以下命令新建一个 conda 环境并安装所需依赖:
```bash
conda create -n glm-4-demo python=3.12
conda activate glm-4-demo
pip install -r requirements.txt
```
请注意,本项目需要 Python 3.10 或更高版本。
此外,使用 Code Interpreter 还需要安装 Jupyter 内核:
```bash
ipython kernel install --name glm-4-demo --user
```
您可以修改 `~/.local/share/jupyter/kernels/glm-4-demo/kernel.json` 来改变 Jupyter 内核的配置,包括内核的启动参数等。例如,若您希望在使用 All Tools 的 Python 代码执行能力时使用 Matplotlib 画图,可以在 `argv` 数组中添加 `"--matplotlib=inline"`
若要使用浏览器和搜索功能,还需要启动浏览器后端。首先,根据 [Node.js](https://nodejs.org/en/download/package-manager)
官网的指示安装 Node.js然后安装包管理器 [PNPM](https://pnpm.io) 之后安装浏览器服务的依赖:
```bash
cd browser
npm install -g pnpm
pnpm install
```
## 运行
1. 修改 `browser/src/config.ts` 中的 `BING_SEARCH_API_KEY` 配置浏览器服务需要使用的 Bing 搜索 API Key
```diff
--- a/browser/src/config.ts
+++ b/browser/src/config.ts
@@ -3,7 +3,7 @@ export default {
BROWSER_TIMEOUT: 10000,
BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0',
- BING_SEARCH_API_KEY: '',
+ BING_SEARCH_API_KEY: '<PUT_YOUR_BING_SEARCH_KEY_HERE>',
HOST: 'localhost',
PORT: 3000,
```
2. 文生图功能需要调用 CogView API。修改 `src/tools/config.py`
,提供文生图功能需要使用的 [智谱 AI 开放平台](https://open.bigmodel.cn) API Key
```diff
--- a/src/tools/config.py
+++ b/src/tools/config.py
@@ -2,5 +2,5 @@ BROWSER_SERVER_URL = 'http://localhost:3000'
IPYKERNEL = 'glm-4-demo'
-ZHIPU_AI_KEY = ''
+ZHIPU_AI_KEY = '<PUT_YOUR_ZHIPU_AI_KEY_HERE>'
COGVIEW_MODEL = 'cogview-3'
```
3. 启动浏览器后端,在单独的 shell 中:
```bash
cd browser
pnpm start
```
4. 运行以下命令在本地加载模型并启动 demo
```bash
streamlit run main.py
```
之后即可从命令行中看到 demo 的地址,点击即可访问。初次访问需要下载并加载模型,可能需要花费一定时间。
如果已经在本地下载了模型,可以通过 `export *_MODEL_PATH=/path/to/model` 来指定从本地加载模型。可以指定的模型包括:
- `CHAT_MODEL_PATH`: 用于 All Tools 模式与文档解读模式,默认为 `THUDM/glm-4-9b-chat`
- `VLM_MODEL_PATH`: 用于 VLM 模式,默认为 `THUDM/glm-4v-9b`
Chat 模型支持使用 [vLLM](https://github.com/vllm-project/vllm) 推理。若要使用,请安装 vLLM 并设置环境变量 `USE_VLLM=1`
如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。
## 使用
GLM-4 Demo 拥有三种模式:
- All Tools: 具有完整工具调用能力的对话模式,原生支持网页浏览、代码执行、图片生成,并支持自定义工具。
- 文档解读: 支持上传文档进行文档解读与对话。
- 多模态: 支持上传图像进行图像理解与对话。
### All Tools
本模式兼容 ChatGLM3-6B 的工具注册流程。
+ 代码能力绘图能力联网能力已经自动集成用户只需按照要求配置对应的Key。
+ 本模式下不支持系统提示词,模型会自动构建提示词。
对话模式下,用户可以直接在侧边栏修改 top_p, temperature 等参数来调整模型的行为。
与模型对话时,模型将会自主决定进行工具调用。
![Tool calling](assets/tool.png)
由于原始结果可能较长,默认情况下工具调用结果被隐藏,可以通过展开折叠框查看原始的工具调用结果。
模型拥有进行网页搜索和 Python 代码执行的能力。同时,模型也可以连续调用多个工具。例如:
![Consecutive tool calling, 1](assets/web_plot_1.png)
此时模型通过调用浏览器工具进行搜索获取到了需要的数据,之后将会调用 Python 工具执行代码,利用 Matplotlib 绘图:
![Consecutive tool calling, 2](assets/web_plot_2.png)
如果提供了智谱开放平台 API Key模型也可以调用 CogView 进行图像生成:
![Image generation](assets/cogview.png)
#### 自定义工具
可以通过在 `tool_registry.py` 中注册新的工具来增强模型的能力。只需要使用 `@register_tool`
装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数 docstring
即为工具的说明;对于工具的参数,使用 `Annotated[typ: type, description: str, required: bool]` 标注参数的类型、描述和是否必须。
例如,`get_weather` 工具的注册如下:
```python
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the weather for `city_name` in the following week
"""
...
```
![The model uses tool to query the weather of Bangkok.](assets/weather.png)
### 文档解读
用户可以上传文档,使用 GLM-4-9B的长文本能力对文本进行理解。可以解析 pptxdocxpdf等文件。
+ 本模式下不支持工具调用和系统提示词。
+ 如果文本很长,可能导致模型需要的显存较高,请确认你的硬件配置。
![Doc reader demo](assets/doc_reader.png)
### 多模态
多模态模式下,用户可以利用 GLM-4V 的多模态理解能力,上传图像并与 GLM-4V 进行多轮对话:
用户可以上传图片,使用 GLM-4-9B的图像理解能力对图片进行理解。
+ 本模式必须使用 glm-4v-9b 模型。
+ 本模式下不支持工具调用和系统提示词。
+ 模型仅能对一张图片进行理解和联系对话,如需更换图片,需要开启一个新的对话。
+ 图像支持的分辨率为 1120 x 1120
![VLM demo](assets/vlm.png)

155
composite_demo/README_en.md Normal file
View File

@ -0,0 +1,155 @@
# GLM-4-9B Web Demo
![Demo webpage](assets/demo.png)
## Installation
We recommend using [Conda](https://docs.conda.io/en/latest/) for environment management.
Execute the following commands to create a conda environment and install the required dependencies:
```bash
conda create -n glm-4-demo python=3.12
conda activate glm-4-demo
pip install -r requirements.txt
```
Please note that this project requires Python 3.10 or higher.
In addition, you need to install the Jupyter kernel to use the Code Interpreter:
```bash
ipython kernel install --name glm-4-demo --user
```
You can modify `~/.local/share/jupyter/kernels/glm-4-demo/kernel.json` to change the configuration of the Jupyter
kernel, including the kernel startup parameters. For example, if you want to use Matplotlib to draw when using the
Python code execution capability of All Tools, you can add `"--matplotlib=inline"` to the `argv` array.
To use the browser and search functions, you also need to start the browser backend. First, install Node.js according to
the instructions on the [Node.js](https://nodejs.org/en/download/package-manager)
official website, then install the package manager [PNPM](https://pnpm.io) and then install the browser service
dependencies:
```bash
cd browser
npm install -g pnpm
pnpm install
```
## Run
1. Modify `BING_SEARCH_API_KEY` in `browser/src/config.ts` to configure the Bing Search API Key that the browser service
needs to use:
```diff
--- a/browser/src/config.ts
+++ b/browser/src/config.ts
@@ -3,7 +3,7 @@ export default {
BROWSER_TIMEOUT: 10000,
BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0',
- BING_SEARCH_API_KEY: '',
+ BING_SEARCH_API_KEY: '<PUT_YOUR_BING_SEARCH_KEY_HERE>',
HOST: 'localhost',
PORT: 3000,
```
2. The Wenshengtu function needs to call the CogView API. Modify `src/tools/config.py`
, provide the [Zhipu AI Open Platform](https://open.bigmodel.cn) API Key required for the Wenshengtu function:
```diff
--- a/src/tools/config.py
+++ b/src/tools/config.py
@@ -2,5 +2,5 @@ BROWSER_SERVER_URL = 'http://localhost:3000'
IPYKERNEL = 'glm4-demo'
-ZHIPU_AI_KEY = ''
+ZHIPU_AI_KEY = '<PUT_YOUR_ZHIPU_AI_KEY_HERE>'
COGVIEW_MODEL = 'cogview-3'
```
3. Start the browser backend in a separate shell:
```bash
cd browser
pnpm start
```
4. Run the following commands to load the model locally and start the demo:
```bash
streamlit run main.py
```
Then you can see the demo address from the command line and click it to access it. The first access requires downloading
and loading the model, which may take some time.
If you have downloaded the model locally, you can specify to load the model from the local
by `export *_MODEL_PATH=/path/to/model`. The models that can be specified include:
- `CHAT_MODEL_PATH`: used for All Tools mode and document interpretation mode, the default is `THUDM/glm-4-9b-chat`.
- `VLM_MODEL_PATH`: used for VLM mode, the default is `THUDM/glm-4v-9b`.
The Chat model supports reasoning using [vLLM](https://github.com/vllm-project/vllm). To use it, please install vLLM and
set the environment variable `USE_VLLM=1`.
If you need to customize the Jupyter kernel, you can specify it by `export IPYKERNEL=<kernel_name>`.
## Usage
GLM4 Demo has three modes:
- All Tools mode
- VLM mode
- Text interpretation mode
### All Tools mode
You can enhance the model's capabilities by registering new tools in `tool_registry.py`. Just use `@register_tool`
decorated function to complete the registration. For tool declarations, the function name is the name of the tool, and
the function docstring
is the description of the tool; for tool parameters, use `Annotated[typ: type, description: str, required: bool]` to
annotate the parameter type, description, and whether it is required.
For example, the registration of the `get_weather` tool is as follows:
```python
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the weather for `city_name` in the following week
"""
...
```
This mode is compatible with the tool registration process of ChatGLM3-6B.
+ Code capability, drawing capability, and networking capability have been automatically integrated. Users only need to
configure the corresponding Key as required.
+ System prompt words are not supported in this mode. The model will automatically build prompt words.
## Text interpretation mode
Users can upload documents and use the long text capability of GLM-4-9B to understand the text. It can parse pptx, docx,
pdf and other files.
+ Tool calls and system prompt words are not supported in this mode.
+ If the text is very long, the model may require a high amount of video memory. Please confirm your hardware
configuration.
## Image Understanding Mode
Users can upload images and use the image understanding capabilities of GLM-4-9B to understand the images.
+ This mode must use the glm-4v-9b model.
+ Tool calls and system prompts are not supported in this mode.
+ The model can only understand and communicate with one image. If you need to change the image, you need to open a new
conversation.
+ The supported image resolution is 1120 x 1120

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 615 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 603 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 684 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1013 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 936 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 851 KiB

144
composite_demo/browser/.gitignore vendored Normal file
View File

@ -0,0 +1,144 @@
# Created by https://www.toptal.com/developers/gitignore/api/node
# Edit at https://www.toptal.com/developers/gitignore?templates=node
### Node ###
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
.pnpm-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# Snowpack dependency directory (https://snowpack.dev/)
web_modules/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional stylelint cache
.stylelintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variable files
.env
.env.development.local
.env.test.local
.env.production.local
.env.local
# parcel-bundler cache (https://parceljs.org/)
.cache
.parcel-cache
# Next.js build output
.next
out
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and not Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# vuepress v2.x temp and cache directory
.temp
# Docusaurus cache and generated files
.docusaurus
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
# Stores VSCode versions used for testing VSCode extensions
.vscode-test
# yarn v2
.yarn/cache
.yarn/unplugged
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
### Node Patch ###
# Serverless Webpack directories
.webpack/
# Optional stylelint cache
# SvelteKit build / generate output
.svelte-kit
# End of https://www.toptal.com/developers/gitignore/api/node

3575
composite_demo/browser/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
{
"name": "glm4-browser",
"version": "1.0.0",
"description": "Browser system for GLM-4",
"main": "src/server.ts",
"scripts": {
"dev": "npx nodemon src/server",
"start": "npx ts-node src/server.ts"
},
"license": "MIT",
"dependencies": {
"express": "^4.18.3",
"jsdom": "^24.0.0",
"pnpm": "^9.1.2",
"turndown": "^7.1.2",
"winston": "^3.11.0"
},
"devDependencies": {
"@types/express": "^4.17.21",
"@types/jsdom": "^21.1.6",
"@types/node": "^20.11.20",
"@types/turndown": "^5.0.4",
"nodemon": "^3.1.0",
"ts-node": "^10.9.2"
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,745 @@
import { JSDOM } from 'jsdom';
import TurndownService from 'turndown';
import config from './config';
import { Message, ToolObservation } from './types';
import { logger, withTimeout } from './utils';
// represent a quote from a display
interface Quote {
text: string;
metadata: Metadata[];
}
interface ActionResult {
contentType: string;
metadataList?: TetherQuoteMetadata[];
metadata?: any;
roleMetadata: string;
message: string;
}
// represent a piece of metadata to be marked in the final answer
interface Metadata {
type: string;
title: string;
url: string;
lines: string[];
}
interface TetherQuoteExtra {
cited_message_idx: number;
evidence_text: string;
}
interface TetherQuoteMetadata {
type: string;
title: string;
url: string;
text: string;
pub_date?: string;
extra?: TetherQuoteExtra;
}
interface Citation {
citation_format_type: string;
start_ix: number;
end_ix: number;
metadata?: TetherQuoteMetadata;
invalid_reason?: string;
}
interface PageState {
aCounter: number;
imgCounter: number;
url: URL;
url_string: string;
hostname: string;
links: string[];
links_meta: TetherQuoteMetadata[];
lines: string[];
line_source: Record<string, Metadata>; // string representation of number interval
title?: string;
}
interface BrowserState {
pageStack: PageState[];
quoteCounter: number;
quotes: Record<string, Quote>;
}
function removeDenseLinks(document: Document, ratioThreshold: number = 0.5) {
// Remove nav elements
const navs = document.querySelectorAll('nav');
navs.forEach(nav => {
if (nav.parentNode) {
nav.parentNode.removeChild(nav);
}
});
// Query for lists, divs, spans, tables, and paragraphs
const elements = document.querySelectorAll('ul, ol, div, span, nav, table, p');
elements.forEach(element => {
if (element === null) return;
const children = Array.from(element.childNodes);
const links = element.querySelectorAll('a');
if (children.length <= 1) return;
const allText = element.textContent ? element.textContent.trim().replace(/\s+/g, '') : '';
const linksText = Array.from(links)
.map(link => (link.textContent ? link.textContent.trim() : ''))
.join('')
.replace(/\s+/g, '');
if (allText.length === 0 || linksText.length === 0) return;
let ratio = linksText.length / allText.length;
if (ratio > ratioThreshold && element.parentNode) {
element.parentNode.removeChild(element);
}
});
}
abstract class BaseBrowser {
public static toolName = 'browser' as const;
public description = 'BaseBrowser';
private turndownService = new TurndownService({
headingStyle: 'atx',
});
private state: BrowserState;
private transform(dom: JSDOM): string {
let state = this.lastPageState();
state.aCounter = 0;
state.imgCounter = 0;
state.links = [];
return this.turndownService.turndown(dom.window.document);
}
private formatPage(state: PageState): string {
let formatted_lines = state.lines.join('\n');
let formatted_title = state.title ? `TITLE: ${state.title}\n\n` : '';
let formatted_range = `\nVisible: 0% - 100%`;
let formatted_message = formatted_title + formatted_lines + formatted_range;
return formatted_message;
}
private newPageState(): PageState {
return {
aCounter: 0,
imgCounter: 0,
url: new URL('about:blank'),
url_string: 'about:blank',
hostname: '',
title: '',
links: [],
links_meta: [],
lines: [],
line_source: {},
};
}
private pushPageState(): PageState {
let state = this.newPageState();
this.state.pageStack.push(state);
return state;
}
private lastPageState(): PageState {
if (this.state.pageStack.length === 0) {
throw new Error('No page state');
}
return this.state.pageStack[this.state.pageStack.length - 1];
}
private formatErrorUrl(url: string): string {
let TRUNCATION_LIMIT = 80;
if (url.length <= TRUNCATION_LIMIT) {
return url;
}
return url.slice(0, TRUNCATION_LIMIT) + `... (URL truncated at ${TRUNCATION_LIMIT} chars)`;
}
protected functions = {
search: async (query: string, recency_days: number = -1) => {
logger.debug(`Searching for: ${query}`);
const search = new URLSearchParams({ q: query });
recency_days > 0 && search.append('recency_days', recency_days.toString());
return withTimeout(
config.BROWSER_TIMEOUT,
fetch(`${config.BING_SEARCH_API_URL}/search?${search.toString()}`, {
headers: {
'Ocp-Apim-Subscription-Key': config.BING_SEARCH_API_KEY,
}
}).then(
res =>
res.json() as Promise<{
queryContext: {
originalQuery: string;
};
webPages: {
webSearchUrl: string;
totalEstimatedMatches: number;
value: {
id: string;
name: string;
url: string;
datePublished: string; // 2018-05-18T08:00:00.0000000
datePublishedDisplayText: string;
isFamilyFriendly: boolean;
displayUrl: string;
snippet: string;
dateLastCrawled: string;
cachedPageUrl: string;
language: string;
isNavigational: boolean;
}[];
};
rankingResponse: {
mainline: {
items: {
answerType: 'WebPages';
resultIndex: number;
value: {
id: string;
};
}[];
};
};
}>,
),
)
.then(async ({ value: res }) => {
try {
let state = this.pushPageState();
let metadataList: TetherQuoteMetadata[] = [];
for (const [i, entry] of res.webPages.value.entries()) {
const url = new URL(entry.url);
const hostname = url.hostname;
state.lines.push(` # 【${i}${entry.name}${hostname}`);
state.lines.push(entry.snippet);
const quoteMetadata: Metadata = {
type: 'webpage',
title: entry.name,
url: entry.url,
lines: state.lines.slice(2 * i, 2 * i + 2),
};
state.line_source[`${2 * i}-${2 * i + 1}`] = quoteMetadata;
state.links[i] = entry.url;
const returnMetadata: TetherQuoteMetadata = {
type: quoteMetadata.type,
title: quoteMetadata.title,
url: quoteMetadata.url,
text: state.lines[2 * i + 1], // only content, not link
pub_date: entry.datePublished,
};
metadataList.push(returnMetadata);
}
const returnContentType = 'browser_result';
return {
contentType: returnContentType,
roleMetadata: returnContentType,
message: this.formatPage(state),
metadataList,
};
} catch (err) {
throw new Error(`parse error: ${err}`);
}
})
.catch(err => {
logger.error(err.message);
if (err.code === 'ECONNABORTED') {
throw new Error(`Timeout while executing search for: ${query}`);
}
throw new Error(`Network or server error occurred`);
});
},
open_url: (url: string) => {
logger.debug(`Opening ${url}`);
return withTimeout(
config.BROWSER_TIMEOUT,
fetch(url).then(res => res.text()),
)
.then(async ({ value: res, time }) => {
try {
const state = this.pushPageState();
state.url = new URL(url);
state.url_string = url;
state.hostname = state.url.hostname;
const html = res;
const dom = new JSDOM(html);
const title = dom.window.document.title;
const markdown = this.transform(dom);
state.title = title;
// Remove first line, because it will be served as the title
const lines = markdown.split('\n');
lines.shift();
// Remove consequent empty lines
let i = 0;
while (i < lines.length - 1) {
if (lines[i].trim() === '' && lines[i + 1].trim() === '') {
lines.splice(i, 1);
} else {
i++;
}
}
let page = lines.join('\n');
// The first line feed is not a typo
let text_result = `\nURL: ${url}\n${page}`;
state.lines = text_result.split('\n');
// all lines has only one source
state.line_source = {};
state.line_source[`0-${state.lines.length - 1}`] = {
type: 'webpage',
title: title,
url: url,
lines: state.lines,
};
let message = this.formatPage(state);
const returnContentType = 'browser_result';
return {
contentType: returnContentType,
roleMetadata: returnContentType,
message,
metadataList: state.links_meta,
};
} catch (err) {
throw new Error(`parse error: ${err}`);
}
})
.catch(err => {
logger.error(err.message);
if (err.code === 'ECONNABORTED') {
throw new Error(`Timeout while loading page w/ URL: ${url}`);
}
throw new Error(`Failed to load page w/ URL: ${url}`);
});
},
mclick: (ids: number[]) => {
logger.info('Entering mclick', ids);
let promises: Promise<ActionResult>[] = [];
let state = this.lastPageState();
for (let id of ids) {
if (isNaN(id) || id >= state.links.length) {
promises.push(
Promise.reject(
new Error(
`recorded='click(${id})' temporary=None permanent=None new_state=None final=None success=False feedback='Error parsing ID ${id}' metadata={}`,
),
),
);
continue;
}
let url: string;
try {
url = new URL(state.links[id], state.url).href;
} catch (err) {
logger.error(`Failed in getting ${state.links[id]}, ${state.url}`);
promises.push(
Promise.reject(
new Error(
`recorded='click(${id})' temporary=None permanent='${err}' new_state=None final=None success=False feedback='Error parsing URL for ID ${id}' metadata={}`,
),
),
);
continue;
}
const quoteIndex = this.state.quoteCounter++; // ascending in final results
promises.push(
withTimeout(
config.BROWSER_TIMEOUT,
fetch(url).then(res => res.text()),
)
.then(({ value: res, time }) => {
let state = this.newPageState();
state.url = new URL(url);
state.hostname = state.url.hostname;
try {
const html = res;
const dom = new JSDOM(html);
const title = dom.window.document.title;
state.title = title;
removeDenseLinks(dom.window.document);
let quoteText = this.transform(dom);
// remove consecutive newline
quoteText = quoteText.replace(/[\r\n]+/g, '\n');
const quoteLines = quoteText.split('\n');
state.lines = quoteLines;
const metadata = {
type: 'webpage',
title: title,
url: url,
lines: quoteLines,
};
const quoteMetadata = {
type: 'webpage',
title: title,
url: url,
text: quoteText,
};
state.line_source = {};
state.line_source[`0-${state.lines.length - 1}`] = metadata;
this.state.quotes[quoteIndex.toString()] = {
text: quoteText,
metadata: [metadata],
};
const returnContentType = 'quote_result';
return {
contentType: returnContentType,
roleMetadata: `${returnContentType} [${quoteIndex}†source]`,
message: quoteText,
metadataList: [quoteMetadata],
metadata: {
url,
},
};
} catch (err) {
throw new Error(`parse error: ${err}`);
}
})
.catch(err => {
logger.error(err.message);
if (err.code === 'ECONNABORTED') {
throw new Error(`Timeout while loading page w/ URL: ${this.formatErrorUrl(url)}`);
}
throw new Error(`Failed to load page w/ URL: ${this.formatErrorUrl(url)}`);
})
.catch(err => {
// format error message
const returnContentType = 'system_error';
throw {
contentType: returnContentType,
roleMetadata: returnContentType,
message: `recorded='click(${id})' temporary=None permanent='${
err.message
}' new_state=None final=None success=False feedback='Error fetching url ${this.formatErrorUrl(
url,
)}' metadata={}`,
metadata: {
failedURL: url,
},
} as ActionResult;
}),
);
}
return Promise.allSettled(promises).then(async results => {
const actionResults = results.map(r => {
if (r.status === 'fulfilled') {
return r.value;
} else {
logger.error(r.reason);
return r.reason as ActionResult;
}
});
if (results.filter(r => r.status === 'fulfilled').length === 0) {
// collect errors
const err_text = (results as PromiseRejectedResult[])
.map(r => (r.reason as ActionResult).message)
.join('\n');
throw new Error(err_text);
} else {
return actionResults;
}
});
},
};
constructor() {
this.state = {
pageStack: [],
quotes: {},
quoteCounter: 7,
};
this.turndownService.remove('script');
this.turndownService.remove('style');
// Add rules for turndown
this.turndownService.addRule('reference', {
filter: function (node, options: any): boolean {
return (
options.linkStyle === 'inlined' &&
node.nodeName === 'A' &&
node.getAttribute('href') !== undefined
);
},
replacement: (content, node, options): string => {
let state = this.state.pageStack[this.state.pageStack.length - 1];
if (!content || !('getAttribute' in node)) return '';
let href = undefined;
try {
if ('getAttribute' in node) {
const hostname = new URL(node.getAttribute('href')!).hostname;
// Do not append hostname when in the same domain
if (hostname === state.hostname || !hostname) {
href = '';
} else {
href = '†' + hostname;
}
}
} catch (e) {
// To prevent displaying links like '/foo/bar'
href = '';
}
if (href === undefined) return '';
const url = node.getAttribute('href')!;
let linkId = state.links.findIndex(link => link === url);
if (linkId === -1) {
linkId = state.aCounter++;
// logger.debug(`New link[${linkId}]: ${url}`);
state.links_meta.push({
type: 'webpage',
title: node.textContent!,
url: href,
text: node.textContent!,
});
state.links.push(url);
}
return `${linkId}${node.textContent}${href}`;
},
});
this.turndownService.addRule('img', {
filter: 'img',
replacement: (content, node, options): string => {
let state = this.state.pageStack[this.state.pageStack.length - 1];
return `[Image ${state.imgCounter++}]`;
},
});
// Just to change indentation, wondering why this isn't exposed as an option
this.turndownService.addRule('list', {
filter: 'li',
replacement: function (content, node, options) {
content = content
.replace(/^\n+/, '') // remove leading newlines
.replace(/\n+$/, '\n') // replace trailing newlines with just a single one
.replace(/\n/gm, '\n '); // indent
let prefix = options.bulletListMarker + ' ';
const parent = node.parentNode! as Element;
if (parent.nodeName === 'OL') {
const start = parent.getAttribute('start');
const index = Array.prototype.indexOf.call(parent.children, node);
prefix = (start ? Number(start) + index : index + 1) + '. ';
}
return ' ' + prefix + content + (node.nextSibling && !/\n$/.test(content) ? '\n' : '');
},
});
// Remove bold; remove() doesn't work on this, I don't know why
this.turndownService.addRule('emph', {
filter: ['strong', 'b'],
replacement: function (content, node, options) {
if (!content.trim()) return '';
return content;
},
});
}
abstract actionLine(content: string): Promise<ActionResult | ActionResult[]>;
async action(content: string): Promise<ToolObservation[]> {
const lines = content.split('\n');
let results: ActionResult[] = [];
for (const line of lines) {
logger.info(`Action line: ${line}`)
try {
const lineActionResult = await this.actionLine(line);
logger.debug(`Action line result: ${JSON.stringify(lineActionResult, null, 2)}`);
if (Array.isArray(lineActionResult)) {
results = results.concat(lineActionResult);
} else {
results.push(lineActionResult);
}
} catch (err) {
const returnContentType = 'system_error';
results.push({
contentType: returnContentType,
roleMetadata: returnContentType,
message: `Error when executing command ${line}\n${err}`,
metadata: {
failedCommand: line,
},
});
}
}
const observations: ToolObservation[] = [];
for (const result of results) {
const observation: ToolObservation = {
contentType: result.contentType,
result: result.message,
roleMetadata: result.roleMetadata,
metadata: result.metadata ?? {},
};
if (result.metadataList) {
observation.metadata.metadata_list = result.metadataList;
}
observations.push(observation);
}
return observations;
}
postProcess(message: Message, metadata: any) {
const quotePattern = /【(.+?)†(.*?)】/g;
const content = message.content;
let match;
let citations: Citation[] = [];
const citation_format_type = 'tether_og';
while ((match = quotePattern.exec(content))) {
logger.debug(`Citation match: ${match[0]}`);
const start_ix = match.index;
const end_ix = match.index + match[0].length;
let invalid_reason = undefined;
let metadata: TetherQuoteMetadata;
try {
let cited_message_idx = parseInt(match[1]);
let evidence_text = match[2];
let quote = this.state.quotes[cited_message_idx.toString()];
if (quote === undefined) {
invalid_reason = `'Referenced message ${cited_message_idx} in citation 【${cited_message_idx}${evidence_text}】 is not a quote or tether browsing display.'`;
logger.error(`Triggered citation error with quote undefined: ${invalid_reason}`);
citations.push({
citation_format_type,
start_ix,
end_ix,
invalid_reason,
});
} else {
let extra: TetherQuoteExtra = {
cited_message_idx,
evidence_text,
};
const quote_metadata = quote.metadata[0];
metadata = {
type: 'webpage',
title: quote_metadata.title,
url: quote_metadata.url,
text: quote_metadata.lines.join('\n'),
extra,
};
citations.push({
citation_format_type,
start_ix,
end_ix,
metadata,
});
}
} catch (err) {
logger.error(`Triggered citation error: ${err}`);
invalid_reason = `Citation Error: ${err}`;
citations.push({
start_ix,
end_ix,
citation_format_type,
invalid_reason,
});
}
}
metadata.citations = citations;
}
getState() {
return this.state;
}
}
export class SimpleBrowser extends BaseBrowser {
public description = 'SimpleBrowser';
constructor() {
super();
}
async actionLine(content: string): Promise<ActionResult | ActionResult[]> {
const regex = /(\w+)\(([^)]*)\)/;
const matches = content.match(regex);
if (matches) {
const functionName = matches[1];
let args_string = matches[2];
if (functionName === 'mclick') {
args_string = args_string.trim().slice(1, -1); // remove '[' and ']'
}
const args = args_string.split(',').map(arg => arg.trim());
let result;
switch (functionName) {
case 'search':
logger.debug(`SimpleBrowser action search ${args[0].slice(1, -1)}`);
const recency_days = /(^|\D)(\d+)($|\D)/.exec(args[1])?.[2] as undefined | `${number}`;
result = await this.functions.search(
args[0].slice(1, -1), // slice quote "query"
recency_days && Number(recency_days),
);
break;
case 'open_url':
logger.debug(`SimpleBrowser action open_url ${args[0].slice(1, -1)}`);
result = await this.functions.open_url(args[0].slice(1, -1));
break;
case 'mclick':
logger.debug(`SimpleBrowser action mclick ${args}`);
result = await this.functions.mclick(args.map(x => parseInt(x)));
break;
default:
throw new Error(`Parse Error: ${content}`);
}
return result;
} else {
throw new Error('Parse Error');
}
}
}
if (require.main === module) {
(async () => {
let browser = new SimpleBrowser();
let demo = async (action: string) => {
logger.info(` ------ Begin of Action: ${action} ------`);
let results = await browser.action(action);
for (const [idx, result] of results.entries()) {
logger.info(`[Result ${idx}] contentType: ${result.contentType}`);
logger.info(`[Result ${idx}] roleMetadata: ${result.roleMetadata}`);
logger.info(`[Result ${idx}] result: ${result.result}`);
logger.info(`[Result ${idx}] metadata: ${JSON.stringify(result.metadata, null, 2)}`);
}
logger.info(` ------ End of Action: ${action} ------\n\n`);
};
await demo("search('Apple Latest News')");
await demo('mclick([0, 1, 5, 6])');
await demo('mclick([1, 999999])');
await demo("open_url('https://chatglm.cn')");
await demo("search('zhipu latest News')");
await demo('mclick([0, 1, 5, 6])');
})();
}

View File

@ -0,0 +1,10 @@
export default {
LOG_LEVEL: 'debug',
BROWSER_TIMEOUT: 10000,
BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/',
BING_SEARCH_API_KEY: '',
HOST: 'localhost',
PORT: 3000,
};

View File

@ -0,0 +1,55 @@
import express, { Express, Request, Response } from 'express';
import { SimpleBrowser } from './browser';
import config from './config';
import { logger } from './utils';
const session_history: Record<string, SimpleBrowser> = {};
const app: Express = express();
app.use(express.json());
app.post('/', async (req: Request, res: Response) => {
const {
session_id,
action,
}: {
session_id: string;
action: string;
} = req.body;
logger.info(`session_id: ${session_id}`);
logger.info(`action: ${action}`);
if (!session_history[session_id]) {
session_history[session_id] = new SimpleBrowser();
}
const browser = session_history[session_id];
try {
res.json(await browser.action(action));
} catch (err) {
logger.error(err);
res.status(400).json(err);
}
})
process.on('SIGINT', () => {
process.exit(0);
});
process.on('uncaughtException', e => {
logger.error(e);
});
const { HOST, PORT } = config;
(async () => {
app.listen(PORT, HOST, () => {
logger.info(`⚡️[server]: Server is running at http://${HOST}:${PORT}`);
try {
(<any>process).send('ready');
} catch (err) {}
});
})();

View File

@ -0,0 +1,25 @@
export interface File {
id: string;
name: string;
size: number;
}
export interface Metadata {
files?: File[];
reference?: string;
}
export interface Message {
role: 'user' | 'assistant' | 'system' | 'observation';
metadata: string;
content: string;
request_metadata?: Metadata;
}
export interface ToolObservation {
contentType: string;
result: string;
text?: string;
roleMetadata?: string; // metadata for <|observation|>${metadata}
metadata: any; // metadata for response
}

View File

@ -0,0 +1,56 @@
import winston from 'winston';
import config from './config';
export class TimeoutError extends Error {}
const logLevel = config.LOG_LEVEL;
export const logger = winston.createLogger({
level: logLevel,
format: winston.format.combine(
winston.format.colorize(),
winston.format.printf(info => {
return `${info.level}: ${info.message}`;
}),
),
transports: [new winston.transports.Console()],
});
console.log('LOG_LEVEL', logLevel);
export const parseHrtimeToMillisecond = (hrtime: [number, number]): number => {
return (hrtime[0] + hrtime[1] / 1e9) * 1000;
};
export const promiseWithTime = <T>(
promise: Promise<T>
): Promise<{
value: T;
time: number;
}> => {
return new Promise((resolve, reject) => {
const startTime = process.hrtime();
promise
.then(value => {
resolve({
value: value,
time: parseHrtimeToMillisecond(process.hrtime(startTime))
});
})
.catch(err => reject(err));
});
};
export const withTimeout = <T>(
millis: number,
promise: Promise<T>
): Promise<{
value: T;
time: number;
}> => {
const timeout = new Promise<{ value: T; time: number }>((_, reject) =>
setTimeout(() => reject(new TimeoutError()), millis)
);
return Promise.race([promiseWithTime(promise), timeout]);
};

View File

@ -0,0 +1,15 @@
{
"compilerOptions": {
"target": "es2022",
"lib": ["es2022", "dom"],
"module": "commonjs",
"rootDir": "./",
"outDir": "./dist",
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": true,
},
"ts-node": {
"transpileOnly": true
}
}

View File

@ -0,0 +1,22 @@
accelerate
huggingface_hub>=0.19.4
ipykernel>=6.26.0
ipython>=8.18.1
jupyter_client>=8.6.0
langchain
langchain-community
matplotlib
pillow>=10.1.0
pymupdf
python-docx
python-pptx
pyyaml>=6.0.1
requests>=2.31.0
sentencepiece
streamlit>=1.35.0
tiktoken
transformers==4.40.0
zhipuai>=2.1.0
# Please install vllm if you'd like to use long context model.
# vllm

View File

@ -0,0 +1,98 @@
"""
This is a client part of composite_demo.
We provide two clients, HFClient and VLLMClient, which are used to interact with the model.
The HFClient is used to interact with the transformers backend, and the VLLMClient is used to interact with the VLLM model.
"""
import json
from collections.abc import Generator
from copy import deepcopy
from enum import Enum, auto
from typing import Protocol
import streamlit as st
from conversation import Conversation, build_system_prompt
from tools.tool_registry import ALL_TOOLS
class ClientType(Enum):
HF = auto()
VLLM = auto()
class Client(Protocol):
def __init__(self, model_path: str): ...
def generate_stream(
self,
tools: list[dict],
history: list[Conversation],
**parameters,
) -> Generator[tuple[str | dict, list[dict]]]: ...
def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
chat_history = []
if len(tools) > 0:
chat_history.append(
{"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)}
)
for conversation in history:
role = str(conversation.role).removeprefix("<|").removesuffix("|>")
item = {
"role": role,
"content": conversation.content,
}
if conversation.metadata:
item["metadata"] = conversation.metadata
# Only append image for user
if role == "user" and conversation.image:
item["image"] = conversation.image
chat_history.append(item)
return chat_history
def process_response(output, history):
content = ""
history = deepcopy(history)
for response in output.split("<|assistant|>"):
if "\n" in response:
metadata, content = response.split("\n", maxsplit=1)
else:
metadata, content = "", response
if not metadata.strip():
content = content.strip()
history.append({"role": "assistant", "metadata": metadata, "content": content})
content = content.replace("[[训练时间]]", "2023年")
else:
history.append({"role": "assistant", "metadata": metadata, "content": content})
if history[0]["role"] == "system" and "tools" in history[0]:
parameters = json.loads(content)
content = {"name": metadata.strip(), "parameters": parameters}
else:
content = {"name": metadata.strip(), "content": content}
return content, history
# glm-4v-9b is not available in VLLM backend, use HFClient instead.
@st.cache_resource(max_entries=1, show_spinner="Loading model...")
def get_client(model_path, typ: ClientType) -> Client:
match typ:
case ClientType.HF:
from clients.hf import HFClient
return HFClient(model_path)
case ClientType.VLLM:
try:
from clients.vllm import VLLMClient
except ImportError as e:
e.msg += "; did you forget to install vLLM?"
raise
return VLLMClient(model_path)
raise NotImplementedError(f"Client type {typ} is not supported.")

View File

@ -0,0 +1,59 @@
"""
HuggingFace client.
"""
import threading
from collections.abc import Generator
from threading import Thread
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from client import Client, process_input, process_response
from conversation import Conversation
class HFClient(Client):
def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
def generate_stream(
self,
tools: list[dict],
history: list[Conversation],
**parameters,
) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, tools)
model_inputs = self.tokenizer.apply_chat_template(
chat_history,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).to(self.model.device)
streamer = TextIteratorStreamer(
tokenizer=self.tokenizer,
timeout=5,
skip_prompt=True,
)
generate_kwargs = {
**model_inputs,
"streamer": streamer,
"eos_token_id": [151329, 151336, 151338],
"do_sample": True,
}
generate_kwargs.update(parameters)
t = Thread(target=self.model.generate, kwargs=generate_kwargs)
t.start()
total_text = ""
for token_text in streamer:
total_text += token_text
yield process_response(total_text, chat_history)

View File

@ -0,0 +1,64 @@
"""
vLLM client.
Please install [vLLM](https://github.com/vllm-project/vllm) according to its
installation guide before running this client.
"""
import time
from collections.abc import Generator
from transformers import AutoTokenizer
from vllm import SamplingParams, LLMEngine, EngineArgs
from client import Client, process_input, process_response
from conversation import Conversation
class VLLMClient(Client):
def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
self.engine_args = EngineArgs(
model=model_path,
tensor_parallel_size=1,
dtype="bfloat16", # torch.bfloat16 is needed.
trust_remote_code=True,
gpu_memory_utilization=0.6,
enforce_eager=True,
worker_use_ray=False,
)
self.engine = LLMEngine.from_engine_args(self.engine_args)
def generate_stream(
self, tools: list[dict], history: list[Conversation], **parameters
) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, tools)
model_inputs = self.tokenizer.apply_chat_template(
chat_history, add_generation_prompt=True, tokenize=False
)
parameters["max_tokens"] = parameters.pop("max_new_tokens")
params_dict = {
"n": 1,
"best_of": 1,
"top_p": 1,
"top_k": -1,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"logprobs": None,
"prompt_logprobs": None,
}
params_dict.update(parameters)
sampling_params = SamplingParams(**params_dict)
self.engine.add_request(
request_id=str(time.time()), inputs=model_inputs, params=sampling_params
)
while self.engine.has_unfinished_requests():
request_outputs = self.engine.step()
for request_output in request_outputs:
yield process_response(request_output.outputs[0].text, chat_history)

View File

@ -0,0 +1,165 @@
import json
import re
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from PIL.Image import Image
from tools.browser import Quote, quotes
QUOTE_REGEX = re.compile(r"【(\d+)†(.+?)】")
SELFCOG_PROMPT = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
DATE_PROMPT = "当前日期: %Y-%m-%d"
TOOL_SYSTEM_PROMPTS = {
"python": "当你向 `python` 发送包含 Python 代码的消息时,该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出,或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中,`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用,这些在线内容的访问将不会成功。",
"simple_browser": "你可以使用 `simple_browser` 工具。该工具支持以下函数:\n`search(query: str, recency_days: int)`:使用搜索引擎进行查询并显示结果,可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`:获取一系列指定 id 的页面内容。每次调用时须选择3-10个页面。选择多个角度的页面同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`:打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息,也可以再次使用 `search` 进行搜索。",
"cogview": "如果用户的请求中包含了对图像的描述,你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述,规则:\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求,需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明,所在地为中国,持有中国立场并遵循中国社会主义价值观。",
}
FILE_TEMPLATE = "[File Name]\n{file_name}\n[File Content]\n{file_content}"
def build_system_prompt(
enabled_tools: list[str],
functions: list[dict],
):
value = SELFCOG_PROMPT
value += "\n\n" + datetime.now().strftime(DATE_PROMPT)
value += "\n\n# 可用工具"
contents = []
for tool in enabled_tools:
contents.append(f"\n\n## {tool}\n\n{TOOL_SYSTEM_PROMPTS[tool]}")
for function in functions:
content = f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
contents.append(content)
value += "".join(contents)
return value
def response_to_str(response: str | dict[str, str]) -> str:
"""
Convert response to string.
"""
if isinstance(response, dict):
return response.get("name", "") + response.get("content", "")
return response
class Role(Enum):
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
OBSERVATION = auto()
def __str__(self):
match self:
case Role.SYSTEM:
return "<|system|>"
case Role.USER:
return "<|user|>"
case Role.ASSISTANT | Role.TOOL:
return "<|assistant|>"
case Role.OBSERVATION:
return "<|observation|>"
# Get the message block for the given role
def get_message(self):
# Compare by value here, because the enum object in the session state
# is not the same as the enum cases here, due to streamlit's rerunning
# behavior.
match self.value:
case Role.SYSTEM.value:
return
case Role.USER.value:
return st.chat_message(name="user", avatar="user")
case Role.ASSISTANT.value:
return st.chat_message(name="assistant", avatar="assistant")
case Role.TOOL.value:
return st.chat_message(name="tool", avatar="assistant")
case Role.OBSERVATION.value:
return st.chat_message(name="observation", avatar="assistant")
case _:
st.error(f"Unexpected role: {self}")
@dataclass
class Conversation:
role: Role
content: str | dict
# Processed content
saved_content: str | None = None
metadata: str | None = None
image: str | Image | None = None
def __str__(self) -> str:
metadata_str = self.metadata if self.metadata else ""
return f"{self.role}{metadata_str}\n{self.content}"
# Human readable format
def get_text(self) -> str:
text = self.saved_content or self.content
match self.role.value:
case Role.TOOL.value:
text = f"Calling tool `{self.metadata}`:\n\n```python\n{text}\n```"
case Role.OBSERVATION.value:
text = f"```python\n{text}\n```"
return text
# Display as a markdown block
def show(self, placeholder: DeltaGenerator | None = None) -> str:
if placeholder:
message = placeholder
else:
message = self.role.get_message()
if self.image:
message.image(self.image, width=512)
if self.role == Role.OBSERVATION:
metadata_str = f"from {self.metadata}" if self.metadata else ""
message = message.expander(f"Observation {metadata_str}")
text = self.get_text()
if self.role != Role.USER:
show_text = text
else:
splitted = text.split('files uploaded.\n')
if len(splitted) == 1:
show_text = text
else:
# Show expander for document content
doc = splitted[0]
show_text = splitted[-1]
expander = message.expander(f'File Content')
expander.markdown(doc)
message.markdown(show_text)
def postprocess_text(text: str, replace_quote: bool) -> str:
text = text.replace("\(", "$")
text = text.replace("\)", "$")
text = text.replace("\[", "$$")
text = text.replace("\]", "$$")
text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "")
text = text.replace("<|user|>", "")
text = text.replace("<|endoftext|>", "")
# Replace quotes
if replace_quote:
for match in QUOTE_REGEX.finditer(text):
quote_id = match.group(1)
quote = quotes.get(quote_id, Quote("未找到引用内容", ""))
text = text.replace(
match.group(0), f" (来源:[{quote.title}]({quote.url})) "
)
return text.strip()

356
composite_demo/src/main.py Normal file
View File

@ -0,0 +1,356 @@
"""
This demo show the All tools and Long Context chat Capabilities of GLM-4.
Please follow the Readme.md to run the demo.
"""
import os
import traceback
from enum import Enum
from io import BytesIO
from uuid import uuid4
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from PIL import Image
from client import Client, ClientType, get_client
from conversation import (
FILE_TEMPLATE,
Conversation,
Role,
postprocess_text,
response_to_str,
)
from tools.tool_registry import dispatch_tool, get_tools
from utils import extract_pdf, extract_docx, extract_pptx, extract_text
CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
class Mode(str, Enum):
ALL_TOOLS = "🛠️ All Tools"
LONG_CTX = "📝 文档解读"
VLM = "🖼️ 多模态"
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None = None,
) -> None:
"""
Append a conversation piece into history, meanwhile show it in a new markdown block
"""
history.append(conversation)
conversation.show(placeholder)
st.set_page_config(
page_title="GLM-4 Demo",
page_icon=":robot:",
layout="centered",
initial_sidebar_state="expanded",
)
st.title("GLM-4 Demo")
st.markdown(
"<sub>智谱AI 公开在线技术文档: https://zhipu-ai.feishu.cn/wiki/RuMswanpkiRh3Ok4z5acOABBnjf </sub> \n\n <sub> 更多 GLM-4 开源模型的使用方法请参考文档。</sub>",
unsafe_allow_html=True,
)
with st.sidebar:
top_p = st.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
top_k = st.slider("top_k", 1, 20, 10, step=1, key="top_k")
temperature = st.slider("temperature", 0.0, 1.5, 0.95, step=0.01)
repetition_penalty = st.slider("repetition_penalty", 0.0, 2.0, 1.0, step=0.01)
max_new_tokens = st.slider("max_new_tokens", 1, 4096, 2048, step=1)
cols = st.columns(2)
export_btn = cols[0]
clear_history = cols[1].button("Clear", use_container_width=True)
retry = export_btn.button("Retry", use_container_width=True)
if clear_history:
page = st.session_state.page
client = st.session_state.client
st.session_state.clear()
st.session_state.page = page
st.session_state.client = client
st.session_state.files_uploaded = False
st.session_state.uploaded_texts = ""
st.session_state.uploaded_file_nums = 0
st.session_state.history = []
if "files_uploaded" not in st.session_state:
st.session_state.files_uploaded = False
if "session_id" not in st.session_state:
st.session_state.session_id = uuid4()
if "history" not in st.session_state:
st.session_state.history = []
first_round = len(st.session_state.history) == 0
def build_client(mode: Mode) -> Client:
match mode:
case Mode.ALL_TOOLS:
st.session_state.top_k = 10
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
return get_client(CHAT_MODEL_PATH, typ)
case Mode.LONG_CTX:
st.session_state.top_k = 10
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
return get_client(CHAT_MODEL_PATH, typ)
case Mode.VLM:
st.session_state.top_k = 1
# vLLM is not available for VLM mode
return get_client(VLM_MODEL_PATH, ClientType.HF)
# Callback function for page change
def page_changed() -> None:
global client
new_page: str = st.session_state.page
st.session_state.history.clear()
st.session_state.client = build_client(Mode(new_page))
page = st.radio(
"选择功能",
[mode.value for mode in Mode],
key="page",
horizontal=True,
index=None,
label_visibility="hidden",
on_change=page_changed,
)
HELP = """
### 🎉 欢迎使用 GLM-4!
请在上方选取一个功能每次切换功能时将会重新加载模型并清空对话历史
文档解读模式与 VLM 模式仅支持在第一轮传入文档或图像
""".strip()
if page is None:
st.markdown(HELP)
exit()
if page == Mode.LONG_CTX:
if first_round:
uploaded_files = st.file_uploader(
"上传文件",
type=["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"],
accept_multiple_files=True,
)
if uploaded_files and not st.session_state.files_uploaded:
uploaded_texts = []
for uploaded_file in uploaded_files:
file_name: str = uploaded_file.name
random_file_name = str(uuid4())
file_extension = os.path.splitext(file_name)[1]
file_path = os.path.join("/tmp", random_file_name + file_extension)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
if file_name.endswith(".pdf"):
content = extract_pdf(file_path)
elif file_name.endswith(".docx"):
content = extract_docx(file_path)
elif file_name.endswith(".pptx"):
content = extract_pptx(file_path)
else:
content = extract_text(file_path)
uploaded_texts.append(
FILE_TEMPLATE.format(file_name=file_name, file_content=content)
)
os.remove(file_path)
st.session_state.uploaded_texts = "\n\n".join(uploaded_texts)
st.session_state.uploaded_file_nums = len(uploaded_files)
else:
st.session_state.uploaded_texts = ""
st.session_state.uploaded_file_nums = 0
elif page == Mode.VLM:
if first_round:
uploaded_image = st.file_uploader(
"上传图片",
type=["png", "jpg", "jpeg", "bmp", "tiff", "webp"],
accept_multiple_files=False,
)
if uploaded_image:
data: bytes = uploaded_image.read()
image = Image.open(BytesIO(data)).convert("RGB")
st.session_state.uploaded_image = image
else:
st.session_state.uploaded_image = None
prompt_text = st.chat_input("Chat with GLM-4!", key="chat_input")
if prompt_text == "" and retry == False:
print("\n== Clean ==\n")
st.session_state.history = []
exit()
history: list[Conversation] = st.session_state.history
if retry:
print("\n== Retry ==\n")
last_user_conversation_idx = None
for idx, conversation in enumerate(history):
if conversation.role.value == Role.USER.value:
last_user_conversation_idx = idx
if last_user_conversation_idx is not None:
prompt_text = history[last_user_conversation_idx].content
print(f"New prompt: {prompt_text}, idx = {last_user_conversation_idx}")
del history[last_user_conversation_idx:]
for conversation in history:
conversation.show()
tools = get_tools() if page == Mode.ALL_TOOLS else []
client: Client = st.session_state.client
def main(prompt_text: str):
global client
assert client is not None
if prompt_text:
prompt_text = prompt_text.strip()
# Append uploaded files
uploaded_texts = st.session_state.get("uploaded_texts")
if page == Mode.LONG_CTX and uploaded_texts and first_round:
meta_msg = "{} files uploaded.\n".format(
st.session_state.uploaded_file_nums
)
prompt_text = uploaded_texts + "\n\n\n" + meta_msg + prompt_text
# Clear after first use
st.session_state.files_uploaded = True
st.session_state.uploaded_texts = ""
st.session_state.uploaded_file_nums = 0
image = st.session_state.get("uploaded_image")
if page == Mode.VLM and image and first_round:
st.session_state.uploaded_image = None
role = Role.USER
append_conversation(Conversation(role, prompt_text, image=image), history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(
name="assistant", avatar="assistant"
)
markdown_placeholder = message_placeholder.empty()
def add_new_block():
nonlocal message_placeholder, markdown_placeholder
message_placeholder = placeholder.chat_message(
name="assistant", avatar="assistant"
)
markdown_placeholder = message_placeholder.empty()
def commit_conversation(
role: Role,
text: str,
metadata: str | None = None,
image: str | None = None,
new: bool = False,
):
processed_text = postprocess_text(text, role.value == Role.ASSISTANT.value)
conversation = Conversation(role, text, processed_text, metadata, image)
# Use different placeholder for new block
placeholder = message_placeholder if new else markdown_placeholder
append_conversation(
conversation,
history,
placeholder,
)
response = ""
for _ in range(10):
last_response = None
history_len = None
try:
for response, chat_history in client.generate_stream(
tools=tools,
history=history,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
):
if history_len is None:
history_len = len(chat_history)
elif history_len != len(chat_history):
commit_conversation(Role.ASSISTANT, last_response)
add_new_block()
history_len = len(chat_history)
last_response = response
replace_quote = chat_history[-1]["role"] == "assistant"
markdown_placeholder.markdown(
postprocess_text(
str(response) + "", replace_quote=replace_quote
)
)
else:
metadata = (
page == Mode.ALL_TOOLS
and isinstance(response, dict)
and response.get("name")
or None
)
role = Role.TOOL if metadata else Role.ASSISTANT
text = (
response.get("content")
if metadata
else response_to_str(response)
)
commit_conversation(role, text, metadata)
if metadata:
add_new_block()
try:
with markdown_placeholder:
with st.spinner(f"Calling tool {metadata}..."):
observations = dispatch_tool(
metadata, text, str(st.session_state.session_id)
)
except Exception as e:
traceback.print_exc()
st.error(f'Uncaught exception in `"{metadata}"`: {e}')
break
for observation in observations:
observation.text = observation.text
commit_conversation(
Role.OBSERVATION,
observation.text,
observation.role_metadata,
observation.image_url,
new=True,
)
add_new_block()
continue
else:
break
except Exception as e:
traceback.print_exc()
st.error(f"Uncaught exception: {traceback.format_exc()}")
else:
st.error("Too many chaining function calls!")
main(prompt_text)

View File

@ -0,0 +1,61 @@
"""
Simple browser tool.
# Usage
Please start the backend browser server according to the instructions in the README.
"""
from pprint import pprint
import re
import requests
import streamlit as st
from dataclasses import dataclass
from .config import BROWSER_SERVER_URL
from .interface import ToolObservation
QUOTE_REGEX = re.compile(r"\[(\d+)†(.+?)\]")
@dataclass
class Quote:
title: str
url: str
# Quotes for displaying reference
if "quotes" not in st.session_state:
st.session_state.quotes = {}
quotes: dict[str, Quote] = st.session_state.quotes
def map_response(response: dict) -> ToolObservation:
# Save quotes for reference
print('===BROWSER_RESPONSE===')
pprint(response)
role_metadata = response.get("roleMetadata")
metadata = response.get("metadata")
if role_metadata.split()[0] == 'quote_result' and metadata:
quote_id = QUOTE_REGEX.search(role_metadata.split()[1]).group(1)
quote: dict[str, str] = metadata['metadata_list'][0]
quotes[quote_id] = Quote(quote['title'], quote['url'])
elif role_metadata == 'browser_result' and metadata:
for i, quote in enumerate(metadata['metadata_list']):
quotes[str(i)] = Quote(quote['title'], quote['url'])
return ToolObservation(
content_type=response.get("contentType"),
text=response.get("result"),
role_metadata=role_metadata,
metadata=metadata,
)
def tool_call(code: str, session_id: str) -> list[ToolObservation]:
request = {
"session_id": session_id,
"action": code,
}
response = requests.post(BROWSER_SERVER_URL, json=request).json()
return list(map(map_response, response))

View File

@ -0,0 +1,23 @@
import streamlit as st
from zhipuai import ZhipuAI
from zhipuai.types.image import GeneratedImage
from .config import COGVIEW_MODEL, ZHIPU_AI_KEY
from .interface import ToolObservation
@st.cache_resource
def get_zhipu_client():
return ZhipuAI(api_key=ZHIPU_AI_KEY)
def map_response(img: GeneratedImage):
return ToolObservation(
content_type='image',
text='CogView 已经生成并向用户展示了生成的图片。',
image_url=img.url,
role_metadata='cogview_result'
)
def tool_call(prompt: str, session_id: str) -> list[ToolObservation]:
client = get_zhipu_client()
response = client.images.generations(model=COGVIEW_MODEL, prompt=prompt).data
return list(map(map_response, response))

View File

@ -0,0 +1,6 @@
BROWSER_SERVER_URL = 'http://localhost:3000'
IPYKERNEL = 'glm-4-demo'
ZHIPU_AI_KEY = ''
COGVIEW_MODEL = 'cogview-3'

View File

@ -0,0 +1,10 @@
from dataclasses import dataclass
from typing import Any
@dataclass
class ToolObservation:
content_type: str
text: str
image_url: str | None = None
role_metadata: str | None = None
metadata: Any = None

View File

@ -0,0 +1,200 @@
from pprint import pprint
import queue
import re
from subprocess import PIPE
from typing import Literal
import jupyter_client
import streamlit as st
from .config import IPYKERNEL
from .interface import ToolObservation
ANSI_ESCAPE = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
CODE = re.compile(r'```([^\n]*)\n(.*?)```')
class CodeKernel:
def __init__(self,
kernel_name='kernel',
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1):
self.kernel_name = kernel_name
self.kernel_id = kernel_id
self.kernel_config_path = kernel_config_path
self.python_path = python_path
self.ipython_path = ipython_path
self.init_file_path = init_file_path
self.verbose = verbose
if python_path is None and ipython_path is None:
env = None
else:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
# Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
connection_file=self.kernel_config_path,
exec_files=[self.init_file_path],
env=env)
if self.kernel_config_path:
self.kernel_manager.load_connection_file()
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_config_path))
else:
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_manager.connection_file))
if verbose:
pprint(self.kernel_manager.get_connection_info())
# Initialize the code kernel
self.kernel = self.kernel_manager.blocking_client()
# self.kernel.load_connection_file()
self.kernel.start_channels()
print("Code kernel started.")
def execute(self, code):
self.kernel.execute(code)
try:
shell_msg = self.kernel.get_shell_msg(timeout=30)
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
while True:
msg_out = io_msg_content
### Poll the message
try:
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
break
except queue.Empty:
break
return shell_msg, msg_out
except Exception as e:
print(e)
return None
def execute_interactive(self, code, verbose=False):
shell_msg = self.kernel.execute_interactive(code)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def inspect(self, code, verbose=False):
msg_id = self.kernel.inspect(code)
shell_msg = self.kernel.get_shell_msg(timeout=30)
if shell_msg is queue.Empty:
if verbose:
print("Timeout waiting for shell message.")
self.check_msg(shell_msg, verbose=verbose)
return shell_msg
def get_error_msg(self, msg, verbose=False) -> str | None:
if msg['content']['status'] == 'error':
try:
error_msg = msg['content']['traceback']
except:
try:
error_msg = msg['content']['traceback'][-1].strip()
except:
error_msg = "Traceback Error"
if verbose:
print("Error: ", error_msg)
return error_msg
return None
def check_msg(self, msg, verbose=False):
status = msg['content']['status']
if status == 'ok':
if verbose:
print("Execution succeeded.")
elif status == 'error':
for line in msg['content']['traceback']:
if verbose:
print(line)
def shutdown(self):
# Shutdown the backend kernel
self.kernel_manager.shutdown_kernel()
print("Backend kernel shutdown.")
# Shutdown the code kernel
self.kernel.shutdown()
print("Code kernel shutdown.")
def restart(self):
# Restart the backend kernel
self.kernel_manager.restart_kernel()
# print("Backend kernel restarted.")
def interrupt(self):
# Interrupt the backend kernel
self.kernel_manager.interrupt_kernel()
# print("Backend kernel interrupted.")
def is_alive(self):
return self.kernel.is_alive()
def clean_ansi_codes(input_string):
return ANSI_ESCAPE.sub('', input_string)
def extract_code(text: str) -> str:
matches = CODE.findall(text, re.DOTALL)
return matches[-1][1]
def execute(
code: str,
kernel: CodeKernel
) -> tuple[Literal['text', 'image'] | None, str]:
res = ""
res_type = None
code = code.replace("<|observation|>", "")
code = code.replace("<|assistant|>python", "")
code = code.replace("<|assistant|>", "")
code = code.replace("<|user|>", "")
code = code.replace("<|system|>", "")
msg, output = kernel.execute(code)
if msg['metadata']['status'] == "timeout":
return res_type, 'Timed out'
elif msg['metadata']['status'] == 'error':
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
if 'text' in output:
res_type = "text"
res = output['text']
elif 'data' in output:
for key in output['data']:
if 'text/plain' in key:
res_type = "text"
res = output['data'][key]
elif 'image/png' in key:
res_type = "image"
res = output['data'][key]
break
return res_type, res
@st.cache_resource
def get_kernel() -> CodeKernel:
return CodeKernel()
def tool_call(code: str, session_id: str) -> list[ToolObservation]:
kernel = get_kernel()
res_type, res = execute(code, kernel)
# Convert base64 to data uri
text = '[Image]' if res_type == 'image' else res
image = f'data:image/png;base64,{res}' if res_type == 'image' else None
return [ToolObservation(res_type, text, image)]

View File

@ -0,0 +1,188 @@
"""
This code is the tool registration part. By registering the tool, the model can call the tool.
This code provides extended functionality to the model, enabling it to call and interact with a variety of utilities
through defined interfaces.
"""
from collections.abc import Callable
import copy
import inspect
import json
from pprint import pformat
import traceback
from types import GenericAlias
from typing import get_origin, Annotated
import subprocess
from .interface import ToolObservation
from .browser import tool_call as browser
from .cogview import tool_call as cogview
from .python import tool_call as python
ALL_TOOLS = {
"simple_browser": browser,
"python": python,
"cogview": cogview,
}
_TOOL_HOOKS = {}
_TOOL_DESCRIPTIONS = []
def register_tool(func: Callable):
tool_name = func.__name__
tool_description = inspect.getdoc(func).strip()
python_params = inspect.signature(func).parameters
tool_params = []
for name, param in python_params.items():
annotation = param.annotation
if annotation is inspect.Parameter.empty:
raise TypeError(f"Parameter `{name}` missing type annotation")
if get_origin(annotation) != Annotated:
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
typ, (description, required) = annotation.__origin__, annotation.__metadata__
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
if not isinstance(description, str):
raise TypeError(f"Description for `{name}` must be a string")
if not isinstance(required, bool):
raise TypeError(f"Required for `{name}` must be a bool")
tool_params.append(
{
"name": name,
"description": description,
"type": typ,
"required": required,
}
)
tool_def = {
"name": tool_name,
"description": tool_description,
"params": tool_params,
}
# print("[registered tool] " + pformat(tool_def))
_TOOL_HOOKS[tool_name] = func
_TOOL_DESCRIPTIONS.append(tool_def)
return func
def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObservation]:
# Dispatch predefined tools
if tool_name in ALL_TOOLS:
return ALL_TOOLS[tool_name](code, session_id)
code = code.strip().rstrip('<|observation|>').strip()
# Dispatch custom tools
try:
tool_params = json.loads(code)
except json.JSONDecodeError as e:
err = f"Error decoding JSON: {e}"
return [ToolObservation("system_error", err)]
if tool_name not in _TOOL_HOOKS:
err = f"Tool `{tool_name}` not found. Please use a provided tool."
return [ToolObservation("system_error", err)]
tool_hook = _TOOL_HOOKS[tool_name]
try:
ret: str = tool_hook(**tool_params)
return [ToolObservation(tool_name, str(ret))]
except:
err = traceback.format_exc()
return [ToolObservation("system_error", err)]
def get_tools() -> list[dict]:
return copy.deepcopy(_TOOL_DESCRIPTIONS)
# Tool Definitions
@register_tool
def random_number_generator(
seed: Annotated[int, "The random seed used by the generator", True],
range: Annotated[tuple[int, int], "The range of the generated numbers", True],
) -> int:
"""
Generates a random number x, s.t. range[0] <= x < range[1]
"""
if not isinstance(seed, int):
raise TypeError("Seed must be an integer")
if not isinstance(range, tuple):
raise TypeError("Range must be a tuple")
if not isinstance(range[0], int) or not isinstance(range[1], int):
raise TypeError("Range must be a tuple of integers")
import random
return random.Random(seed).randint(*range)
@register_tool
def get_weather(
city_name: Annotated[str, "The name of the city to be queried", True],
) -> str:
"""
Get the current weather for `city_name`
"""
if not isinstance(city_name, str):
raise TypeError("City name must be a string")
key_selection = {
"current_condition": [
"temp_C",
"FeelsLikeC",
"humidity",
"weatherDesc",
"observation_time",
],
}
import requests
try:
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
resp.raise_for_status()
resp = resp.json()
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
except:
import traceback
ret = (
"Error encountered while fetching weather data!\n" + traceback.format_exc()
)
return str(ret)
@register_tool
def get_shell(
query: Annotated[str, "The command should run in Linux shell", True],
) -> str:
"""
Use shell to run command
"""
if not isinstance(query, str):
raise TypeError("Command must be a string")
try:
result = subprocess.run(
query,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return result.stdout
except subprocess.CalledProcessError as e:
return e.stderr
if __name__ == "__main__":
# print(dispatch_tool("get_shell", {"query": "pwd"}))
print(get_tools())

View File

@ -0,0 +1,29 @@
from langchain_community.document_loaders import PyMuPDFLoader
import docx
from pptx import Presentation
def extract_text(path):
return open(path, 'r').read()
def extract_pdf(path):
loader = PyMuPDFLoader(path)
data = loader.load()
data = [x.page_content for x in data]
content = '\n\n'.join(data)
return content
def extract_docx(path):
doc = docx.Document(path)
data = []
for paragraph in doc.paragraphs:
data.append(paragraph.text)
content = '\n\n'.join(data)
def extract_pptx(path):
prs = Presentation(path)
text = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text

249
finetune_demo/README.md Normal file
View File

@ -0,0 +1,249 @@
# GLM-4-9B Chat 对话模型微调
Read this in [English](README_en.md)
本 demo 中,你将体验到如何微调 glm-4-9b 对话开源模型(不支持视觉理解模型)。 请严格按照文档的步骤进行操作,以避免不必要的错误。
## 硬件检查
**本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。**
测试硬件信息:
+ OS: Ubuntu 22.04
+ Memory: 512GB
+ Python: 3.12.3
+ CUDA Version: 12.3
+ GPU Driver: 535.104.05
+ GPU: NVIDIA A100-SXM4-80GB * 8
| 微调方案 | 显存占用 | 权重保存点大小 |
|--------------------|-----------------------------------|---------|
| lora (PEFT) | 21531MiB | 17M |
| p-tuning v2 (PEFT) | 21381MiB | 121M |
| SFT (Zero3 method) | 80935MiB<br/>(Each GPU需要使用8张GPU) | 20G |
在开始微调之前,请你先安装`basic_demo`中的依赖,同时您需要安装本目录下的依赖项:
```bash
pip install -r requirements.txt
```
## 多轮对话格式
多轮对话微调示例采用 GLM-4 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`
对于数据文件,样例采用如下格式
如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。
```
[
{
"messages": [
{
"role": "system",
"content": "<system prompt text>",
"tools": [
{
"name": "<tool name>",
"args": {
"<arg name>": "<arg value>"
}
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// If Tool Using:
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
{
"role": "observation",
"content": "<observation prompt text>"
},
{
"role": "assistant",
"content": "<assistant response observation>"
},
// Multi_turns:
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
这里是一个不带有工具的例子:
```
{"messages": [{"role": "user", "content": "类型#裤*材质#牛仔布*风格#性感"}, {"role": "assistant", "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质其柔然的手感和细腻的质地在穿着舒适的同时透露着清纯甜美的个性气质。除此之外流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致不失为一款随性出街的必备单品。"}]}
```
这是一个带有工具调用的例子:
```
{"messages": [{"role": "system", "content": "", "tools": [{"type": "function", "function": {"name": "get_recommended_books", "description": "Get recommended books based on user's interests", "parameters": {"type": "object", "properties": {"interests": {"type": "array", "items": {"type": "string"}, "description": "The interests to recommend books for"}}, "required": ["interests"]}}}]}, {"role": "user", "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."}, {"role": "assistant", "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"}, {"role": "observation", "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"}, {"role": "assistant", "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."}]}
```
- `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user`
角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。
- `tools` 字段为可选字段,若存在 `tools` 字段,其必须出现在 `system`
角色之后,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `tools` 字段。当 `tools` 字段存在时,`system`
角色必须存在并且 `content` 字段为空。
## 配置文件
微调配置文件位于 `config` 目录下,包括以下文件:
1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。
2. `lora.yaml / ptuning_v2
3. .yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下:
+ data_config 部分
+ train_file: 训练数据集的文件路径。
+ val_file: 验证数据集的文件路径。
+ test_file: 测试数据集的文件路径。
+ num_proc: 在加载数据时使用的进程数量。
+ max_input_length: 输入序列的最大长度。
+ max_output_length: 输出序列的最大长度。
+ training_args 部分
+ output_dir: 用于保存模型和其他输出的目录。
+ max_steps: 训练的最大步数。
+ per_device_train_batch_size: 每个设备(如 GPU的训练批次大小。
+ dataloader_num_workers: 加载数据时使用的工作线程数量。
+ remove_unused_columns: 是否移除数据中未使用的列。
+ save_strategy: 模型保存策略(例如,每隔多少步保存一次)。
+ save_steps: 每隔多少步保存一次模型。
+ log_level: 日志级别(如 info
+ logging_strategy: 日志记录策略。
+ logging_steps: 每隔多少步记录一次日志。
+ per_device_eval_batch_size: 每个设备的评估批次大小。
+ evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。
+ eval_steps: 每隔多少步进行一次评估。
+ predict_with_generate: 是否使用生成模式进行预测。
+ generation_config 部分
+ max_new_tokens: 生成的最大新 token 数量。
+ peft_config 部分
+ peft_type: 使用的参数有效调整类型 (支持 LORA 和 PREFIX_TUNING)。
+ task_type: 任务类型,这里是因果语言模型 (不要改动)。
+ Lora 参数:
+ r: LoRA 的秩。
+ lora_alpha: LoRA 的缩放因子。
+ lora_dropout: 在 LoRA 层使用的 dropout 概率。
+ P-TuningV2 参数:
+ num_virtual_tokens: 虚拟 token 的数量。
+ num_attention_heads: 2: P-TuningV2 的注意力头数(不要改动)。
+ token_dim: 256: P-TuningV2 的 token 维度(不要改动)。
## 开始微调
通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`
```shell
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b configs/lora.yaml
```
通过以下代码执行 **单机单卡** 运行。
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
```
## 从保存点进行微调
如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式:
1. `yes`, 自动从最后一个保存的 Checkpoint开始训练
2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练
例如,这就是一个从最后一个保存点继续微调的示例代码
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
```
## 使用微调后的模型
### 在 inference.py 中验证微调后的模型
您可以在 `finetune_demo/inference.py` 中使用我们的微调后的模型,仅需要一行代码就能简单的进行测试。
```shell
python inference.py your_finetune_path
```
这样,得到的回答就微调后的回答了。
### 在本仓库的其他 demo 或者外部仓库使用微调后的模型
您可以在任何一个 demo 内使用我们的 `LORA` 和 全参微调的模型。这需要你自己按照以下教程进行修改代码。
1. 使用`finetune_demo/inference.py`中读入模型的方式替换 demo 中读入模型的方式。
> 请注意,对于 LORA 和 P-TuningV2 我们没有合并训练后的模型,而是在`adapter_config.json`
> 中记录了微调型的路径,如果你的原始模型位置发生更改,则你应该修改`adapter_config.json`中`base_model_name_or_path`的路径。
```python
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code
)
return model, tokenizer
```
2. 读取微调的模型,请注意,你应该使用微调模型的位置,例如,若你的模型位置为`/path/to/finetune_adapter_model`
,原始模型地址为`path/to/base_model`,则你应该使用`/path/to/finetune_adapter_model`作为`model_dir`。
3. 完成上述操作后,就能正常使用微调的模型了,其他的调用方式没有变化。
## 参考文献
```
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short
Papers)},
pages={61--68},
year={2022}
}
@misc{tang2023toolalpaca,
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
year={2023},
eprint={2306.05301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

245
finetune_demo/README_en.md Normal file
View File

@ -0,0 +1,245 @@
# GLM-4-9B Chat dialogue model fine-tuning
In this demo, you will experience how to fine-tune the glm-4-9b dialogue open source model (visual understanding model is not supported). Please strictly follow the steps in the document to avoid unnecessary errors.
## Hardware check
**The data in this document are tested in the following hardware environment. The actual operating environment requirements and the video memory occupied by the operation are slightly different. Please refer to the actual operating environment. **
Test hardware information:
+ OS: Ubuntu 22.04
+ Memory: 512GB
+ Python: 3.12.3
+ CUDA Version: 12.3
+ GPU Driver: 535.104.05
+ GPU: NVIDIA A100-SXM4-80GB * 8
| Fine-tuning solution | Video memory usage | Weight save point size |
|--------------------|-----------------------------------|---------|
| lora (PEFT) | 21531MiB | 17M |
| p-tuning v2 (PEFT) | 21381MiB | 121M |
| SFT (Zero3 method) | 80935MiB<br/>(Each GPU, 8 GPUs are required) | 20G |
Before starting fine-tuning, please install the dependencies in `basic_demo` first. You also need to install the dependencies in this directory:
```bash
pip install -r requirements.txt
```
## Multi-round dialogue format
The multi-round dialogue fine-tuning example uses the GLM-4 dialogue format convention, adding different `loss_mask` to different roles to calculate `loss` for multiple rounds of replies in one calculation.
For data files, the sample uses the following format:
```
[
{
"messages": [
{
"role": "system",
"content": "<system prompt text>",
"tools": [
{
"name": "<tool name>",
"args": {
"<arg name>": "<arg value>"
}
},
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
// If Tool Using:
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
},
{
"role": "observation",
"content": "<observation prompt text>"
},
{
"role": "assistant",
"content": "<assistant response observation>"
},
// Multi_turns:
{
"role": "user",
"content": "<user prompt text>"
},
{
"role": "assistant",
"content": "<assistant response text>"
}
]
}
// ...
]
```
This is a sample without tools:
```
{"messages": [{"role": "user", "content": "类型#裤*材质#牛仔布*风格#性感"}, {"role": "assistant", "content": "3x1的这款牛仔裤采用浅白的牛仔面料为裤身材质其柔然的手感和细腻的质地在穿着舒适的同时透露着清纯甜美的个性气质。除此之外流畅的裤身剪裁将性感的腿部曲线彰显的淋漓尽致不失为一款随性出街的必备单品。"}]}
```
This is a sample with tools:
```
{"messages": [{"role": "system", "content": "", "tools": [{"type": "function", "function": {"name": "get_recommended_books", "description": "Get recommended books based on user's interests", "parameters": {"type": "object", "properties": {"interests": {"type": "array", "items": {"type": "string"}, "description": "The interests to recommend books for"}}, "required": ["interests"]}}}]}, {"role": "user", "content": "Hi, I am looking for some book recommendations. I am interested in history and science fiction."}, {"role": "assistant", "content": "{\"name\": \"get_recommended_books\", \"arguments\": {\"interests\": [\"history\", \"science fiction\"]}}"}, {"role": "observation", "content": "{\"books\": [\"Sapiens: A Brief History of Humankind by Yuval Noah Harari\", \"A Brief History of Time by Stephen Hawking\", \"Dune by Frank Herbert\", \"The Martian by Andy Weir\"]}"}, {"role": "assistant", "content": "Based on your interests in history and science fiction, I would recommend the following books: \"Sapiens: A Brief History of Humankind\" by Yuval Noah Harari, \"A Brief History of Time\" by Stephen Hawking, \"Dune\" by Frank Herbert, and \"The Martian\" by Andy Weir."}]}
```
- The `system` role is optional, but if it exists, it must appear before the `user` role, and a complete conversation data (whether single-round or multi-round conversation) can only have one `system` role.
- The `tools` field is optional. If it exists, it must appear after the `system` role, and a complete conversation data (whether single-round or multi-round conversation) can only have one `tools` field. When the `tools` field exists, the `system` role must exist and the `content` field is empty.
## Configuration file
The fine-tuning configuration file is located in the `config` directory, including the following files:
1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed configuration file.
2. `lora.yaml / ptuning_v2
3. .yaml / sft.yaml`: Configuration files for different modes of models, including model parameters, optimizer parameters, training parameters, etc. Some important parameters are explained as follows:
+ data_config section
+ train_file: File path of training dataset.
+ val_file: File path of validation dataset.
+ test_file: File path of test dataset.
+ num_proc: Number of processes to use when loading data.
+ max_input_length: Maximum length of input sequence.
+ max_output_length: Maximum length of output sequence.
+ training_args section
+ output_dir: Directory for saving model and other outputs.
+ max_steps: Maximum number of training steps.
+ per_device_train_batch_size: Training batch size per device (such as GPU).
+ dataloader_num_workers: Number of worker threads to use when loading data.
+ remove_unused_columns: Whether to remove unused columns in data.
+ save_strategy: Model saving strategy (for example, how many steps to save).
+ save_steps: How many steps to save the model.
+ log_level: Log level (such as info).
+ logging_strategy: logging strategy.
+ logging_steps: how many steps to log at.
+ per_device_eval_batch_size: per-device evaluation batch size.
+ evaluation_strategy: evaluation strategy (e.g. how many steps to evaluate at).
+ eval_steps: how many steps to evaluate at.
+ predict_with_generate: whether to use generation mode for prediction.
+ generation_config section
+ max_new_tokens: maximum number of new tokens to generate.
+ peft_config section
+ peft_type: type of parameter tuning to use (supports LORA and PREFIX_TUNING).
+ task_type: task type, here is causal language model (don't change).
+ Lora parameters:
+ r: rank of LoRA.
+ lora_alpha: scaling factor of LoRA.
+ lora_dropout: dropout probability to use in LoRA layer.
+ P-TuningV2 parameters:
+ num_virtual_tokens: the number of virtual tokens.
+ num_attention_heads: 2: the number of attention heads of P-TuningV2 (do not change).
+ token_dim: 256: the token dimension of P-TuningV2 (do not change).
## Start fine-tuning
Execute **single machine multi-card/multi-machine multi-card** run through the following code, which uses `deepspeed` as the acceleration solution, and you need to install `deepspeed`.
```shell
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b configs/lora.yaml
```
Execute **single machine single card** run through the following code.
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
```
## Fine-tune from a saved point
If you train as described above, each fine-tuning will start from the beginning. If you want to fine-tune from a half-trained model, you can add a fourth parameter, which can be passed in two ways:
1. `yes`, automatically start training from the last saved Checkpoint
2. `XX`, breakpoint number, for example `600`, start training from Checkpoint 600
For example, this is an example code to continue fine-tuning from the last saved point
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
```
## Use the fine-tuned model
### Verify the fine-tuned model in inference.py
You can Use our fine-tuned model in `finetune_demo/inference.py`, and you can easily test it with just one line of code.
```shell
python inference.py your_finetune_path
```
In this way, the answer you get is the fine-tuned answer.
### Use the fine-tuned model in other demos in this repository or external repositories
You can use our `LORA` and fully fine-tuned models in any demo. This requires you to modify the code yourself according to the following tutorial.
1. Replace the way to read the model in the demo with the way to read the model in `finetune_demo/inference.py`.
> Please note that for LORA and P-TuningV2, we did not merge the trained models, but recorded the fine-tuned path in `adapter_config.json`
> If the location of your original model changes, you should modify the path of `base_model_name_or_path` in `adapter_config.json`.
```python
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code
)
return model, tokenizer
```
2. Read the fine-tuned model. Please note that you should use the location of the fine-tuned model. For example, if your model location is `/path/to/finetune_adapter_model`
and the original model address is `path/to/base_model`, you should use `/path/to/finetune_adapter_model` as `model_dir`.
3. After completing the above operations, you can use the fine-tuned model normally. Other calling methods remain unchanged.
## Reference
```
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short
Papers)},
pages={61--68},
year={2022}
}
@misc{tang2023toolalpaca,
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
year={2023},
eprint={2306.05301},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

View File

@ -0,0 +1,29 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,31 @@
{
"train_micro_batch_size_per_gpu": "auto",
"zero_allow_untested_optimizer": true,
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"contiguous_gradients": true,
"overlap_comm": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}

View File

@ -0,0 +1,44 @@
data_config:
train_file: train.jsonl
val_file: dev.jsonl
test_file: dev.jsonl
num_proc: 1
max_input_length: 512
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-4
# settings for data loading
per_device_train_batch_size: 1
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 10
# settings for evaluation
per_device_eval_batch_size: 4
evaluation_strategy: steps
eval_steps: 500
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
# see `transformers.GenerationConfig`
generation_config:
max_new_tokens: 512
# set your absolute deepspeed path here
#deepspeed: ds_zero_2.json
peft_config:
peft_type: LORA
task_type: CAUSAL_LM
r: 8
lora_alpha: 32
lora_dropout: 0.1

View File

@ -0,0 +1,44 @@
data_config:
train_file: train.jsonl
val_file: dev.jsonl
test_file: dev.jsonl
num_proc: 1
max_input_length: 128
max_output_length: 128
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-4
# settings for data loading
per_device_train_batch_size: 4
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 500
# settings for evaluation
per_device_eval_batch_size: 16
evaluation_strategy: steps
eval_steps: 500
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
# see `transformers.GenerationConfig`
generation_config:
max_new_tokens: 512
# set your absolute deepspeed path here
#deepspeed: ds_zero_3.json
peft_config:
peft_type: PREFIX_TUNING
task_type: CAUSAL_LM
num_virtual_tokens: 512
num_attention_heads: 2
token_dim: 256

View File

@ -0,0 +1,37 @@
data_config:
train_file: train.jsonl
val_file: dev.jsonl
test_file: dev.jsonl
num_proc: 1
max_input_length: 256
max_output_length: 512
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-5
# settings for data loading
per_device_train_batch_size: 1
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 10
# settings for evaluation
per_device_eval_batch_size: 16
evaluation_strategy: steps
eval_steps: 500
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
generation_config:
max_new_tokens: 512
# set your absolute deepspeed path here
deepspeed: configs/ds_zero_3.json

447
finetune_demo/finetune.py Normal file
View File

@ -0,0 +1,447 @@
# -*- coding: utf-8 -*-
import json
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Optional, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, NamedSplit, Split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
EvalPrediction,
GenerationConfig,
PreTrainedTokenizer,
Seq2SeqTrainingArguments,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
app = typer.Typer(pretty_exceptions_show_locals=False)
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
if output_ids is not None:
max_output_length = max(len(out) for out in output_ids)
if self.pad_to_multiple_of is not None:
max_output_length = (
(
max_output_length + self.pad_to_multiple_of - 1) //
self.pad_to_multiple_of * self.pad_to_multiple_of
)
for feature in features:
remainder = [self.tokenizer.pad_token_id] * (
max_output_length - len(feature['output_ids'])
)
if isinstance(feature['output_ids'], list):
feature['output_ids'] = feature['output_ids'] + remainder
else:
feature['output_ids'] = np.concatenate(
[feature['output_ids'], remainder]
).astype(np.int64)
return super().__call__(features, return_tensors)
class Seq2SeqTrainer(_Seq2SeqTrainer):
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, Any],
prediction_loss_only: bool,
ignore_keys=None,
**gen_kwargs,
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.args.predict_with_generate:
output_ids = inputs.pop('output_ids')
input_ids = inputs['input_ids']
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
)
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
labels = output_ids
return loss, generated_tokens, labels
@dc.dataclass
class DataConfig(object):
train_file: Optional[str] = None
val_file: Optional[str] = None
test_file: Optional[str] = None
num_proc: Optional[int] = None
@property
def data_format(self) -> str:
return Path(self.train_file).suffix
@property
def data_files(self) -> dict[NamedSplit, str]:
return {
split: data_file
for split, data_file in zip(
[Split.TRAIN, Split.VALIDATION, Split.TEST],
[self.train_file, self.val_file, self.test_file],
)
if data_file is not None
}
@dc.dataclass
class FinetuningConfig(object):
data_config: DataConfig
max_input_length: int
max_output_length: int
training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
)
peft_config: Optional[PeftConfig] = None
def __post_init__(self):
if not self.training_args.do_eval or self.data_config.val_file is None:
self.training_args.do_eval = False
self.training_args.evaluation_strategy = 'no'
self.data_config.val_file = None
else:
self.training_args.per_device_eval_batch_size = (
self.training_args.per_device_eval_batch_size
or self.training_args.per_device_train_batch_size
)
@classmethod
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
training_args = kwargs.get('training_args', None)
if training_args is not None and not isinstance(
training_args, Seq2SeqTrainingArguments
):
gen_config = training_args.get('generation_config')
# TODO: a bit hacky
if not isinstance(gen_config, GenerationConfig):
training_args['generation_config'] = GenerationConfig(
**gen_config
)
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
data_config = kwargs.get('data_config')
if not isinstance(data_config, DataConfig):
kwargs['data_config'] = DataConfig(**data_config)
peft_config = kwargs.get('peft_config', None)
if peft_config is not None and not isinstance(peft_config, PeftConfig):
kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
return cls(**kwargs)
@classmethod
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
path = Path(path)
parser = yaml.YAML(typ='safe', pure=True)
parser.indent(mapping=2, offset=2, sequence=4)
parser.default_flow_style = False
kwargs = parser.load(path)
return cls.from_dict(**kwargs)
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional
def _load_datasets(
data_dir: str,
data_format: str,
data_files: dict[NamedSplit, str],
num_proc: Optional[int],
) -> DatasetDict:
if data_format == '.jsonl':
dataset_dct = load_dataset(
data_dir,
data_files=data_files,
split=None,
num_proc=num_proc,
)
else:
raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
return dataset_dct
class DataManager(object):
def __init__(self, data_dir: str, data_config: DataConfig):
self._num_proc = data_config.num_proc
self._dataset_dct = _load_datasets(
data_dir,
data_config.data_format,
data_config.data_files,
self._num_proc,
)
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
return self._dataset_dct.get(split, None)
def get_dataset(
self,
split: NamedSplit,
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
batched: bool = True,
remove_orig_columns: bool = True,
) -> Optional[Dataset]:
orig_dataset = self._get_dataset(split)
if orig_dataset is None:
return
if remove_orig_columns:
remove_columns = orig_dataset.column_names
else:
remove_columns = None
return orig_dataset.map(
process_fn,
batched=batched,
remove_columns=remove_columns,
num_proc=self._num_proc,
)
def process_message(message):
if 'tools' in message and message['role'] == 'system':
for tool in message['tools']:
parameters = tool['function']['parameters']['properties']
tool['function']['parameters']['properties'] = \
{k: v for k, v in parameters.items() if
v is not None}
elif 'tools' in message:
del message['tools']
return message
def process_batch(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_labels = []
for conv in batched_conv:
input_ids = [151331, 151333]
loss_masks = [False, False]
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
new_loss_masks = [loss_mask_val] * len(new_input_ids)
input_ids += new_input_ids
loss_masks += new_loss_masks
input_ids.append(tokenizer.eos_token_id)
loss_masks = [False, *loss_masks]
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
labels.append(input_id)
else:
labels.append(-100)
max_length = max_input_length + max_output_length + 1
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
return {'input_ids': batched_input_ids, 'labels': batched_labels}
def process_batch_eval(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_output_ids = []
for conv in batched_conv:
input_ids = [151331, 151333]
for message in conv:
if len(input_ids) >= max_input_length:
break
else:
message = process_message(message)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],
new_input_ids[1:],
)
output_ids.append(tokenizer.eos_token_id)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
def load_tokenizer_and_model(
model_dir: str,
peft_config: Optional[PeftConfig] = None,
):
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
if peft_config is not None:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False,
torch_dtype=torch.bfloat16 # Must use BFloat 16
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False,
torch_dtype=torch.bfloat16
)
return tokenizer, model
def compute_metrics(eval_preds: EvalPrediction, tokenizer):
batched_pred_ids, batched_label_ids = eval_preds
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
pred_txt = tokenizer.decode(pred_ids).strip()
label_txt = tokenizer.decode(label_ids).strip()
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
for k, v in scores[0].items():
metrics_dct[k].append(round(v['f'] * 100, 4))
metrics_dct['bleu-4'].append(
sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
return {k: np.mean(v) for k, v in metrics_dct.items()}
@app.command()
def main(
data_dir: Annotated[str, typer.Argument(help='')],
model_dir: Annotated[
str,
typer.Argument(
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
),
],
config_file: Annotated[str, typer.Argument(help='')],
auto_resume_from_checkpoint: str = typer.Argument(
default='',
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
),
):
ft_config = FinetuningConfig.from_file(config_file)
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
print('train_dataset:', train_dataset)
val_dataset = data_manager.get_dataset(
Split.VALIDATION,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if val_dataset is not None:
print('val_dataset:', val_dataset)
test_dataset = data_manager.get_dataset(
Split.TEST,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if test_dataset is not None:
print('test_dataset:', test_dataset)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
trainer = Seq2SeqTrainer(
model=model,
args=ft_config.training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding='longest',
return_tensors='pt',
),
train_dataset=train_dataset,
eval_dataset=val_dataset.select(list(range(50))),
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
)
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
trainer.train()
else:
output_dir = ft_config.training_args.output_dir
dirlist = os.listdir(output_dir)
checkpoint_sn = 0
for checkpoint_str in dirlist:
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
if checkpoint > checkpoint_sn:
checkpoint_sn = checkpoint
if auto_resume_from_checkpoint.upper() == "YES":
if checkpoint_sn > 0:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
trainer.train()
else:
if auto_resume_from_checkpoint.isdigit():
if int(auto_resume_from_checkpoint) > 0:
checkpoint_sn = int(auto_resume_from_checkpoint)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
print(auto_resume_from_checkpoint,
"The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
if test_dataset is not None:
trainer.predict(test_dataset)
if __name__ == '__main__':
app()

109
finetune_demo/inference.py Normal file
View File

@ -0,0 +1,109 @@
from pathlib import Path
from typing import Annotated, Union
import typer
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
app = typer.Typer(pretty_exceptions_show_locals=False)
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = Path(model_dir).expanduser().resolve()
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code, encode_special_tokens=True, use_fast=False
)
return model, tokenizer
@app.command()
def main(
model_dir: Annotated[str, typer.Argument(help='')],
):
messages = [
{
"role": "system", "content": "",
"tools":
[
{
"type": "function",
"function": {
"name": "create_calendar_event",
"description": "Create a new calendar event",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "The title of the event"
},
"start_time": {
"type": "string",
"description": "The start time of the event in the format YYYY-MM-DD HH:MM"
},
"end_time": {
"type": "string",
"description": "The end time of the event in the format YYYY-MM-DD HH:MM"
}
},
"required": [
"title",
"start_time",
"end_time"
]
}
}
}
]
},
{
"role": "user",
"content": "Can you help me create a calendar event for my meeting tomorrow? The title is \"Team Meeting\". It starts at 10:00 AM and ends at 11:00 AM."
},
]
model, tokenizer = load_model_and_tokenizer(model_dir)
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(model.device)
generate_kwargs = {
"input_ids": inputs,
"max_new_tokens": 1024,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.8,
"repetition_penalty": 1.2,
"eos_token_id": model.config.eos_token_id,
}
outputs = model.generate(**generate_kwargs)
response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()
print("=========")
print(response)
if __name__ == '__main__':
app()

View File

@ -0,0 +1,5 @@
jieba>=0.42.1
datasets>=2.19.1
peft>=0.11.0
deepspeed>=0.13.3
nltk==3.8.1

7
resources/WECHAT.md Normal file
View File

@ -0,0 +1,7 @@
<div align="center">
<img src=wechat.jpg width="60%"/>
<p> 扫码关注公众号加入「GLM-4交流群」 </p>
<p> Scan the QR code to follow the official account and join the "ChatGLM Discussion Group" </p>
</div>

BIN
resources/eval_needle.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 452 KiB

BIN
resources/longbench.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

BIN
resources/wechat.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 151 KiB