lora adapter with vllm
This commit is contained in:
parent
f0d67ff4a4
commit
fafa33d351
|
@ -11,6 +11,7 @@ Read this in [English](README_en.md)
|
|||
|
||||
## 项目更新
|
||||
|
||||
- 🔥🔥 **News**: ```2024/09/04```: 增加了在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM 演示代码
|
||||
- 🔥🔥 **News**: ```2024/08/15```: 我们开源具备长文本输出能力(单轮对话大模型输出可超过1万token)
|
||||
的模型 [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b)
|
||||
以及数据集 [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k),
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
## Update
|
||||
|
||||
- 🔥🔥 **News**: ```2024/09/04```: Add demo code for using vLLM with LoRA adapter on the GLM-4-9B-Chat model.
|
||||
- 🔥🔥 **News**: ```2024/08/15```: We have open-sourced a model with long-text output capability (single turn LLM output can exceed
|
||||
10K tokens) [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) and the
|
||||
dataset [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k). You're welcome
|
||||
|
|
|
@ -126,6 +126,12 @@ python openai_api_server.py
|
|||
python openai_api_request.py
|
||||
```
|
||||
|
||||
### 在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM
|
||||
|
||||
```shell
|
||||
python vllm_cli_lora_demo.py
|
||||
```
|
||||
|
||||
## 压力测试
|
||||
|
||||
用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度:
|
||||
|
|
|
@ -132,6 +132,13 @@ Client request:
|
|||
python openai_api_request.py
|
||||
```
|
||||
|
||||
### LoRA adapters with vLLM
|
||||
+ use LoRA adapters with vLLM on GLM-4-9B-Chat model.
|
||||
|
||||
```shell
|
||||
python vllm_cli_lora_demo.py
|
||||
```
|
||||
|
||||
## Stress test
|
||||
|
||||
Users can use this code to test the generation speed of the model on the transformers backend on their own devices:
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
This script creates a CLI demo that utilizes LoRA adapters with vLLM backend for the GLM-4-9b model,
|
||||
allowing users to interact with the model through a command-line interface.
|
||||
|
||||
Usage:
|
||||
- Run the script to start the CLI demo.
|
||||
- Interact with the model by typing questions and receiving responses.
|
||||
|
||||
Note: The script includes a modification to handle markdown to plain text conversion,
|
||||
ensuring that the CLI interface displays formatted text correctly.
|
||||
"""
|
||||
import time
|
||||
import asyncio
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
||||
from typing import List, Dict
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = 'THUDM/GLM-4'
|
||||
LORA_PATH = '' # 你的 lora adapter 路径
|
||||
|
||||
def load_model_and_tokenizer(model_dir: str):
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_dir,
|
||||
tokenizer=model_dir,
|
||||
enable_lora=True, # 新增
|
||||
max_loras=1, # 新增
|
||||
max_lora_rank=8, ## 新增
|
||||
max_num_seqs=256, ## 新增
|
||||
tensor_parallel_size=2,
|
||||
dtype="bfloat16",
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048,
|
||||
enforce_eager=True,
|
||||
worker_use_ray=True,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True
|
||||
# 如果遇见 OOM 现象,建议开启下述参数
|
||||
# enable_chunked_prefill=True,
|
||||
# max_num_batched_tokens=8192
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=True,
|
||||
encode_special_tokens=True
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
return engine, tokenizer
|
||||
|
||||
|
||||
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH)
|
||||
|
||||
|
||||
async def vllm_gen(lora_path: str, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False
|
||||
)
|
||||
params_dict = {
|
||||
"n": 1,
|
||||
"best_of": 1,
|
||||
"presence_penalty": 1.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": -1,
|
||||
"use_beam_search": False,
|
||||
"length_penalty": 1,
|
||||
"early_stopping": False,
|
||||
"ignore_eos": False,
|
||||
"max_tokens": max_dec_len,
|
||||
"logprobs": None,
|
||||
"prompt_logprobs": None,
|
||||
"skip_special_tokens": True
|
||||
}
|
||||
sampling_params = SamplingParams(**params_dict)
|
||||
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)):
|
||||
yield output.outputs[0].text
|
||||
|
||||
|
||||
async def chat():
|
||||
history = []
|
||||
max_length = 8192
|
||||
top_p = 0.8
|
||||
temperature = 0
|
||||
|
||||
print("Welcome to the GLM-4-9B CLI (Lora) chat. Type your messages below.")
|
||||
while True:
|
||||
user_input = input("\nYou: ")
|
||||
if user_input.lower() in ["exit", "quit"]:
|
||||
break
|
||||
history.append([user_input, ""])
|
||||
|
||||
messages = []
|
||||
for idx, (user_msg, model_msg) in enumerate(history):
|
||||
if idx == len(history) - 1 and not model_msg:
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
break
|
||||
if user_msg:
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
if model_msg:
|
||||
messages.append({"role": "assistant", "content": model_msg})
|
||||
|
||||
print("\nGLM-4: ", end="")
|
||||
current_length = 0
|
||||
output = ""
|
||||
async for output in vllm_gen(LORA_PATH, messages, top_p, temperature, max_length):
|
||||
print(output[current_length:], end="", flush=True)
|
||||
current_length = len(output)
|
||||
history[-1][1] = output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(chat())
|
Loading…
Reference in New Issue