From fafa33d3510d336d76baace69fa48779513cb9e2 Mon Sep 17 00:00:00 2001 From: sixsixcoder Date: Wed, 4 Sep 2024 09:10:03 +0000 Subject: [PATCH 1/5] lora adapter with vllm --- README.md | 1 + README_en.md | 1 + basic_demo/README.md | 6 ++ basic_demo/README_en.md | 7 ++ basic_demo/vllm_cli_lora_demo.py | 116 +++++++++++++++++++++++++++++++ 5 files changed, 131 insertions(+) create mode 100644 basic_demo/vllm_cli_lora_demo.py diff --git a/README.md b/README.md index 37a0cfb..c58da67 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Read this in [English](README_en.md) ## 项目更新 +- 🔥🔥 **News**: ```2024/09/04```: 增加了在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM 演示代码 - 🔥🔥 **News**: ```2024/08/15```: 我们开源具备长文本输出能力(单轮对话大模型输出可超过1万token) 的模型 [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) 以及数据集 [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k), diff --git a/README_en.md b/README_en.md index aaff0f1..a45447b 100644 --- a/README_en.md +++ b/README_en.md @@ -9,6 +9,7 @@ ## Update +- 🔥🔥 **News**: ```2024/09/04```: Add demo code for using vLLM with LoRA adapter on the GLM-4-9B-Chat model. - 🔥🔥 **News**: ```2024/08/15```: We have open-sourced a model with long-text output capability (single turn LLM output can exceed 10K tokens) [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) and the dataset [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k). You're welcome diff --git a/basic_demo/README.md b/basic_demo/README.md index 4ffca52..30cd311 100644 --- a/basic_demo/README.md +++ b/basic_demo/README.md @@ -126,6 +126,12 @@ python openai_api_server.py python openai_api_request.py ``` +### 在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM + +```shell +python vllm_cli_lora_demo.py +``` + ## 压力测试 用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度: diff --git a/basic_demo/README_en.md b/basic_demo/README_en.md index 570e446..07086d1 100644 --- a/basic_demo/README_en.md +++ b/basic_demo/README_en.md @@ -132,6 +132,13 @@ Client request: python openai_api_request.py ``` +### LoRA adapters with vLLM ++ use LoRA adapters with vLLM on GLM-4-9B-Chat model. + +```shell +python vllm_cli_lora_demo.py +``` + ## Stress test Users can use this code to test the generation speed of the model on the transformers backend on their own devices: diff --git a/basic_demo/vllm_cli_lora_demo.py b/basic_demo/vllm_cli_lora_demo.py new file mode 100644 index 0000000..5f6834d --- /dev/null +++ b/basic_demo/vllm_cli_lora_demo.py @@ -0,0 +1,116 @@ +""" +This script creates a CLI demo that utilizes LoRA adapters with vLLM backend for the GLM-4-9b model, +allowing users to interact with the model through a command-line interface. + +Usage: +- Run the script to start the CLI demo. +- Interact with the model by typing questions and receiving responses. + +Note: The script includes a modification to handle markdown to plain text conversion, +ensuring that the CLI interface displays formatted text correctly. +""" +import time +import asyncio +from transformers import AutoTokenizer +from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine +from typing import List, Dict +from vllm.lora.request import LoRARequest + +MODEL_PATH = 'THUDM/GLM-4' +LORA_PATH = '' # 你的 lora adapter 路径 + +def load_model_and_tokenizer(model_dir: str): + engine_args = AsyncEngineArgs( + model=model_dir, + tokenizer=model_dir, + enable_lora=True, # 新增 + max_loras=1, # 新增 + max_lora_rank=8, ## 新增 + max_num_seqs=256, ## 新增 + tensor_parallel_size=2, + dtype="bfloat16", + trust_remote_code=True, + gpu_memory_utilization=0.5, + max_model_len=2048, + enforce_eager=True, + worker_use_ray=True, + engine_use_ray=False, + disable_log_requests=True + # 如果遇见 OOM 现象,建议开启下述参数 + # enable_chunked_prefill=True, + # max_num_batched_tokens=8192 + ) + tokenizer = AutoTokenizer.from_pretrained( + model_dir, + trust_remote_code=True, + encode_special_tokens=True + ) + engine = AsyncLLMEngine.from_engine_args(engine_args) + return engine, tokenizer + + +engine, tokenizer = load_model_and_tokenizer(MODEL_PATH) + + +async def vllm_gen(lora_path: str, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): + inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False + ) + params_dict = { + "n": 1, + "best_of": 1, + "presence_penalty": 1.0, + "frequency_penalty": 0.0, + "temperature": temperature, + "top_p": top_p, + "top_k": -1, + "use_beam_search": False, + "length_penalty": 1, + "early_stopping": False, + "ignore_eos": False, + "max_tokens": max_dec_len, + "logprobs": None, + "prompt_logprobs": None, + "skip_special_tokens": True + } + sampling_params = SamplingParams(**params_dict) + async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)): + yield output.outputs[0].text + + +async def chat(): + history = [] + max_length = 8192 + top_p = 0.8 + temperature = 0 + + print("Welcome to the GLM-4-9B CLI (Lora) chat. Type your messages below.") + while True: + user_input = input("\nYou: ") + if user_input.lower() in ["exit", "quit"]: + break + history.append([user_input, ""]) + + messages = [] + for idx, (user_msg, model_msg) in enumerate(history): + if idx == len(history) - 1 and not model_msg: + messages.append({"role": "user", "content": user_msg}) + break + if user_msg: + messages.append({"role": "user", "content": user_msg}) + if model_msg: + messages.append({"role": "assistant", "content": model_msg}) + + print("\nGLM-4: ", end="") + current_length = 0 + output = "" + async for output in vllm_gen(LORA_PATH, messages, top_p, temperature, max_length): + print(output[current_length:], end="", flush=True) + current_length = len(output) + history[-1][1] = output + + +if __name__ == "__main__": + asyncio.run(chat()) \ No newline at end of file From d4a3b7ddbaf4643460ce0a25dc2e846cf6d804f1 Mon Sep 17 00:00:00 2001 From: sixsixcoder Date: Wed, 4 Sep 2024 10:28:22 +0000 Subject: [PATCH 2/5] lora adapter with vllm --- basic_demo/vllm_cli_demo.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/basic_demo/vllm_cli_demo.py b/basic_demo/vllm_cli_demo.py index b5cc0a3..24da1d4 100644 --- a/basic_demo/vllm_cli_demo.py +++ b/basic_demo/vllm_cli_demo.py @@ -14,14 +14,16 @@ import asyncio from transformers import AutoTokenizer from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine from typing import List, Dict +from vllm.lora.request import LoRARequest MODEL_PATH = 'THUDM/glm-4-9b-chat' +LORA_PATH = '' - -def load_model_and_tokenizer(model_dir: str): +def load_model_and_tokenizer(model_dir: str, enable_lora: bool): engine_args = AsyncEngineArgs( model=model_dir, tokenizer=model_dir, + enable_lora=enable_lora, tensor_parallel_size=1, dtype="bfloat16", trust_remote_code=True, @@ -42,11 +44,14 @@ def load_model_and_tokenizer(model_dir: str): engine = AsyncLLMEngine.from_engine_args(engine_args) return engine, tokenizer +enable_lora = False +if LORA_PATH: + enable_lora = True -engine, tokenizer = load_model_and_tokenizer(MODEL_PATH) +engine, tokenizer = load_model_and_tokenizer(MODEL_PATH, enable_lora) -async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): +async def vllm_gen(lora_path: str, enable_lora: bool, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, @@ -70,8 +75,12 @@ async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: fl "skip_special_tokens": True, } sampling_params = SamplingParams(**params_dict) - async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): - yield output.outputs[0].text + if enable_lora: + async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)): + yield output.outputs[0].text + else: + async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): + yield output.outputs[0].text async def chat(): @@ -100,7 +109,7 @@ async def chat(): print("\nGLM-4: ", end="") current_length = 0 output = "" - async for output in vllm_gen(messages, top_p, temperature, max_length): + async for output in vllm_gen(LORA_PATH, enable_lora, messages, top_p, temperature, max_length): print(output[current_length:], end="", flush=True) current_length = len(output) history[-1][1] = output From af2fc4558547dedfd0d49e6cc94229189a517aeb Mon Sep 17 00:00:00 2001 From: sixsixcoder Date: Wed, 4 Sep 2024 10:30:21 +0000 Subject: [PATCH 3/5] lora adapter with vllm --- basic_demo/vllm_cli_lora_demo.py | 116 ------------------------------- 1 file changed, 116 deletions(-) delete mode 100644 basic_demo/vllm_cli_lora_demo.py diff --git a/basic_demo/vllm_cli_lora_demo.py b/basic_demo/vllm_cli_lora_demo.py deleted file mode 100644 index 5f6834d..0000000 --- a/basic_demo/vllm_cli_lora_demo.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -This script creates a CLI demo that utilizes LoRA adapters with vLLM backend for the GLM-4-9b model, -allowing users to interact with the model through a command-line interface. - -Usage: -- Run the script to start the CLI demo. -- Interact with the model by typing questions and receiving responses. - -Note: The script includes a modification to handle markdown to plain text conversion, -ensuring that the CLI interface displays formatted text correctly. -""" -import time -import asyncio -from transformers import AutoTokenizer -from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine -from typing import List, Dict -from vllm.lora.request import LoRARequest - -MODEL_PATH = 'THUDM/GLM-4' -LORA_PATH = '' # 你的 lora adapter 路径 - -def load_model_and_tokenizer(model_dir: str): - engine_args = AsyncEngineArgs( - model=model_dir, - tokenizer=model_dir, - enable_lora=True, # 新增 - max_loras=1, # 新增 - max_lora_rank=8, ## 新增 - max_num_seqs=256, ## 新增 - tensor_parallel_size=2, - dtype="bfloat16", - trust_remote_code=True, - gpu_memory_utilization=0.5, - max_model_len=2048, - enforce_eager=True, - worker_use_ray=True, - engine_use_ray=False, - disable_log_requests=True - # 如果遇见 OOM 现象,建议开启下述参数 - # enable_chunked_prefill=True, - # max_num_batched_tokens=8192 - ) - tokenizer = AutoTokenizer.from_pretrained( - model_dir, - trust_remote_code=True, - encode_special_tokens=True - ) - engine = AsyncLLMEngine.from_engine_args(engine_args) - return engine, tokenizer - - -engine, tokenizer = load_model_and_tokenizer(MODEL_PATH) - - -async def vllm_gen(lora_path: str, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): - inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=False - ) - params_dict = { - "n": 1, - "best_of": 1, - "presence_penalty": 1.0, - "frequency_penalty": 0.0, - "temperature": temperature, - "top_p": top_p, - "top_k": -1, - "use_beam_search": False, - "length_penalty": 1, - "early_stopping": False, - "ignore_eos": False, - "max_tokens": max_dec_len, - "logprobs": None, - "prompt_logprobs": None, - "skip_special_tokens": True - } - sampling_params = SamplingParams(**params_dict) - async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)): - yield output.outputs[0].text - - -async def chat(): - history = [] - max_length = 8192 - top_p = 0.8 - temperature = 0 - - print("Welcome to the GLM-4-9B CLI (Lora) chat. Type your messages below.") - while True: - user_input = input("\nYou: ") - if user_input.lower() in ["exit", "quit"]: - break - history.append([user_input, ""]) - - messages = [] - for idx, (user_msg, model_msg) in enumerate(history): - if idx == len(history) - 1 and not model_msg: - messages.append({"role": "user", "content": user_msg}) - break - if user_msg: - messages.append({"role": "user", "content": user_msg}) - if model_msg: - messages.append({"role": "assistant", "content": model_msg}) - - print("\nGLM-4: ", end="") - current_length = 0 - output = "" - async for output in vllm_gen(LORA_PATH, messages, top_p, temperature, max_length): - print(output[current_length:], end="", flush=True) - current_length = len(output) - history[-1][1] = output - - -if __name__ == "__main__": - asyncio.run(chat()) \ No newline at end of file From 7422d118e881bfe15f915030439416bebc86fe48 Mon Sep 17 00:00:00 2001 From: sixgod Date: Wed, 4 Sep 2024 21:19:37 +0800 Subject: [PATCH 4/5] Update README.md --- basic_demo/README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/basic_demo/README.md b/basic_demo/README.md index 30cd311..6e3a7dd 100644 --- a/basic_demo/README.md +++ b/basic_demo/README.md @@ -112,6 +112,13 @@ python trans_batch_demo.py python vllm_cli_demo.py ``` ++ 在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM +```python +# vllm_cli_demo.py +# 添加 LORA_PATH = '' +``` + + + 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。 启动服务端: @@ -126,12 +133,6 @@ python openai_api_server.py python openai_api_request.py ``` -### 在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM - -```shell -python vllm_cli_lora_demo.py -``` - ## 压力测试 用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度: From cb038cd2d3a7c705bc484895b6aba0b86dd51802 Mon Sep 17 00:00:00 2001 From: sixgod Date: Wed, 4 Sep 2024 21:20:45 +0800 Subject: [PATCH 5/5] Update README_en.md --- basic_demo/README_en.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/basic_demo/README_en.md b/basic_demo/README_en.md index 07086d1..412ee90 100644 --- a/basic_demo/README_en.md +++ b/basic_demo/README_en.md @@ -117,6 +117,13 @@ python trans_batch_demo.py python vllm_cli_demo.py ``` ++ use LoRA adapters with vLLM on GLM-4-9B-Chat model. + +```python +# vllm_cli_demo.py +# add LORA_PATH = '' +``` + + Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This demo supports Function Call and All Tools functions. @@ -132,17 +139,10 @@ Client request: python openai_api_request.py ``` -### LoRA adapters with vLLM -+ use LoRA adapters with vLLM on GLM-4-9B-Chat model. - -```shell -python vllm_cli_lora_demo.py -``` - ## Stress test Users can use this code to test the generation speed of the model on the transformers backend on their own devices: ```shell python trans_stress_test.py -``` \ No newline at end of file +```