Merge pull request #585 from sixsixcoder/main
Add GLM-4v-9B model support for vllm framework
This commit is contained in:
commit
5142bdb6e1
35
README.md
35
README.md
|
@ -11,6 +11,7 @@ Read this in [English](README_en.md)
|
||||||
|
|
||||||
## 项目更新
|
## 项目更新
|
||||||
|
|
||||||
|
- 🔥 **News**: ```2024/10/12```: 增加了 GLM-4v-9B 模型对vllm框架的支持
|
||||||
- 🔥 **News**: ```2024/09/06```: 增加了在 GLM-4v-9B 模型上构建OpenAI API兼容的服务端
|
- 🔥 **News**: ```2024/09/06```: 增加了在 GLM-4v-9B 模型上构建OpenAI API兼容的服务端
|
||||||
- 🔥 **News**: ```2024/09/05``` 我们开源了使LLMs能够在长上下文问答中生成细粒度引用的模型 [longcite-glm4-9b](https://huggingface.co/THUDM/LongCite-glm4-9b)
|
- 🔥 **News**: ```2024/09/05``` 我们开源了使LLMs能够在长上下文问答中生成细粒度引用的模型 [longcite-glm4-9b](https://huggingface.co/THUDM/LongCite-glm4-9b)
|
||||||
以及数据集 [LongCite-45k](https://huggingface.co/datasets/THUDM/LongCite-45k),
|
以及数据集 [LongCite-45k](https://huggingface.co/datasets/THUDM/LongCite-45k),
|
||||||
|
@ -252,7 +253,39 @@ with torch.no_grad():
|
||||||
print(tokenizer.decode(outputs[0]))
|
print(tokenizer.decode(outputs[0]))
|
||||||
```
|
```
|
||||||
|
|
||||||
注意: GLM-4V-9B 暂不支持使用 vLLM 方式调用。
|
使用 vLLM 后端进行推理:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from PIL import Image
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
model_name = "THUDM/glm-4v-9b"
|
||||||
|
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
max_model_len=8192,
|
||||||
|
trust_remote_code=True,
|
||||||
|
enforce_eager=True)
|
||||||
|
stop_token_ids = [151329, 151336, 151338]
|
||||||
|
sampling_params = SamplingParams(temperature=0.2,
|
||||||
|
max_tokens=1024,
|
||||||
|
stop_token_ids=stop_token_ids)
|
||||||
|
|
||||||
|
prompt = "What's the content of the image?"
|
||||||
|
image = Image.open("your image").convert('RGB')
|
||||||
|
inputs = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"image": image
|
||||||
|
},
|
||||||
|
}
|
||||||
|
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
## 完整项目列表
|
## 完整项目列表
|
||||||
|
|
||||||
|
|
35
README_en.md
35
README_en.md
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
## Update
|
## Update
|
||||||
|
|
||||||
|
- 🔥 **News**: ```2024/10/12```: Add GLM-4v-9B model support for vllm framework.
|
||||||
- 🔥 **News**: ```2024/09/06```: Add support for OpenAI API server on the GLM-4v-9B model.
|
- 🔥 **News**: ```2024/09/06```: Add support for OpenAI API server on the GLM-4v-9B model.
|
||||||
- 🔥 **News**: ```2024/09/05```: We open-sourced a model enabling LLMs to generate fine-grained citations in
|
- 🔥 **News**: ```2024/09/05```: We open-sourced a model enabling LLMs to generate fine-grained citations in
|
||||||
long-context Q&A: [longcite-glm4-9b](https://huggingface.co/THUDM/LongCite-glm4-9b), along with the
|
long-context Q&A: [longcite-glm4-9b](https://huggingface.co/THUDM/LongCite-glm4-9b), along with the
|
||||||
|
@ -269,7 +270,39 @@ with torch.no_grad():
|
||||||
print(tokenizer.decode(outputs[0]))
|
print(tokenizer.decode(outputs[0]))
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: GLM-4V-9B does not support calling using vLLM method yet.
|
Use the vLLM backend for inference:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from PIL import Image
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
model_name = "THUDM/glm-4v-9b"
|
||||||
|
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
max_model_len=8192,
|
||||||
|
trust_remote_code=True,
|
||||||
|
enforce_eager=True)
|
||||||
|
stop_token_ids = [151329, 151336, 151338]
|
||||||
|
sampling_params = SamplingParams(temperature=0.2,
|
||||||
|
max_tokens=1024,
|
||||||
|
stop_token_ids=stop_token_ids)
|
||||||
|
|
||||||
|
prompt = "What's the content of the image?"
|
||||||
|
image = Image.open("your image").convert('RGB')
|
||||||
|
inputs = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"image": image
|
||||||
|
},
|
||||||
|
}
|
||||||
|
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
## Complete project list
|
## Complete project list
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""
|
||||||
|
This script creates a CLI demo with vllm backand for the glm-4v-9b model,
|
||||||
|
allowing users to interact with the model through a command-line interface.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- Run the script to start the CLI demo.
|
||||||
|
- Interact with the model by typing questions and receiving responses.
|
||||||
|
|
||||||
|
Note: The script includes a modification to handle markdown to plain text conversion,
|
||||||
|
ensuring that the CLI interface displays formatted text correctly.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Dict
|
||||||
|
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
||||||
|
|
||||||
|
MODEL_PATH = 'THUDM/glm-4v-9b'
|
||||||
|
|
||||||
|
def load_model_and_tokenizer(model_dir: str):
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model_dir,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
dtype="bfloat16",
|
||||||
|
trust_remote_code=True,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
enforce_eager=True,
|
||||||
|
worker_use_ray=True,
|
||||||
|
disable_log_requests=True
|
||||||
|
# 如果遇见 OOM 现象,建议开启下述参数
|
||||||
|
# enable_chunked_prefill=True,
|
||||||
|
# max_num_batched_tokens=8192
|
||||||
|
)
|
||||||
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
engine = load_model_and_tokenizer(MODEL_PATH)
|
||||||
|
|
||||||
|
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
|
||||||
|
inputs = messages[-1]
|
||||||
|
params_dict = {
|
||||||
|
"n": 1,
|
||||||
|
"best_of": 1,
|
||||||
|
"presence_penalty": 1.0,
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"top_k": -1,
|
||||||
|
"use_beam_search": False,
|
||||||
|
"length_penalty": 1,
|
||||||
|
"early_stopping": False,
|
||||||
|
"ignore_eos": False,
|
||||||
|
"max_tokens": max_dec_len,
|
||||||
|
"logprobs": None,
|
||||||
|
"prompt_logprobs": None,
|
||||||
|
"skip_special_tokens": True,
|
||||||
|
"stop_token_ids" :[151329, 151336, 151338]
|
||||||
|
}
|
||||||
|
sampling_params = SamplingParams(**params_dict)
|
||||||
|
|
||||||
|
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
||||||
|
yield output.outputs[0].text
|
||||||
|
|
||||||
|
|
||||||
|
async def chat():
|
||||||
|
history = []
|
||||||
|
max_length = 8192
|
||||||
|
top_p = 0.8
|
||||||
|
temperature = 0.6
|
||||||
|
image = None
|
||||||
|
|
||||||
|
print("Welcome to the GLM-4v-9B CLI chat. Type your messages below.")
|
||||||
|
image_path = input("Image Path:")
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
except:
|
||||||
|
print("Invalid image path. Continuing with text conversation.")
|
||||||
|
while True:
|
||||||
|
user_input = input("\nYou: ")
|
||||||
|
if user_input.lower() in ["exit", "quit"]:
|
||||||
|
break
|
||||||
|
history.append([user_input, ""])
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for idx, (user_msg, model_msg) in enumerate(history):
|
||||||
|
if idx == len(history) - 1 and not model_msg:
|
||||||
|
messages.append({
|
||||||
|
"prompt": user_msg,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"image": image
|
||||||
|
},})
|
||||||
|
break
|
||||||
|
if user_msg:
|
||||||
|
messages.append({"role": "user", "prompt": user_msg})
|
||||||
|
if model_msg:
|
||||||
|
messages.append({"role": "assistant", "prompt": model_msg})
|
||||||
|
|
||||||
|
print("\nGLM-4v: ", end="")
|
||||||
|
current_length = 0
|
||||||
|
output = ""
|
||||||
|
async for output in vllm_gen(messages, top_p, temperature, max_length):
|
||||||
|
print(output[current_length:], end="", flush=True)
|
||||||
|
current_length = len(output)
|
||||||
|
history[-1][1] = output
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(chat())
|
Loading…
Reference in New Issue