openai demo update

#64 #36
This commit is contained in:
zR 2024-06-07 22:14:00 +08:00
parent 7fcaeba6cc
commit abe93e093d
2 changed files with 130 additions and 100 deletions

View File

@ -15,7 +15,7 @@ def function_chat():
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
@ -23,15 +23,19 @@ def function_chat():
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location"],
"required": ["location", "format"],
},
},
}
}
},
]
# All Tools 能力: 绘图
# # All Tools 能力: 绘图
# messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
# tools = [{"type": "cogview"}]
#
@ -43,33 +47,34 @@ def function_chat():
model="glm-4",
messages=messages,
tools=tools,
stream=False, # must use False
tool_choice="auto", # use "auto" to let the model choose the tool automatically
# tool_choice={"type": "function", "function": {"name": "my_function"}},
)
if response:
content = response.choices[0].message.content
print(content)
print(response.choices[0].message)
else:
print("Error:", response.status_code)
def simple_chat(use_stream=False):
messages = [
{
"role": "system",
"content": "你是 GLM-4请你热情回答用户的问题",
"content": "请在你输出的时候都带上“喵喵喵”三个字,放在开头",
},
{
"role": "user",
"content": "你好,请你用生动的话语给我讲一个小故事吧"
"content": "你好,你是谁"
}
]
response = client.chat.completions.create(
model="glm-4",
messages=messages,
stream=use_stream,
max_tokens=1024,
temperature=0.8,
max_tokens=256,
temperature=0.1,
presence_penalty=1.1,
top_p=0.8)
if response:
@ -84,5 +89,5 @@ def simple_chat(use_stream=False):
if __name__ == "__main__":
simple_chat()
# simple_chat(use_stream=False)
function_chat()

View File

@ -1,4 +1,3 @@
import os
import time
from asyncio.log import logger
@ -17,6 +16,7 @@ from transformers import AutoTokenizer, LogitsProcessor
from sse_starlette.sse import EventSourceResponse
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
MODEL_PATH = 'THUDM/glm-4-9b-chat'
MAX_MODEL_LENGTH = 8192
@ -52,7 +52,12 @@ class ModelCard(BaseModel):
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
data: List[ModelCard] = ["glm-4"]
class FunctionCall(BaseModel):
name: str
arguments: str
class FunctionCallResponse(BaseModel):
@ -60,11 +65,23 @@ class FunctionCallResponse(BaseModel):
arguments: Optional[str] = None
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionMessageToolCall(BaseModel):
id: str
function: FunctionCall
type: Literal["function"]
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "tool"]
content: str = None
name: Optional[str] = None
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class DeltaMessage(BaseModel):
@ -73,28 +90,25 @@ class DeltaMessage(BaseModel):
function_call: Optional[FunctionCallResponse] = None
class EmbeddingRequest(BaseModel):
input: Union[List[str], str]
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "tool_calls"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: CompletionUsage
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
class ChatCompletionRequest(BaseModel):
@ -109,27 +123,6 @@ class ChatCompletionRequest(BaseModel):
repetition_penalty: Optional[float] = 1.1
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "function_call"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
@ -141,27 +134,38 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
content = ""
for response in output.split("<|assistant|>"):
if "\n" in response:
metadata, content = response.split("\n", maxsplit=1)
else:
metadata, content = "", response
if not metadata.strip():
content = content.strip()
else:
if use_tool:
parameters = eval(content.strip())
content = {
"name": metadata.strip(),
"arguments": json.dumps(parameters, ensure_ascii=False)
}
else:
content = {
"name": metadata.strip(),
"content": content
}
return content
lines = output.strip().split("\n")
if len(lines) == 2:
function_name = lines[0].strip()
arguments = lines[1].strip()
special_tools = ["cogview", "simple_browser"]
arguments_json = None
try:
arguments_json = json.loads(arguments)
is_tool_call = True
except json.JSONDecodeError:
is_tool_call = function_name in special_tools
if is_tool_call and use_tool:
content = {
"name": function_name,
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
}
if function_name in special_tools:
content["text"] = arguments
return content
elif is_tool_call:
content = {
"name": function_name,
"content": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
}
return content
return output.strip()
@torch.inference_mode()
@ -238,9 +242,8 @@ def process_messages(messages, tools=None, tool_choice="none"):
"tools": tools
}
)
msg_has_sys = True
msg_has_sys = True
# add to metadata
if isinstance(tool_choice, dict) and tools:
messages.append(
{
@ -278,6 +281,12 @@ def process_messages(messages, tools=None, tool_choice="none"):
continue
messages.append({"role": role, "content": content})
if not tools or tool_choice == "none":
for m in _messages:
if m.role == 'system':
messages.insert(0, {"role": m.role, "content": m.content})
break
return messages
@ -348,22 +357,33 @@ async def create_chat_completion(request: ChatCompletionRequest):
response["text"] = response["text"].strip()
usage = UsageInfo()
function_call, finish_reason = None, "stop"
tool_calls = None
if request.tools:
try:
function_call = process_response(response["text"], use_tool=True)
except:
logger.warning(
"Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
except Exception as e:
logger.warning(f"Failed to parse tool call: {e}")
if isinstance(function_call, dict):
finish_reason = "function_call"
function_call = FunctionCallResponse(**function_call)
finish_reason = "tool_calls"
function_call_response = FunctionCallResponse(**function_call)
function_call_instance = FunctionCall(
name=function_call_response.name,
arguments=function_call_response.arguments
)
tool_calls = [
ChatCompletionMessageToolCall(
id=f"call_{int(time.time() * 1000)}",
function=function_call_instance,
type="function")]
message = ChatMessage(
role="assistant",
content=response["text"],
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
content=None if tool_calls else response["text"],
function_call=None,
tool_calls=tool_calls,
)
logger.debug(f"==== message ====\n{message}")
@ -402,11 +422,11 @@ async def predict(model_id: str, params: dict):
previous_text = decoded_unicode
finish_reason = new_response["finish_reason"]
if len(delta_text) == 0 and finish_reason != "function_call":
if len(delta_text) == 0 and finish_reason != "tool_calls":
continue
function_call = None
if finish_reason == "function_call":
if finish_reason == "tool_calls":
try:
function_call = process_response(decoded_unicode, use_tool=True)
except:
@ -417,9 +437,15 @@ async def predict(model_id: str, params: dict):
function_call = FunctionCallResponse(**function_call)
delta = DeltaMessage(
content=delta_text,
content=None,
role="assistant",
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
function_call=None,
tool_calls=[{
"id": f"call_{int(time.time() * 1000)}",
"index": 0,
"type": "function",
"function": function_call
}] if isinstance(function_call, FunctionCallResponse) else None,
)
choice_data = ChatCompletionResponseStreamChoice(
@ -454,16 +480,15 @@ async def predict_stream(model_id, gen_params):
output = ""
is_function_call = False
has_send_first_chunk = False
async for new_response in generate_stream_glm4(gen_params):
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode
lines = output.strip().split("\n")
if not is_function_call and len(lines) >= 2:
is_function_call = True
if not is_function_call and len(output) > 7:
is_function_call = output and 'get_' in output
if is_function_call:
continue
finish_reason = new_response["finish_reason"]
if not has_send_first_chunk:
message = DeltaMessage(
@ -538,9 +563,9 @@ if __name__ == "__main__":
tensor_parallel_size=1,
dtype="bfloat16",
trust_remote_code=True,
gpu_memory_utilization=0.9,
gpu_memory_utilization=0.3,
enforce_eager=True,
worker_use_ray=True,
worker_use_ray=False,
engine_use_ray=False,
disable_log_requests=True,
max_model_len=MAX_MODEL_LENGTH,