Merge pull request #155 from qq332982511/add_api_backend
composite demo add openai backend
This commit is contained in:
commit
bab384d193
|
@ -18,8 +18,8 @@ from transformers import AutoTokenizer, LogitsProcessor
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||||
|
import os
|
||||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
||||||
MAX_MODEL_LENGTH = 8192
|
MAX_MODEL_LENGTH = 8192
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -82,6 +82,8 @@ pnpm install
|
||||||
|
|
||||||
Chat 模型支持使用 [vLLM](https://github.com/vllm-project/vllm) 推理。若要使用,请安装 vLLM 并设置环境变量 `USE_VLLM=1`。
|
Chat 模型支持使用 [vLLM](https://github.com/vllm-project/vllm) 推理。若要使用,请安装 vLLM 并设置环境变量 `USE_VLLM=1`。
|
||||||
|
|
||||||
|
Chat 模型支持使用 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 推理。若要使用,请启动basic_demo目录下的openai_api_server并设置环境变量 `USE_API=1`。该功能可以解耦推理服务器和demo服务器。
|
||||||
|
|
||||||
如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。
|
如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。
|
||||||
|
|
||||||
## 使用
|
## 使用
|
||||||
|
|
|
@ -91,6 +91,8 @@ by `export *_MODEL_PATH=/path/to/model`. The models that can be specified includ
|
||||||
The Chat model supports reasoning using [vLLM](https://github.com/vllm-project/vllm). To use it, please install vLLM and
|
The Chat model supports reasoning using [vLLM](https://github.com/vllm-project/vllm). To use it, please install vLLM and
|
||||||
set the environment variable `USE_VLLM=1`.
|
set the environment variable `USE_VLLM=1`.
|
||||||
|
|
||||||
|
The Chat model also supports reasoning using [OpenAI API](https://platform.openai.com/docs/api-reference/introduction). To use it, please run `openai_api_server.py` in `basic_demo` and set the environment variable `USE_API=1`. This function is used to deploy inference server and demo server in different machine.
|
||||||
|
|
||||||
If you need to customize the Jupyter kernel, you can specify it by `export IPYKERNEL=<kernel_name>`.
|
If you need to customize the Jupyter kernel, you can specify it by `export IPYKERNEL=<kernel_name>`.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
|
@ -21,6 +21,7 @@ from tools.tool_registry import ALL_TOOLS
|
||||||
class ClientType(Enum):
|
class ClientType(Enum):
|
||||||
HF = auto()
|
HF = auto()
|
||||||
VLLM = auto()
|
VLLM = auto()
|
||||||
|
API = auto()
|
||||||
|
|
||||||
|
|
||||||
class Client(Protocol):
|
class Client(Protocol):
|
||||||
|
@ -34,7 +35,7 @@ class Client(Protocol):
|
||||||
) -> Generator[tuple[str | dict, list[dict]]]: ...
|
) -> Generator[tuple[str | dict, list[dict]]]: ...
|
||||||
|
|
||||||
|
|
||||||
def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
|
def process_input(history: list[dict], tools: list[dict], role_name_replace:dict=None) -> list[dict]:
|
||||||
chat_history = []
|
chat_history = []
|
||||||
if len(tools) > 0:
|
if len(tools) > 0:
|
||||||
chat_history.append(
|
chat_history.append(
|
||||||
|
@ -43,6 +44,8 @@ def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
|
||||||
|
|
||||||
for conversation in history:
|
for conversation in history:
|
||||||
role = str(conversation.role).removeprefix("<|").removesuffix("|>")
|
role = str(conversation.role).removeprefix("<|").removesuffix("|>")
|
||||||
|
if role_name_replace:
|
||||||
|
role = role_name_replace.get(role, role)
|
||||||
item = {
|
item = {
|
||||||
"role": role,
|
"role": role,
|
||||||
"content": conversation.content,
|
"content": conversation.content,
|
||||||
|
@ -94,5 +97,8 @@ def get_client(model_path, typ: ClientType) -> Client:
|
||||||
e.msg += "; did you forget to install vLLM?"
|
e.msg += "; did you forget to install vLLM?"
|
||||||
raise
|
raise
|
||||||
return VLLMClient(model_path)
|
return VLLMClient(model_path)
|
||||||
|
case ClientType.API:
|
||||||
|
from clients.openai import APIClient
|
||||||
|
return APIClient(model_path)
|
||||||
|
|
||||||
raise NotImplementedError(f"Client type {typ} is not supported.")
|
raise NotImplementedError(f"Client type {typ} is not supported.")
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
"""
|
||||||
|
OpenAI API client.
|
||||||
|
"""
|
||||||
|
from openai import OpenAI
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from client import Client, process_input, process_response
|
||||||
|
from conversation import Conversation
|
||||||
|
|
||||||
|
def format_openai_tool(origin_tools):
|
||||||
|
openai_tools = []
|
||||||
|
for tool in origin_tools:
|
||||||
|
openai_param={}
|
||||||
|
for param in tool['params']:
|
||||||
|
openai_param[param['name']] = {}
|
||||||
|
openai_tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool['name'],
|
||||||
|
"description": tool['description'],
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
param['name']:{'type':param['type'], 'description':param['description']} for param in tool['params']
|
||||||
|
},
|
||||||
|
"required": [param['name'] for param in tool['params'] if param['required']]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
openai_tools.append(openai_tool)
|
||||||
|
return openai_tools
|
||||||
|
|
||||||
|
class APIClient(Client):
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
base_url = "http://127.0.0.1:8000/v1/"
|
||||||
|
self.client = OpenAI(api_key="EMPTY", base_url=base_url)
|
||||||
|
self.use_stream= False
|
||||||
|
self.role_name_replace = {'observation':'tool'}
|
||||||
|
|
||||||
|
def generate_stream(
|
||||||
|
self,
|
||||||
|
tools: list[dict],
|
||||||
|
history: list[Conversation],
|
||||||
|
**parameters,
|
||||||
|
) -> Generator[tuple[str | dict, list[dict]]]:
|
||||||
|
chat_history = process_input(history, tools)
|
||||||
|
messages = process_input(history, '', role_name_replace=self.role_name_replace)
|
||||||
|
openai_tools = format_openai_tool(tools)
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="glm-4",
|
||||||
|
messages=messages,
|
||||||
|
tools=openai_tools,
|
||||||
|
stream=self.use_stream,
|
||||||
|
max_tokens=parameters["max_new_tokens"],
|
||||||
|
temperature=parameters["temperature"],
|
||||||
|
presence_penalty=1.2,
|
||||||
|
top_p=parameters["top_p"],
|
||||||
|
tool_choice="auto"
|
||||||
|
)
|
||||||
|
output = response.choices[0].message
|
||||||
|
if output.tool_calls:
|
||||||
|
glm4_output = output.tool_calls[0].function.name + '\n' + output.tool_calls[0].function.arguments
|
||||||
|
else:
|
||||||
|
glm4_output = output.content
|
||||||
|
yield process_response(glm4_output, chat_history)
|
|
@ -32,7 +32,7 @@ CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
|
||||||
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
|
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
|
||||||
|
|
||||||
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
|
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
|
||||||
|
USE_API = os.environ.get("USE_API", "0") == "1"
|
||||||
|
|
||||||
class Mode(str, Enum):
|
class Mode(str, Enum):
|
||||||
ALL_TOOLS = "🛠️ All Tools"
|
ALL_TOOLS = "🛠️ All Tools"
|
||||||
|
@ -104,6 +104,7 @@ def build_client(mode: Mode) -> Client:
|
||||||
case Mode.ALL_TOOLS:
|
case Mode.ALL_TOOLS:
|
||||||
st.session_state.top_k = 10
|
st.session_state.top_k = 10
|
||||||
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
|
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
|
||||||
|
typ = ClientType.API if USE_API else typ
|
||||||
return get_client(CHAT_MODEL_PATH, typ)
|
return get_client(CHAT_MODEL_PATH, typ)
|
||||||
case Mode.LONG_CTX:
|
case Mode.LONG_CTX:
|
||||||
st.session_state.top_k = 10
|
st.session_state.top_k = 10
|
||||||
|
|
Loading…
Reference in New Issue