parent
7fcaeba6cc
commit
abe93e093d
|
@ -15,7 +15,7 @@ def function_chat():
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -23,15 +23,19 @@ def function_chat():
|
|||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# All Tools 能力: 绘图
|
||||
# # All Tools 能力: 绘图
|
||||
# messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
|
||||
# tools = [{"type": "cogview"}]
|
||||
#
|
||||
|
@ -43,33 +47,34 @@ def function_chat():
|
|||
model="glm-4",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=False, # must use False
|
||||
tool_choice="auto", # use "auto" to let the model choose the tool automatically
|
||||
# tool_choice={"type": "function", "function": {"name": "my_function"}},
|
||||
)
|
||||
if response:
|
||||
content = response.choices[0].message.content
|
||||
print(content)
|
||||
print(response.choices[0].message)
|
||||
else:
|
||||
print("Error:", response.status_code)
|
||||
|
||||
|
||||
|
||||
def simple_chat(use_stream=False):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是 GLM-4,请你热情回答用户的问题。",
|
||||
"content": "请在你输出的时候都带上“喵喵喵”三个字,放在开头。",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,请你用生动的话语给我讲一个小故事吧"
|
||||
"content": "你好,你是谁"
|
||||
}
|
||||
]
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4",
|
||||
messages=messages,
|
||||
stream=use_stream,
|
||||
max_tokens=1024,
|
||||
temperature=0.8,
|
||||
max_tokens=256,
|
||||
temperature=0.1,
|
||||
presence_penalty=1.1,
|
||||
top_p=0.8)
|
||||
if response:
|
||||
|
@ -84,5 +89,5 @@ def simple_chat(use_stream=False):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
simple_chat()
|
||||
# simple_chat(use_stream=False)
|
||||
function_chat()
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import time
|
||||
from asyncio.log import logger
|
||||
|
||||
|
@ -17,6 +16,7 @@ from transformers import AutoTokenizer, LogitsProcessor
|
|||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
|
||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
||||
MAX_MODEL_LENGTH = 8192
|
||||
|
||||
|
@ -52,7 +52,12 @@ class ModelCard(BaseModel):
|
|||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = []
|
||||
data: List[ModelCard] = ["glm-4"]
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionCallResponse(BaseModel):
|
||||
|
@ -60,11 +65,23 @@ class FunctionCallResponse(BaseModel):
|
|||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCall(BaseModel):
|
||||
id: str
|
||||
function: FunctionCall
|
||||
type: Literal["function"]
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system", "tool"]
|
||||
content: str = None
|
||||
name: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[FunctionCallResponse] = None
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
|
@ -73,28 +90,25 @@ class DeltaMessage(BaseModel):
|
|||
function_call: Optional[FunctionCallResponse] = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: Union[List[str], str]
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
index: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
data: list
|
||||
model: str
|
||||
object: str
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
id: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
|
@ -109,27 +123,6 @@ class ChatCompletionRequest(BaseModel):
|
|||
repetition_penalty: Optional[float] = 1.1
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length", "function_call"]
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length", "function_call"]]
|
||||
index: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
id: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
|
@ -141,27 +134,38 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|||
|
||||
|
||||
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
||||
content = ""
|
||||
for response in output.split("<|assistant|>"):
|
||||
if "\n" in response:
|
||||
metadata, content = response.split("\n", maxsplit=1)
|
||||
else:
|
||||
metadata, content = "", response
|
||||
if not metadata.strip():
|
||||
content = content.strip()
|
||||
else:
|
||||
if use_tool:
|
||||
parameters = eval(content.strip())
|
||||
content = {
|
||||
"name": metadata.strip(),
|
||||
"arguments": json.dumps(parameters, ensure_ascii=False)
|
||||
}
|
||||
else:
|
||||
content = {
|
||||
"name": metadata.strip(),
|
||||
"content": content
|
||||
}
|
||||
return content
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
if len(lines) == 2:
|
||||
function_name = lines[0].strip()
|
||||
arguments = lines[1].strip()
|
||||
special_tools = ["cogview", "simple_browser"]
|
||||
|
||||
arguments_json = None
|
||||
try:
|
||||
arguments_json = json.loads(arguments)
|
||||
is_tool_call = True
|
||||
except json.JSONDecodeError:
|
||||
is_tool_call = function_name in special_tools
|
||||
|
||||
if is_tool_call and use_tool:
|
||||
content = {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||
ensure_ascii=False)
|
||||
}
|
||||
if function_name in special_tools:
|
||||
content["text"] = arguments
|
||||
return content
|
||||
elif is_tool_call:
|
||||
content = {
|
||||
"name": function_name,
|
||||
"content": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||
ensure_ascii=False)
|
||||
}
|
||||
return content
|
||||
|
||||
return output.strip()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@ -238,9 +242,8 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
|||
"tools": tools
|
||||
}
|
||||
)
|
||||
msg_has_sys = True
|
||||
msg_has_sys = True
|
||||
|
||||
# add to metadata
|
||||
if isinstance(tool_choice, dict) and tools:
|
||||
messages.append(
|
||||
{
|
||||
|
@ -278,6 +281,12 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
|||
continue
|
||||
messages.append({"role": role, "content": content})
|
||||
|
||||
if not tools or tool_choice == "none":
|
||||
for m in _messages:
|
||||
if m.role == 'system':
|
||||
messages.insert(0, {"role": m.role, "content": m.content})
|
||||
break
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
|
@ -348,22 +357,33 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
response["text"] = response["text"].strip()
|
||||
|
||||
usage = UsageInfo()
|
||||
|
||||
function_call, finish_reason = None, "stop"
|
||||
tool_calls = None
|
||||
if request.tools:
|
||||
try:
|
||||
function_call = process_response(response["text"], use_tool=True)
|
||||
except:
|
||||
logger.warning(
|
||||
"Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse tool call: {e}")
|
||||
|
||||
if isinstance(function_call, dict):
|
||||
finish_reason = "function_call"
|
||||
function_call = FunctionCallResponse(**function_call)
|
||||
finish_reason = "tool_calls"
|
||||
function_call_response = FunctionCallResponse(**function_call)
|
||||
function_call_instance = FunctionCall(
|
||||
name=function_call_response.name,
|
||||
arguments=function_call_response.arguments
|
||||
)
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{int(time.time() * 1000)}",
|
||||
function=function_call_instance,
|
||||
type="function")]
|
||||
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=response["text"],
|
||||
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
||||
content=None if tool_calls else response["text"],
|
||||
function_call=None,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
logger.debug(f"==== message ====\n{message}")
|
||||
|
@ -402,11 +422,11 @@ async def predict(model_id: str, params: dict):
|
|||
previous_text = decoded_unicode
|
||||
|
||||
finish_reason = new_response["finish_reason"]
|
||||
if len(delta_text) == 0 and finish_reason != "function_call":
|
||||
if len(delta_text) == 0 and finish_reason != "tool_calls":
|
||||
continue
|
||||
|
||||
function_call = None
|
||||
if finish_reason == "function_call":
|
||||
if finish_reason == "tool_calls":
|
||||
try:
|
||||
function_call = process_response(decoded_unicode, use_tool=True)
|
||||
except:
|
||||
|
@ -417,9 +437,15 @@ async def predict(model_id: str, params: dict):
|
|||
function_call = FunctionCallResponse(**function_call)
|
||||
|
||||
delta = DeltaMessage(
|
||||
content=delta_text,
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
||||
function_call=None,
|
||||
tool_calls=[{
|
||||
"id": f"call_{int(time.time() * 1000)}",
|
||||
"index": 0,
|
||||
"type": "function",
|
||||
"function": function_call
|
||||
}] if isinstance(function_call, FunctionCallResponse) else None,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
|
@ -454,16 +480,15 @@ async def predict_stream(model_id, gen_params):
|
|||
output = ""
|
||||
is_function_call = False
|
||||
has_send_first_chunk = False
|
||||
async for new_response in generate_stream_glm4(gen_params):
|
||||
async for new_response in generate_stream_glm4(gen_params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text = decoded_unicode[len(output):]
|
||||
output = decoded_unicode
|
||||
lines = output.strip().split("\n")
|
||||
if not is_function_call and len(lines) >= 2:
|
||||
is_function_call = True
|
||||
|
||||
if not is_function_call and len(output) > 7:
|
||||
is_function_call = output and 'get_' in output
|
||||
if is_function_call:
|
||||
continue
|
||||
|
||||
finish_reason = new_response["finish_reason"]
|
||||
if not has_send_first_chunk:
|
||||
message = DeltaMessage(
|
||||
|
@ -538,9 +563,9 @@ if __name__ == "__main__":
|
|||
tensor_parallel_size=1,
|
||||
dtype="bfloat16",
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.9,
|
||||
gpu_memory_utilization=0.3,
|
||||
enforce_eager=True,
|
||||
worker_use_ray=True,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
max_model_len=MAX_MODEL_LENGTH,
|
||||
|
|
Loading…
Reference in New Issue