From 293ca3b83421a3f548b56fcbddf7c1e421cb26df Mon Sep 17 00:00:00 2001 From: yybear Date: Wed, 12 Jun 2024 23:59:22 +0800 Subject: [PATCH] composite demo add openai backend --- basic_demo/openai_api_server.py | 4 +- composite_demo/README.md | 2 + composite_demo/README_en.md | 2 + composite_demo/src/client.py | 8 +++- composite_demo/src/clients/openai.py | 65 ++++++++++++++++++++++++++++ composite_demo/src/main.py | 3 +- 6 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 composite_demo/src/clients/openai.py diff --git a/basic_demo/openai_api_server.py b/basic_demo/openai_api_server.py index e758621..64df78b 100644 --- a/basic_demo/openai_api_server.py +++ b/basic_demo/openai_api_server.py @@ -16,8 +16,8 @@ from transformers import AutoTokenizer, LogitsProcessor from sse_starlette.sse import EventSourceResponse EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 - -MODEL_PATH = 'THUDM/glm-4-9b-chat' +import os +MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') MAX_MODEL_LENGTH = 8192 diff --git a/composite_demo/README.md b/composite_demo/README.md index 20c6de7..6e7e853 100644 --- a/composite_demo/README.md +++ b/composite_demo/README.md @@ -82,6 +82,8 @@ pnpm install Chat 模型支持使用 [vLLM](https://github.com/vllm-project/vllm) 推理。若要使用,请安装 vLLM 并设置环境变量 `USE_VLLM=1`。 +Chat 模型支持使用 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 推理。若要使用,请启动basic_demo目录下的openai_api_server并设置环境变量 `USE_API=1`。该功能可以解耦推理服务器和demo服务器。 + 如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=` 来指定。 ## 使用 diff --git a/composite_demo/README_en.md b/composite_demo/README_en.md index 5826227..79aa6a1 100644 --- a/composite_demo/README_en.md +++ b/composite_demo/README_en.md @@ -91,6 +91,8 @@ by `export *_MODEL_PATH=/path/to/model`. The models that can be specified includ The Chat model supports reasoning using [vLLM](https://github.com/vllm-project/vllm). To use it, please install vLLM and set the environment variable `USE_VLLM=1`. +The Chat model also supports reasoning using [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). To use it, please run `openai_api_server.py` in `basic_demo` and set the environment variable `USE_API=1`. This function is used to deploy inference server and demo server in different machine. + If you need to customize the Jupyter kernel, you can specify it by `export IPYKERNEL=`. ## Usage diff --git a/composite_demo/src/client.py b/composite_demo/src/client.py index bde3b5b..2ef2b45 100644 --- a/composite_demo/src/client.py +++ b/composite_demo/src/client.py @@ -21,6 +21,7 @@ from tools.tool_registry import ALL_TOOLS class ClientType(Enum): HF = auto() VLLM = auto() + API = auto() class Client(Protocol): @@ -34,7 +35,7 @@ class Client(Protocol): ) -> Generator[tuple[str | dict, list[dict]]]: ... -def process_input(history: list[dict], tools: list[dict]) -> list[dict]: +def process_input(history: list[dict], tools: list[dict], role_name_replace:dict=None) -> list[dict]: chat_history = [] if len(tools) > 0: chat_history.append( @@ -43,6 +44,8 @@ def process_input(history: list[dict], tools: list[dict]) -> list[dict]: for conversation in history: role = str(conversation.role).removeprefix("<|").removesuffix("|>") + if role_name_replace: + role = role_name_replace.get(role, role) item = { "role": role, "content": conversation.content, @@ -94,5 +97,8 @@ def get_client(model_path, typ: ClientType) -> Client: e.msg += "; did you forget to install vLLM?" raise return VLLMClient(model_path) + case ClientType.API: + from clients.openai import APIClient + return APIClient(model_path) raise NotImplementedError(f"Client type {typ} is not supported.") diff --git a/composite_demo/src/clients/openai.py b/composite_demo/src/clients/openai.py new file mode 100644 index 0000000..82f4c12 --- /dev/null +++ b/composite_demo/src/clients/openai.py @@ -0,0 +1,65 @@ +""" +OpenAI API client. +""" +from openai import OpenAI +from collections.abc import Generator + +from client import Client, process_input, process_response +from conversation import Conversation + +def format_openai_tool(origin_tools): + openai_tools = [] + for tool in origin_tools: + openai_param={} + for param in tool['params']: + openai_param[param['name']] = {} + openai_tool = { + "type": "function", + "function": { + "name": tool['name'], + "description": tool['description'], + "parameters": { + "type": "object", + "properties": { + param['name']:{'type':param['type'], 'description':param['description']} for param in tool['params'] + }, + "required": [param['name'] for param in tool['params'] if param['required']] + } + } + } + openai_tools.append(openai_tool) + return openai_tools + +class APIClient(Client): + def __init__(self, model_path: str): + base_url = "http://127.0.0.1:8000/v1/" + self.client = OpenAI(api_key="EMPTY", base_url=base_url) + self.use_stream= False + self.role_name_replace = {'observation':'tool'} + + def generate_stream( + self, + tools: list[dict], + history: list[Conversation], + **parameters, + ) -> Generator[tuple[str | dict, list[dict]]]: + chat_history = process_input(history, tools) + messages = process_input(history, '', role_name_replace=self.role_name_replace) + openai_tools = format_openai_tool(tools) + response = self.client.chat.completions.create( + model="glm-4", + messages=messages, + tools=openai_tools, + stream=self.use_stream, + max_tokens=parameters["max_new_tokens"], + temperature=parameters["temperature"], + presence_penalty=1.2, + top_p=parameters["top_p"], + tool_choice="auto" + ) + output = response.choices[0].message + if output.tool_calls: + glm4_output = output.tool_calls[0].function.name + '\n' + output.tool_calls[0].function.arguments + else: + glm4_output = output.content + yield process_response(glm4_output, chat_history) diff --git a/composite_demo/src/main.py b/composite_demo/src/main.py index 333eb80..0cb51b2 100644 --- a/composite_demo/src/main.py +++ b/composite_demo/src/main.py @@ -32,7 +32,7 @@ CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat") VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b") USE_VLLM = os.environ.get("USE_VLLM", "0") == "1" - +USE_API = os.environ.get("USE_API", "0") == "1" class Mode(str, Enum): ALL_TOOLS = "🛠️ All Tools" @@ -104,6 +104,7 @@ def build_client(mode: Mode) -> Client: case Mode.ALL_TOOLS: st.session_state.top_k = 10 typ = ClientType.VLLM if USE_VLLM else ClientType.HF + typ = ClientType.API if USE_API else typ return get_client(CHAT_MODEL_PATH, typ) case Mode.LONG_CTX: st.session_state.top_k = 10