Merge pull request #585 from sixsixcoder/main

Add GLM-4v-9B model support for vllm framework
This commit is contained in:
Yuxuan.Zhang 2024-10-15 14:55:35 +08:00 committed by GitHub
commit 5142bdb6e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 175 additions and 2 deletions

View File

@ -11,6 +11,7 @@ Read this in [English](README_en.md)
## 项目更新
- 🔥 **News**: ```2024/10/12```: 增加了 GLM-4v-9B 模型对vllm框架的支持
- 🔥 **News**: ```2024/09/06```: 增加了在 GLM-4v-9B 模型上构建OpenAI API兼容的服务端
- 🔥 **News**: ```2024/09/05``` 我们开源了使LLMs能够在长上下文问答中生成细粒度引用的模型 [longcite-glm4-9b](https://huggingface.co/THUDM/LongCite-glm4-9b)
以及数据集 [LongCite-45k](https://huggingface.co/datasets/THUDM/LongCite-45k),
@ -252,7 +253,39 @@ with torch.no_grad():
print(tokenizer.decode(outputs[0]))
```
注意: GLM-4V-9B 暂不支持使用 vLLM 方式调用。
使用 vLLM 后端进行推理:
```python
from PIL import Image
from vllm import LLM, SamplingParams
model_name = "THUDM/glm-4v-9b"
llm = LLM(model=model_name,
tensor_parallel_size=1,
max_model_len=8192,
trust_remote_code=True,
enforce_eager=True)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0.2,
max_tokens=1024,
stop_token_ids=stop_token_ids)
prompt = "What's the content of the image?"
image = Image.open("your image").convert('RGB')
inputs = {
"prompt": prompt,
"multi_modal_data": {
"image": image
},
}
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
```
## 完整项目列表

View File

@ -9,6 +9,7 @@
## Update
- 🔥 **News**: ```2024/10/12```: Add GLM-4v-9B model support for vllm framework.
- 🔥 **News**: ```2024/09/06```: Add support for OpenAI API server on the GLM-4v-9B model.
- 🔥 **News**: ```2024/09/05```: We open-sourced a model enabling LLMs to generate fine-grained citations in
long-context Q&A: [longcite-glm4-9b](https://huggingface.co/THUDM/LongCite-glm4-9b), along with the
@ -269,7 +270,39 @@ with torch.no_grad():
print(tokenizer.decode(outputs[0]))
```
Note: GLM-4V-9B does not support calling using vLLM method yet.
Use the vLLM backend for inference:
```python
from PIL import Image
from vllm import LLM, SamplingParams
model_name = "THUDM/glm-4v-9b"
llm = LLM(model=model_name,
tensor_parallel_size=1,
max_model_len=8192,
trust_remote_code=True,
enforce_eager=True)
stop_token_ids = [151329, 151336, 151338]
sampling_params = SamplingParams(temperature=0.2,
max_tokens=1024,
stop_token_ids=stop_token_ids)
prompt = "What's the content of the image?"
image = Image.open("your image").convert('RGB')
inputs = {
"prompt": prompt,
"multi_modal_data": {
"image": image
},
}
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
```
## Complete project list

View File

@ -0,0 +1,107 @@
"""
This script creates a CLI demo with vllm backand for the glm-4v-9b model,
allowing users to interact with the model through a command-line interface.
Usage:
- Run the script to start the CLI demo.
- Interact with the model by typing questions and receiving responses.
Note: The script includes a modification to handle markdown to plain text conversion,
ensuring that the CLI interface displays formatted text correctly.
"""
import time
import asyncio
from PIL import Image
from typing import List, Dict
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
MODEL_PATH = 'THUDM/glm-4v-9b'
def load_model_and_tokenizer(model_dir: str):
engine_args = AsyncEngineArgs(
model=model_dir,
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,
gpu_memory_utilization=0.9,
enforce_eager=True,
worker_use_ray=True,
disable_log_requests=True
# 如果遇见 OOM 现象,建议开启下述参数
# enable_chunked_prefill=True,
# max_num_batched_tokens=8192
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine
engine = load_model_and_tokenizer(MODEL_PATH)
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
inputs = messages[-1]
params_dict = {
"n": 1,
"best_of": 1,
"presence_penalty": 1.0,
"frequency_penalty": 0.0,
"temperature": temperature,
"top_p": top_p,
"top_k": -1,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"ignore_eos": False,
"max_tokens": max_dec_len,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
"stop_token_ids" :[151329, 151336, 151338]
}
sampling_params = SamplingParams(**params_dict)
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
yield output.outputs[0].text
async def chat():
history = []
max_length = 8192
top_p = 0.8
temperature = 0.6
image = None
print("Welcome to the GLM-4v-9B CLI chat. Type your messages below.")
image_path = input("Image Path:")
try:
image = Image.open(image_path).convert("RGB")
except:
print("Invalid image path. Continuing with text conversation.")
while True:
user_input = input("\nYou: ")
if user_input.lower() in ["exit", "quit"]:
break
history.append([user_input, ""])
messages = []
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({
"prompt": user_msg,
"multi_modal_data": {
"image": image
},})
break
if user_msg:
messages.append({"role": "user", "prompt": user_msg})
if model_msg:
messages.append({"role": "assistant", "prompt": model_msg})
print("\nGLM-4v: ", end="")
current_length = 0
output = ""
async for output in vllm_gen(messages, top_p, temperature, max_length):
print(output[current_length:], end="", flush=True)
current_length = len(output)
history[-1][1] = output
if __name__ == "__main__":
asyncio.run(chat())