lora adapter with vllm

This commit is contained in:
sixsixcoder 2024-09-04 10:28:22 +00:00
parent fafa33d351
commit d4a3b7ddba
1 changed files with 16 additions and 7 deletions

View File

@ -14,14 +14,16 @@ import asyncio
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from typing import List, Dict from typing import List, Dict
from vllm.lora.request import LoRARequest
MODEL_PATH = 'THUDM/glm-4-9b-chat' MODEL_PATH = 'THUDM/glm-4-9b-chat'
LORA_PATH = ''
def load_model_and_tokenizer(model_dir: str, enable_lora: bool):
def load_model_and_tokenizer(model_dir: str):
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=model_dir, model=model_dir,
tokenizer=model_dir, tokenizer=model_dir,
enable_lora=enable_lora,
tensor_parallel_size=1, tensor_parallel_size=1,
dtype="bfloat16", dtype="bfloat16",
trust_remote_code=True, trust_remote_code=True,
@ -42,11 +44,14 @@ def load_model_and_tokenizer(model_dir: str):
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, tokenizer return engine, tokenizer
enable_lora = False
if LORA_PATH:
enable_lora = True
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH) engine, tokenizer = load_model_and_tokenizer(MODEL_PATH, enable_lora)
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): async def vllm_gen(lora_path: str, enable_lora: bool, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
inputs = tokenizer.apply_chat_template( inputs = tokenizer.apply_chat_template(
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
@ -70,8 +75,12 @@ async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: fl
"skip_special_tokens": True, "skip_special_tokens": True,
} }
sampling_params = SamplingParams(**params_dict) sampling_params = SamplingParams(**params_dict)
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): if enable_lora:
yield output.outputs[0].text async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)):
yield output.outputs[0].text
else:
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
yield output.outputs[0].text
async def chat(): async def chat():
@ -100,7 +109,7 @@ async def chat():
print("\nGLM-4: ", end="") print("\nGLM-4: ", end="")
current_length = 0 current_length = 0
output = "" output = ""
async for output in vllm_gen(messages, top_p, temperature, max_length): async for output in vllm_gen(LORA_PATH, enable_lora, messages, top_p, temperature, max_length):
print(output[current_length:], end="", flush=True) print(output[current_length:], end="", flush=True)
current_length = len(output) current_length = len(output)
history[-1][1] = output history[-1][1] = output