From 293ca3b83421a3f548b56fcbddf7c1e421cb26df Mon Sep 17 00:00:00 2001
From: yybear <liujunjieqq33@126.com>
Date: Wed, 12 Jun 2024 23:59:22 +0800
Subject: [PATCH] composite demo add openai backend

---
 basic_demo/openai_api_server.py      |  4 +-
 composite_demo/README.md             |  2 +
 composite_demo/README_en.md          |  2 +
 composite_demo/src/client.py         |  8 +++-
 composite_demo/src/clients/openai.py | 65 ++++++++++++++++++++++++++++
 composite_demo/src/main.py           |  3 +-
 6 files changed, 80 insertions(+), 4 deletions(-)
 create mode 100644 composite_demo/src/clients/openai.py

diff --git a/basic_demo/openai_api_server.py b/basic_demo/openai_api_server.py
index e758621..64df78b 100644
--- a/basic_demo/openai_api_server.py
+++ b/basic_demo/openai_api_server.py
@@ -16,8 +16,8 @@ from transformers import AutoTokenizer, LogitsProcessor
 from sse_starlette.sse import EventSourceResponse
 
 EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
-
-MODEL_PATH = 'THUDM/glm-4-9b-chat'
+import os
+MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
 MAX_MODEL_LENGTH = 8192
 
 
diff --git a/composite_demo/README.md b/composite_demo/README.md
index 20c6de7..6e7e853 100644
--- a/composite_demo/README.md
+++ b/composite_demo/README.md
@@ -82,6 +82,8 @@ pnpm install
 
 Chat 模型支持使用 [vLLM](https://github.com/vllm-project/vllm) 推理。若要使用,请安装 vLLM 并设置环境变量 `USE_VLLM=1`。
 
+Chat 模型支持使用 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 推理。若要使用,请启动basic_demo目录下的openai_api_server并设置环境变量 `USE_API=1`。该功能可以解耦推理服务器和demo服务器。
+
 如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。
 
 ## 使用
diff --git a/composite_demo/README_en.md b/composite_demo/README_en.md
index 5826227..79aa6a1 100644
--- a/composite_demo/README_en.md
+++ b/composite_demo/README_en.md
@@ -91,6 +91,8 @@ by `export *_MODEL_PATH=/path/to/model`. The models that can be specified includ
 The Chat model supports reasoning using [vLLM](https://github.com/vllm-project/vllm). To use it, please install vLLM and
 set the environment variable `USE_VLLM=1`.
 
+The Chat model also supports reasoning using [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). To use it, please run `openai_api_server.py` in `basic_demo` and set the environment variable `USE_API=1`. This function is used to deploy inference server and demo server in different machine.
+
 If you need to customize the Jupyter kernel, you can specify it by `export IPYKERNEL=<kernel_name>`.
 
 ## Usage
diff --git a/composite_demo/src/client.py b/composite_demo/src/client.py
index bde3b5b..2ef2b45 100644
--- a/composite_demo/src/client.py
+++ b/composite_demo/src/client.py
@@ -21,6 +21,7 @@ from tools.tool_registry import ALL_TOOLS
 class ClientType(Enum):
     HF = auto()
     VLLM = auto()
+    API = auto()
 
 
 class Client(Protocol):
@@ -34,7 +35,7 @@ class Client(Protocol):
     ) -> Generator[tuple[str | dict, list[dict]]]: ...
 
 
-def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
+def process_input(history: list[dict], tools: list[dict], role_name_replace:dict=None) -> list[dict]:
     chat_history = []
     if len(tools) > 0:
         chat_history.append(
@@ -43,6 +44,8 @@ def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
 
     for conversation in history:
         role = str(conversation.role).removeprefix("<|").removesuffix("|>")
+        if role_name_replace:
+            role = role_name_replace.get(role, role)
         item = {
             "role": role,
             "content": conversation.content,
@@ -94,5 +97,8 @@ def get_client(model_path, typ: ClientType) -> Client:
                 e.msg += "; did you forget to install vLLM?"
                 raise
             return VLLMClient(model_path)
+        case ClientType.API:
+            from clients.openai import APIClient
+            return APIClient(model_path)
 
     raise NotImplementedError(f"Client type {typ} is not supported.")
diff --git a/composite_demo/src/clients/openai.py b/composite_demo/src/clients/openai.py
new file mode 100644
index 0000000..82f4c12
--- /dev/null
+++ b/composite_demo/src/clients/openai.py
@@ -0,0 +1,65 @@
+"""
+OpenAI API client.
+"""
+from openai import OpenAI
+from collections.abc import Generator
+
+from client import Client, process_input, process_response
+from conversation import Conversation
+
+def format_openai_tool(origin_tools):
+    openai_tools = []
+    for tool in origin_tools:
+        openai_param={}
+        for param in tool['params']:
+            openai_param[param['name']] = {}
+        openai_tool = {
+            "type": "function",
+            "function": {
+                "name": tool['name'],
+                "description": tool['description'],
+                "parameters": {
+                    "type": "object",
+                    "properties": {
+                        param['name']:{'type':param['type'], 'description':param['description']} for param in tool['params']
+                    },
+                    "required": [param['name'] for param in tool['params'] if param['required']]
+                    }
+                }
+            }
+        openai_tools.append(openai_tool)
+    return openai_tools
+
+class APIClient(Client):
+    def __init__(self, model_path: str):
+        base_url = "http://127.0.0.1:8000/v1/"
+        self.client = OpenAI(api_key="EMPTY", base_url=base_url)
+        self.use_stream= False
+        self.role_name_replace = {'observation':'tool'}
+
+    def generate_stream(
+        self,
+        tools: list[dict],
+        history: list[Conversation],
+        **parameters,
+    ) -> Generator[tuple[str | dict, list[dict]]]:
+        chat_history = process_input(history, tools)
+        messages = process_input(history, '', role_name_replace=self.role_name_replace)
+        openai_tools = format_openai_tool(tools)
+        response = self.client.chat.completions.create(
+            model="glm-4",
+            messages=messages,
+            tools=openai_tools,
+            stream=self.use_stream,
+            max_tokens=parameters["max_new_tokens"],
+            temperature=parameters["temperature"],
+            presence_penalty=1.2,
+            top_p=parameters["top_p"],
+            tool_choice="auto"
+        )
+        output = response.choices[0].message
+        if output.tool_calls:
+            glm4_output = output.tool_calls[0].function.name + '\n' + output.tool_calls[0].function.arguments
+        else:
+            glm4_output = output.content
+        yield process_response(glm4_output, chat_history)
diff --git a/composite_demo/src/main.py b/composite_demo/src/main.py
index 333eb80..0cb51b2 100644
--- a/composite_demo/src/main.py
+++ b/composite_demo/src/main.py
@@ -32,7 +32,7 @@ CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
 VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
 
 USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
-
+USE_API = os.environ.get("USE_API", "0") == "1"
 
 class Mode(str, Enum):
     ALL_TOOLS = "🛠️ All Tools"
@@ -104,6 +104,7 @@ def build_client(mode: Mode) -> Client:
         case Mode.ALL_TOOLS:
             st.session_state.top_k = 10
             typ = ClientType.VLLM if USE_VLLM else ClientType.HF
+            typ = ClientType.API if USE_API else typ
             return get_client(CHAT_MODEL_PATH, typ)
         case Mode.LONG_CTX:
             st.session_state.top_k = 10