diff --git a/README.md b/README.md index 12ee103..cb281e3 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Read this in [English](README_en.md) ## 项目更新 +- 🔥 **News**: ```2024/10/12```: 增加了 GLM-4v-9B 模型对vllm框架的支持 - 🔥 **News**: ```2024/09/06```: 增加了在 GLM-4v-9B 模型上构建OpenAI API兼容的服务端 - 🔥 **News**: ```2024/09/05``` 我们开源了使LLMs能够在长上下文问答中生成细粒度引用的模型 [longcite-glm4-9b](https://huggingface.co/THUDM/LongCite-glm4-9b) 以及数据集 [LongCite-45k](https://huggingface.co/datasets/THUDM/LongCite-45k), @@ -252,7 +253,39 @@ with torch.no_grad(): print(tokenizer.decode(outputs[0])) ``` -注意: GLM-4V-9B 暂不支持使用 vLLM 方式调用。 +使用 vLLM 后端进行推理: + +```python +from PIL import Image +from vllm import LLM, SamplingParams + +model_name = "THUDM/glm-4v-9b" + +llm = LLM(model=model_name, + tensor_parallel_size=1, + max_model_len=8192, + trust_remote_code=True, + enforce_eager=True) +stop_token_ids = [151329, 151336, 151338] +sampling_params = SamplingParams(temperature=0.2, + max_tokens=1024, + stop_token_ids=stop_token_ids) + +prompt = "What's the content of the image?" +image = Image.open("your image").convert('RGB') +inputs = { + "prompt": prompt, + "multi_modal_data": { + "image": image + }, + } +outputs = llm.generate(inputs, sampling_params=sampling_params) + +for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +``` ## 完整项目列表 diff --git a/README_en.md b/README_en.md index 2d52d50..3a339b3 100644 --- a/README_en.md +++ b/README_en.md @@ -9,6 +9,7 @@ ## Update +- 🔥 **News**: ```2024/10/12```: Add GLM-4v-9B model support for vllm framework. - 🔥 **News**: ```2024/09/06```: Add support for OpenAI API server on the GLM-4v-9B model. - 🔥 **News**: ```2024/09/05```: We open-sourced a model enabling LLMs to generate fine-grained citations in long-context Q&A: [longcite-glm4-9b](https://huggingface.co/THUDM/LongCite-glm4-9b), along with the @@ -269,7 +270,39 @@ with torch.no_grad(): print(tokenizer.decode(outputs[0])) ``` -Note: GLM-4V-9B does not support calling using vLLM method yet. +Use the vLLM backend for inference: + +```python +from PIL import Image +from vllm import LLM, SamplingParams + +model_name = "THUDM/glm-4v-9b" + +llm = LLM(model=model_name, + tensor_parallel_size=1, + max_model_len=8192, + trust_remote_code=True, + enforce_eager=True) +stop_token_ids = [151329, 151336, 151338] +sampling_params = SamplingParams(temperature=0.2, + max_tokens=1024, + stop_token_ids=stop_token_ids) + +prompt = "What's the content of the image?" +image = Image.open("your image").convert('RGB') +inputs = { + "prompt": prompt, + "multi_modal_data": { + "image": image + }, + } +outputs = llm.generate(inputs, sampling_params=sampling_params) + +for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +``` ## Complete project list diff --git a/basic_demo/vllm_cli_vision_demo.py b/basic_demo/vllm_cli_vision_demo.py new file mode 100644 index 0000000..1ed67d5 --- /dev/null +++ b/basic_demo/vllm_cli_vision_demo.py @@ -0,0 +1,107 @@ +""" +This script creates a CLI demo with vllm backand for the glm-4v-9b model, +allowing users to interact with the model through a command-line interface. + +Usage: +- Run the script to start the CLI demo. +- Interact with the model by typing questions and receiving responses. + +Note: The script includes a modification to handle markdown to plain text conversion, +ensuring that the CLI interface displays formatted text correctly. +""" +import time +import asyncio +from PIL import Image +from typing import List, Dict +from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine + +MODEL_PATH = 'THUDM/glm-4v-9b' + +def load_model_and_tokenizer(model_dir: str): + engine_args = AsyncEngineArgs( + model=model_dir, + tensor_parallel_size=1, + dtype="bfloat16", + trust_remote_code=True, + gpu_memory_utilization=0.9, + enforce_eager=True, + worker_use_ray=True, + disable_log_requests=True + # 如果遇见 OOM 现象,建议开启下述参数 + # enable_chunked_prefill=True, + # max_num_batched_tokens=8192 + ) + engine = AsyncLLMEngine.from_engine_args(engine_args) + return engine + +engine = load_model_and_tokenizer(MODEL_PATH) + +async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): + inputs = messages[-1] + params_dict = { + "n": 1, + "best_of": 1, + "presence_penalty": 1.0, + "frequency_penalty": 0.0, + "temperature": temperature, + "top_p": top_p, + "top_k": -1, + "use_beam_search": False, + "length_penalty": 1, + "early_stopping": False, + "ignore_eos": False, + "max_tokens": max_dec_len, + "logprobs": None, + "prompt_logprobs": None, + "skip_special_tokens": True, + "stop_token_ids" :[151329, 151336, 151338] + } + sampling_params = SamplingParams(**params_dict) + + async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): + yield output.outputs[0].text + + +async def chat(): + history = [] + max_length = 8192 + top_p = 0.8 + temperature = 0.6 + image = None + + print("Welcome to the GLM-4v-9B CLI chat. Type your messages below.") + image_path = input("Image Path:") + try: + image = Image.open(image_path).convert("RGB") + except: + print("Invalid image path. Continuing with text conversation.") + while True: + user_input = input("\nYou: ") + if user_input.lower() in ["exit", "quit"]: + break + history.append([user_input, ""]) + + messages = [] + for idx, (user_msg, model_msg) in enumerate(history): + if idx == len(history) - 1 and not model_msg: + messages.append({ + "prompt": user_msg, + "multi_modal_data": { + "image": image + },}) + break + if user_msg: + messages.append({"role": "user", "prompt": user_msg}) + if model_msg: + messages.append({"role": "assistant", "prompt": model_msg}) + + print("\nGLM-4v: ", end="") + current_length = 0 + output = "" + async for output in vllm_gen(messages, top_p, temperature, max_length): + print(output[current_length:], end="", flush=True) + current_length = len(output) + history[-1][1] = output + +if __name__ == "__main__": + asyncio.run(chat())