Merge pull request from sixsixcoder/main

vLLM with LoRA adapter
This commit is contained in:
zR 2024-09-05 09:52:59 +08:00 committed by GitHub
commit 19f2f91fb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 33 additions and 8 deletions

View File

@ -11,6 +11,7 @@ Read this in [English](README_en.md)
## 项目更新 ## 项目更新
- 🔥🔥 **News**: ```2024/09/04```: 增加了在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM 演示代码
- 🔥🔥 **News**: ```2024/08/15```: 我们开源具备长文本输出能力(单轮对话大模型输出可超过1万token) - 🔥🔥 **News**: ```2024/08/15```: 我们开源具备长文本输出能力(单轮对话大模型输出可超过1万token)
的模型 [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) 的模型 [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b)
以及数据集 [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k), 以及数据集 [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k),

View File

@ -9,6 +9,7 @@
## Update ## Update
- 🔥🔥 **News**: ```2024/09/04```: Add demo code for using vLLM with LoRA adapter on the GLM-4-9B-Chat model.
- 🔥🔥 **News**: ```2024/08/15```: We have open-sourced a model with long-text output capability (single turn LLM output can exceed - 🔥🔥 **News**: ```2024/08/15```: We have open-sourced a model with long-text output capability (single turn LLM output can exceed
10K tokens) [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) and the 10K tokens) [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) and the
dataset [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k). You're welcome dataset [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k). You're welcome

View File

@ -112,6 +112,13 @@ python trans_batch_demo.py
python vllm_cli_demo.py python vllm_cli_demo.py
``` ```
+ 在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM
```python
# vllm_cli_demo.py
# 添加 LORA_PATH = ''
```
+ 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。 + 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。
启动服务端: 启动服务端:

View File

@ -117,6 +117,13 @@ python trans_batch_demo.py
python vllm_cli_demo.py python vllm_cli_demo.py
``` ```
+ use LoRA adapters with vLLM on GLM-4-9B-Chat model.
```python
# vllm_cli_demo.py
# add LORA_PATH = ''
```
+ Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This + Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This
demo supports Function Call and All Tools functions. demo supports Function Call and All Tools functions.

View File

@ -14,14 +14,16 @@ import asyncio
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from typing import List, Dict from typing import List, Dict
from vllm.lora.request import LoRARequest
MODEL_PATH = 'THUDM/glm-4-9b-chat' MODEL_PATH = 'THUDM/glm-4-9b-chat'
LORA_PATH = ''
def load_model_and_tokenizer(model_dir: str, enable_lora: bool):
def load_model_and_tokenizer(model_dir: str):
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=model_dir, model=model_dir,
tokenizer=model_dir, tokenizer=model_dir,
enable_lora=enable_lora,
tensor_parallel_size=1, tensor_parallel_size=1,
dtype="bfloat16", dtype="bfloat16",
trust_remote_code=True, trust_remote_code=True,
@ -42,11 +44,14 @@ def load_model_and_tokenizer(model_dir: str):
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, tokenizer return engine, tokenizer
enable_lora = False
if LORA_PATH:
enable_lora = True
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH) engine, tokenizer = load_model_and_tokenizer(MODEL_PATH, enable_lora)
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): async def vllm_gen(lora_path: str, enable_lora: bool, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
inputs = tokenizer.apply_chat_template( inputs = tokenizer.apply_chat_template(
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
@ -70,6 +75,10 @@ async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: fl
"skip_special_tokens": True, "skip_special_tokens": True,
} }
sampling_params = SamplingParams(**params_dict) sampling_params = SamplingParams(**params_dict)
if enable_lora:
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)):
yield output.outputs[0].text
else:
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
yield output.outputs[0].text yield output.outputs[0].text
@ -100,7 +109,7 @@ async def chat():
print("\nGLM-4: ", end="") print("\nGLM-4: ", end="")
current_length = 0 current_length = 0
output = "" output = ""
async for output in vllm_gen(messages, top_p, temperature, max_length): async for output in vllm_gen(LORA_PATH, enable_lora, messages, top_p, temperature, max_length):
print(output[current_length:], end="", flush=True) print(output[current_length:], end="", flush=True)
current_length = len(output) current_length = len(output)
history[-1][1] = output history[-1][1] = output