openai demo update

#64 #36
This commit is contained in:
zR 2024-06-07 22:14:00 +08:00
parent 7fcaeba6cc
commit abe93e093d
2 changed files with 130 additions and 100 deletions

View File

@ -15,7 +15,7 @@ def function_chat():
"type": "function", "type": "function",
"function": { "function": {
"name": "get_current_weather", "name": "get_current_weather",
"description": "Get the current weather in a given location", "description": "Get the current weather",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -23,15 +23,19 @@ def function_chat():
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
}, },
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
}, },
"required": ["location"],
}, },
"required": ["location", "format"],
}, },
} }
},
] ]
# All Tools 能力: 绘图 # # All Tools 能力: 绘图
# messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}] # messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
# tools = [{"type": "cogview"}] # tools = [{"type": "cogview"}]
# #
@ -43,33 +47,34 @@ def function_chat():
model="glm-4", model="glm-4",
messages=messages, messages=messages,
tools=tools, tools=tools,
stream=False, # must use False
tool_choice="auto", # use "auto" to let the model choose the tool automatically tool_choice="auto", # use "auto" to let the model choose the tool automatically
# tool_choice={"type": "function", "function": {"name": "my_function"}}, # tool_choice={"type": "function", "function": {"name": "my_function"}},
) )
if response: if response:
content = response.choices[0].message.content print(response.choices[0].message)
print(content)
else: else:
print("Error:", response.status_code) print("Error:", response.status_code)
def simple_chat(use_stream=False): def simple_chat(use_stream=False):
messages = [ messages = [
{ {
"role": "system", "role": "system",
"content": "你是 GLM-4请你热情回答用户的问题", "content": "请在你输出的时候都带上“喵喵喵”三个字,放在开头",
}, },
{ {
"role": "user", "role": "user",
"content": "你好,请你用生动的话语给我讲一个小故事吧" "content": "你好,你是谁"
} }
] ]
response = client.chat.completions.create( response = client.chat.completions.create(
model="glm-4", model="glm-4",
messages=messages, messages=messages,
stream=use_stream, stream=use_stream,
max_tokens=1024, max_tokens=256,
temperature=0.8, temperature=0.1,
presence_penalty=1.1, presence_penalty=1.1,
top_p=0.8) top_p=0.8)
if response: if response:
@ -84,5 +89,5 @@ def simple_chat(use_stream=False):
if __name__ == "__main__": if __name__ == "__main__":
simple_chat() # simple_chat(use_stream=False)
function_chat() function_chat()

View File

@ -1,4 +1,3 @@
import os
import time import time
from asyncio.log import logger from asyncio.log import logger
@ -17,6 +16,7 @@ from transformers import AutoTokenizer, LogitsProcessor
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = 'THUDM/glm-4-9b-chat' MODEL_PATH = 'THUDM/glm-4-9b-chat'
MAX_MODEL_LENGTH = 8192 MAX_MODEL_LENGTH = 8192
@ -52,7 +52,12 @@ class ModelCard(BaseModel):
class ModelList(BaseModel): class ModelList(BaseModel):
object: str = "list" object: str = "list"
data: List[ModelCard] = [] data: List[ModelCard] = ["glm-4"]
class FunctionCall(BaseModel):
name: str
arguments: str
class FunctionCallResponse(BaseModel): class FunctionCallResponse(BaseModel):
@ -60,11 +65,23 @@ class FunctionCallResponse(BaseModel):
arguments: Optional[str] = None arguments: Optional[str] = None
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionMessageToolCall(BaseModel):
id: str
function: FunctionCall
type: Literal["function"]
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "tool"] role: Literal["user", "assistant", "system", "tool"]
content: str = None content: Optional[str] = None
name: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None function_call: Optional[FunctionCallResponse] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
@ -73,28 +90,25 @@ class DeltaMessage(BaseModel):
function_call: Optional[FunctionCallResponse] = None function_call: Optional[FunctionCallResponse] = None
class EmbeddingRequest(BaseModel): class ChatCompletionResponseChoice(BaseModel):
input: Union[List[str], str] index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "tool_calls"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
class CompletionUsage(BaseModel): choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
prompt_tokens: int created: Optional[int] = Field(default_factory=lambda: int(time.time()))
completion_tokens: int usage: Optional[UsageInfo] = None
total_tokens: int
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: CompletionUsage
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
@ -109,27 +123,6 @@ class ChatCompletionRequest(BaseModel):
repetition_penalty: Optional[float] = 1.1 repetition_penalty: Optional[float] = 1.1
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "function_call"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
class InvalidScoreLogitsProcessor(LogitsProcessor): class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__( def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor self, input_ids: torch.LongTensor, scores: torch.FloatTensor
@ -141,28 +134,39 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
content = "" lines = output.strip().split("\n")
for response in output.split("<|assistant|>"):
if "\n" in response: if len(lines) == 2:
metadata, content = response.split("\n", maxsplit=1) function_name = lines[0].strip()
else: arguments = lines[1].strip()
metadata, content = "", response special_tools = ["cogview", "simple_browser"]
if not metadata.strip():
content = content.strip() arguments_json = None
else: try:
if use_tool: arguments_json = json.loads(arguments)
parameters = eval(content.strip()) is_tool_call = True
except json.JSONDecodeError:
is_tool_call = function_name in special_tools
if is_tool_call and use_tool:
content = { content = {
"name": metadata.strip(), "name": function_name,
"arguments": json.dumps(parameters, ensure_ascii=False) "arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
} }
else: if function_name in special_tools:
content["text"] = arguments
return content
elif is_tool_call:
content = { content = {
"name": metadata.strip(), "name": function_name,
"content": content "content": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
} }
return content return content
return output.strip()
@torch.inference_mode() @torch.inference_mode()
async def generate_stream_glm4(params): async def generate_stream_glm4(params):
@ -240,7 +244,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
) )
msg_has_sys = True msg_has_sys = True
# add to metadata
if isinstance(tool_choice, dict) and tools: if isinstance(tool_choice, dict) and tools:
messages.append( messages.append(
{ {
@ -278,6 +281,12 @@ def process_messages(messages, tools=None, tool_choice="none"):
continue continue
messages.append({"role": role, "content": content}) messages.append({"role": role, "content": content})
if not tools or tool_choice == "none":
for m in _messages:
if m.role == 'system':
messages.insert(0, {"role": m.role, "content": m.content})
break
return messages return messages
@ -348,22 +357,33 @@ async def create_chat_completion(request: ChatCompletionRequest):
response["text"] = response["text"].strip() response["text"] = response["text"].strip()
usage = UsageInfo() usage = UsageInfo()
function_call, finish_reason = None, "stop" function_call, finish_reason = None, "stop"
tool_calls = None
if request.tools: if request.tools:
try: try:
function_call = process_response(response["text"], use_tool=True) function_call = process_response(response["text"], use_tool=True)
except: except Exception as e:
logger.warning( logger.warning(f"Failed to parse tool call: {e}")
"Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
if isinstance(function_call, dict): if isinstance(function_call, dict):
finish_reason = "function_call" finish_reason = "tool_calls"
function_call = FunctionCallResponse(**function_call) function_call_response = FunctionCallResponse(**function_call)
function_call_instance = FunctionCall(
name=function_call_response.name,
arguments=function_call_response.arguments
)
tool_calls = [
ChatCompletionMessageToolCall(
id=f"call_{int(time.time() * 1000)}",
function=function_call_instance,
type="function")]
message = ChatMessage( message = ChatMessage(
role="assistant", role="assistant",
content=response["text"], content=None if tool_calls else response["text"],
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, function_call=None,
tool_calls=tool_calls,
) )
logger.debug(f"==== message ====\n{message}") logger.debug(f"==== message ====\n{message}")
@ -402,11 +422,11 @@ async def predict(model_id: str, params: dict):
previous_text = decoded_unicode previous_text = decoded_unicode
finish_reason = new_response["finish_reason"] finish_reason = new_response["finish_reason"]
if len(delta_text) == 0 and finish_reason != "function_call": if len(delta_text) == 0 and finish_reason != "tool_calls":
continue continue
function_call = None function_call = None
if finish_reason == "function_call": if finish_reason == "tool_calls":
try: try:
function_call = process_response(decoded_unicode, use_tool=True) function_call = process_response(decoded_unicode, use_tool=True)
except: except:
@ -417,9 +437,15 @@ async def predict(model_id: str, params: dict):
function_call = FunctionCallResponse(**function_call) function_call = FunctionCallResponse(**function_call)
delta = DeltaMessage( delta = DeltaMessage(
content=delta_text, content=None,
role="assistant", role="assistant",
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, function_call=None,
tool_calls=[{
"id": f"call_{int(time.time() * 1000)}",
"index": 0,
"type": "function",
"function": function_call
}] if isinstance(function_call, FunctionCallResponse) else None,
) )
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
@ -458,12 +484,11 @@ async def predict_stream(model_id, gen_params):
decoded_unicode = new_response["text"] decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):] delta_text = decoded_unicode[len(output):]
output = decoded_unicode output = decoded_unicode
lines = output.strip().split("\n")
if not is_function_call and len(lines) >= 2:
is_function_call = True
if not is_function_call and len(output) > 7: if not is_function_call and len(output) > 7:
is_function_call = output and 'get_' in output
if is_function_call:
continue
finish_reason = new_response["finish_reason"] finish_reason = new_response["finish_reason"]
if not has_send_first_chunk: if not has_send_first_chunk:
message = DeltaMessage( message = DeltaMessage(
@ -538,9 +563,9 @@ if __name__ == "__main__":
tensor_parallel_size=1, tensor_parallel_size=1,
dtype="bfloat16", dtype="bfloat16",
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.3,
enforce_eager=True, enforce_eager=True,
worker_use_ray=True, worker_use_ray=False,
engine_use_ray=False, engine_use_ray=False,
disable_log_requests=True, disable_log_requests=True,
max_model_len=MAX_MODEL_LENGTH, max_model_len=MAX_MODEL_LENGTH,