composite demo add openai backend

This commit is contained in:
yybear 2024-06-12 23:59:22 +08:00
parent adeeb0e8e0
commit 293ca3b834
6 changed files with 80 additions and 4 deletions

View File

@ -16,8 +16,8 @@ from transformers import AutoTokenizer, LogitsProcessor
from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = 'THUDM/glm-4-9b-chat'
import os
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
MAX_MODEL_LENGTH = 8192

View File

@ -82,6 +82,8 @@ pnpm install
Chat 模型支持使用 [vLLM](https://github.com/vllm-project/vllm) 推理。若要使用,请安装 vLLM 并设置环境变量 `USE_VLLM=1`
Chat 模型支持使用 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 推理。若要使用请启动basic_demo目录下的openai_api_server并设置环境变量 `USE_API=1`。该功能可以解耦推理服务器和demo服务器。
如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。
## 使用

View File

@ -91,6 +91,8 @@ by `export *_MODEL_PATH=/path/to/model`. The models that can be specified includ
The Chat model supports reasoning using [vLLM](https://github.com/vllm-project/vllm). To use it, please install vLLM and
set the environment variable `USE_VLLM=1`.
The Chat model also supports reasoning using [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). To use it, please run `openai_api_server.py` in `basic_demo` and set the environment variable `USE_API=1`. This function is used to deploy inference server and demo server in different machine.
If you need to customize the Jupyter kernel, you can specify it by `export IPYKERNEL=<kernel_name>`.
## Usage

View File

@ -21,6 +21,7 @@ from tools.tool_registry import ALL_TOOLS
class ClientType(Enum):
HF = auto()
VLLM = auto()
API = auto()
class Client(Protocol):
@ -34,7 +35,7 @@ class Client(Protocol):
) -> Generator[tuple[str | dict, list[dict]]]: ...
def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
def process_input(history: list[dict], tools: list[dict], role_name_replace:dict=None) -> list[dict]:
chat_history = []
if len(tools) > 0:
chat_history.append(
@ -43,6 +44,8 @@ def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
for conversation in history:
role = str(conversation.role).removeprefix("<|").removesuffix("|>")
if role_name_replace:
role = role_name_replace.get(role, role)
item = {
"role": role,
"content": conversation.content,
@ -94,5 +97,8 @@ def get_client(model_path, typ: ClientType) -> Client:
e.msg += "; did you forget to install vLLM?"
raise
return VLLMClient(model_path)
case ClientType.API:
from clients.openai import APIClient
return APIClient(model_path)
raise NotImplementedError(f"Client type {typ} is not supported.")

View File

@ -0,0 +1,65 @@
"""
OpenAI API client.
"""
from openai import OpenAI
from collections.abc import Generator
from client import Client, process_input, process_response
from conversation import Conversation
def format_openai_tool(origin_tools):
openai_tools = []
for tool in origin_tools:
openai_param={}
for param in tool['params']:
openai_param[param['name']] = {}
openai_tool = {
"type": "function",
"function": {
"name": tool['name'],
"description": tool['description'],
"parameters": {
"type": "object",
"properties": {
param['name']:{'type':param['type'], 'description':param['description']} for param in tool['params']
},
"required": [param['name'] for param in tool['params'] if param['required']]
}
}
}
openai_tools.append(openai_tool)
return openai_tools
class APIClient(Client):
def __init__(self, model_path: str):
base_url = "http://127.0.0.1:8000/v1/"
self.client = OpenAI(api_key="EMPTY", base_url=base_url)
self.use_stream= False
self.role_name_replace = {'observation':'tool'}
def generate_stream(
self,
tools: list[dict],
history: list[Conversation],
**parameters,
) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, tools)
messages = process_input(history, '', role_name_replace=self.role_name_replace)
openai_tools = format_openai_tool(tools)
response = self.client.chat.completions.create(
model="glm-4",
messages=messages,
tools=openai_tools,
stream=self.use_stream,
max_tokens=parameters["max_new_tokens"],
temperature=parameters["temperature"],
presence_penalty=1.2,
top_p=parameters["top_p"],
tool_choice="auto"
)
output = response.choices[0].message
if output.tool_calls:
glm4_output = output.tool_calls[0].function.name + '\n' + output.tool_calls[0].function.arguments
else:
glm4_output = output.content
yield process_response(glm4_output, chat_history)

View File

@ -32,7 +32,7 @@ CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
USE_API = os.environ.get("USE_API", "0") == "1"
class Mode(str, Enum):
ALL_TOOLS = "🛠️ All Tools"
@ -104,6 +104,7 @@ def build_client(mode: Mode) -> Client:
case Mode.ALL_TOOLS:
st.session_state.top_k = 10
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
typ = ClientType.API if USE_API else typ
return get_client(CHAT_MODEL_PATH, typ)
case Mode.LONG_CTX:
st.session_state.top_k = 10