diff --git a/README.md b/README.md index 37a0cfb..c58da67 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Read this in [English](README_en.md) ## 项目更新 +- 🔥🔥 **News**: ```2024/09/04```: 增加了在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM 演示代码 - 🔥🔥 **News**: ```2024/08/15```: 我们开源具备长文本输出能力(单轮对话大模型输出可超过1万token) 的模型 [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) 以及数据集 [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k), diff --git a/README_en.md b/README_en.md index aaff0f1..a45447b 100644 --- a/README_en.md +++ b/README_en.md @@ -9,6 +9,7 @@ ## Update +- 🔥🔥 **News**: ```2024/09/04```: Add demo code for using vLLM with LoRA adapter on the GLM-4-9B-Chat model. - 🔥🔥 **News**: ```2024/08/15```: We have open-sourced a model with long-text output capability (single turn LLM output can exceed 10K tokens) [longwriter-glm4-9b](https://huggingface.co/THUDM/LongWriter-glm4-9b) and the dataset [LongWriter-6k](https://huggingface.co/datasets/THUDM/LongWriter-6k). You're welcome diff --git a/basic_demo/README.md b/basic_demo/README.md index 4ffca52..6e3a7dd 100644 --- a/basic_demo/README.md +++ b/basic_demo/README.md @@ -112,6 +112,13 @@ python trans_batch_demo.py python vllm_cli_demo.py ``` ++ 在 GLM-4-9B-Chat 模型上使用带有 Lora adapter 的 vLLM +```python +# vllm_cli_demo.py +# 添加 LORA_PATH = '' +``` + + + 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。 启动服务端: diff --git a/basic_demo/README_en.md b/basic_demo/README_en.md index 570e446..412ee90 100644 --- a/basic_demo/README_en.md +++ b/basic_demo/README_en.md @@ -117,6 +117,13 @@ python trans_batch_demo.py python vllm_cli_demo.py ``` ++ use LoRA adapters with vLLM on GLM-4-9B-Chat model. + +```python +# vllm_cli_demo.py +# add LORA_PATH = '' +``` + + Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This demo supports Function Call and All Tools functions. @@ -138,4 +145,4 @@ Users can use this code to test the generation speed of the model on the transfo ```shell python trans_stress_test.py -``` \ No newline at end of file +``` diff --git a/basic_demo/vllm_cli_demo.py b/basic_demo/vllm_cli_demo.py index b5cc0a3..24da1d4 100644 --- a/basic_demo/vllm_cli_demo.py +++ b/basic_demo/vllm_cli_demo.py @@ -14,14 +14,16 @@ import asyncio from transformers import AutoTokenizer from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine from typing import List, Dict +from vllm.lora.request import LoRARequest MODEL_PATH = 'THUDM/glm-4-9b-chat' +LORA_PATH = '' - -def load_model_and_tokenizer(model_dir: str): +def load_model_and_tokenizer(model_dir: str, enable_lora: bool): engine_args = AsyncEngineArgs( model=model_dir, tokenizer=model_dir, + enable_lora=enable_lora, tensor_parallel_size=1, dtype="bfloat16", trust_remote_code=True, @@ -42,11 +44,14 @@ def load_model_and_tokenizer(model_dir: str): engine = AsyncLLMEngine.from_engine_args(engine_args) return engine, tokenizer +enable_lora = False +if LORA_PATH: + enable_lora = True -engine, tokenizer = load_model_and_tokenizer(MODEL_PATH) +engine, tokenizer = load_model_and_tokenizer(MODEL_PATH, enable_lora) -async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): +async def vllm_gen(lora_path: str, enable_lora: bool, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, @@ -70,8 +75,12 @@ async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: fl "skip_special_tokens": True, } sampling_params = SamplingParams(**params_dict) - async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): - yield output.outputs[0].text + if enable_lora: + async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)): + yield output.outputs[0].text + else: + async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): + yield output.outputs[0].text async def chat(): @@ -100,7 +109,7 @@ async def chat(): print("\nGLM-4: ", end="") current_length = 0 output = "" - async for output in vllm_gen(messages, top_p, temperature, max_length): + async for output in vllm_gen(LORA_PATH, enable_lora, messages, top_p, temperature, max_length): print(output[current_length:], end="", flush=True) current_length = len(output) history[-1][1] = output