lora adapter with vllm
This commit is contained in:
parent
fafa33d351
commit
d4a3b7ddba
|
@ -14,14 +14,16 @@ import asyncio
|
|||
from transformers import AutoTokenizer
|
||||
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
||||
from typing import List, Dict
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
||||
LORA_PATH = ''
|
||||
|
||||
|
||||
def load_model_and_tokenizer(model_dir: str):
|
||||
def load_model_and_tokenizer(model_dir: str, enable_lora: bool):
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_dir,
|
||||
tokenizer=model_dir,
|
||||
enable_lora=enable_lora,
|
||||
tensor_parallel_size=1,
|
||||
dtype="bfloat16",
|
||||
trust_remote_code=True,
|
||||
|
@ -42,11 +44,14 @@ def load_model_and_tokenizer(model_dir: str):
|
|||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
return engine, tokenizer
|
||||
|
||||
enable_lora = False
|
||||
if LORA_PATH:
|
||||
enable_lora = True
|
||||
|
||||
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH)
|
||||
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH, enable_lora)
|
||||
|
||||
|
||||
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
|
||||
async def vllm_gen(lora_path: str, enable_lora: bool, messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
|
@ -70,8 +75,12 @@ async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: fl
|
|||
"skip_special_tokens": True,
|
||||
}
|
||||
sampling_params = SamplingParams(**params_dict)
|
||||
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
||||
yield output.outputs[0].text
|
||||
if enable_lora:
|
||||
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}", lora_request=LoRARequest("glm-4-lora", 1, lora_path=lora_path)):
|
||||
yield output.outputs[0].text
|
||||
else:
|
||||
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
||||
yield output.outputs[0].text
|
||||
|
||||
|
||||
async def chat():
|
||||
|
@ -100,7 +109,7 @@ async def chat():
|
|||
print("\nGLM-4: ", end="")
|
||||
current_length = 0
|
||||
output = ""
|
||||
async for output in vllm_gen(messages, top_p, temperature, max_length):
|
||||
async for output in vllm_gen(LORA_PATH, enable_lora, messages, top_p, temperature, max_length):
|
||||
print(output[current_length:], end="", flush=True)
|
||||
current_length = len(output)
|
||||
history[-1][1] = output
|
||||
|
|
Loading…
Reference in New Issue