parent
7fcaeba6cc
commit
abe93e093d
|
@ -15,7 +15,7 @@ def function_chat():
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"description": "Get the current weather in a given location",
|
"description": "Get the current weather",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -23,15 +23,19 @@ def function_chat():
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The city and state, e.g. San Francisco, CA",
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
},
|
},
|
||||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
},
|
},
|
||||||
"required": ["location"],
|
|
||||||
},
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# All Tools 能力: 绘图
|
# # All Tools 能力: 绘图
|
||||||
# messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
|
# messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
|
||||||
# tools = [{"type": "cogview"}]
|
# tools = [{"type": "cogview"}]
|
||||||
#
|
#
|
||||||
|
@ -43,33 +47,34 @@ def function_chat():
|
||||||
model="glm-4",
|
model="glm-4",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
stream=False, # must use False
|
||||||
tool_choice="auto", # use "auto" to let the model choose the tool automatically
|
tool_choice="auto", # use "auto" to let the model choose the tool automatically
|
||||||
# tool_choice={"type": "function", "function": {"name": "my_function"}},
|
# tool_choice={"type": "function", "function": {"name": "my_function"}},
|
||||||
)
|
)
|
||||||
if response:
|
if response:
|
||||||
content = response.choices[0].message.content
|
print(response.choices[0].message)
|
||||||
print(content)
|
|
||||||
else:
|
else:
|
||||||
print("Error:", response.status_code)
|
print("Error:", response.status_code)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def simple_chat(use_stream=False):
|
def simple_chat(use_stream=False):
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "你是 GLM-4,请你热情回答用户的问题。",
|
"content": "请在你输出的时候都带上“喵喵喵”三个字,放在开头。",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "你好,请你用生动的话语给我讲一个小故事吧"
|
"content": "你好,你是谁"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model="glm-4",
|
model="glm-4",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=use_stream,
|
stream=use_stream,
|
||||||
max_tokens=1024,
|
max_tokens=256,
|
||||||
temperature=0.8,
|
temperature=0.1,
|
||||||
presence_penalty=1.1,
|
presence_penalty=1.1,
|
||||||
top_p=0.8)
|
top_p=0.8)
|
||||||
if response:
|
if response:
|
||||||
|
@ -84,5 +89,5 @@ def simple_chat(use_stream=False):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
simple_chat()
|
# simple_chat(use_stream=False)
|
||||||
function_chat()
|
function_chat()
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from asyncio.log import logger
|
from asyncio.log import logger
|
||||||
|
|
||||||
|
@ -17,6 +16,7 @@ from transformers import AutoTokenizer, LogitsProcessor
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||||
|
|
||||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
||||||
MAX_MODEL_LENGTH = 8192
|
MAX_MODEL_LENGTH = 8192
|
||||||
|
|
||||||
|
@ -52,7 +52,12 @@ class ModelCard(BaseModel):
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
class ModelList(BaseModel):
|
||||||
object: str = "list"
|
object: str = "list"
|
||||||
data: List[ModelCard] = []
|
data: List[ModelCard] = ["glm-4"]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
class FunctionCallResponse(BaseModel):
|
class FunctionCallResponse(BaseModel):
|
||||||
|
@ -60,11 +65,23 @@ class FunctionCallResponse(BaseModel):
|
||||||
arguments: Optional[str] = None
|
arguments: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionMessageToolCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
function: FunctionCall
|
||||||
|
type: Literal["function"]
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Literal["user", "assistant", "system", "tool"]
|
role: Literal["user", "assistant", "system", "tool"]
|
||||||
content: str = None
|
content: Optional[str] = None
|
||||||
name: Optional[str] = None
|
|
||||||
function_call: Optional[FunctionCallResponse] = None
|
function_call: Optional[FunctionCallResponse] = None
|
||||||
|
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(BaseModel):
|
||||||
|
@ -73,28 +90,25 @@ class DeltaMessage(BaseModel):
|
||||||
function_call: Optional[FunctionCallResponse] = None
|
function_call: Optional[FunctionCallResponse] = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
input: Union[List[str], str]
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||||
class CompletionUsage(BaseModel):
|
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||||
prompt_tokens: int
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
completion_tokens: int
|
usage: Optional[UsageInfo] = None
|
||||||
total_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingResponse(BaseModel):
|
|
||||||
data: list
|
|
||||||
model: str
|
|
||||||
object: str
|
|
||||||
usage: CompletionUsage
|
|
||||||
|
|
||||||
|
|
||||||
class UsageInfo(BaseModel):
|
|
||||||
prompt_tokens: int = 0
|
|
||||||
total_tokens: int = 0
|
|
||||||
completion_tokens: Optional[int] = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
@ -109,27 +123,6 @@ class ChatCompletionRequest(BaseModel):
|
||||||
repetition_penalty: Optional[float] = 1.1
|
repetition_penalty: Optional[float] = 1.1
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
message: ChatMessage
|
|
||||||
finish_reason: Literal["stop", "length", "function_call"]
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
|
||||||
delta: DeltaMessage
|
|
||||||
finish_reason: Optional[Literal["stop", "length", "function_call"]]
|
|
||||||
index: int
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
|
||||||
model: str
|
|
||||||
id: str
|
|
||||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
|
||||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
|
||||||
usage: Optional[UsageInfo] = None
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||||
|
@ -141,28 +134,39 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
|
|
||||||
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
||||||
content = ""
|
lines = output.strip().split("\n")
|
||||||
for response in output.split("<|assistant|>"):
|
|
||||||
if "\n" in response:
|
if len(lines) == 2:
|
||||||
metadata, content = response.split("\n", maxsplit=1)
|
function_name = lines[0].strip()
|
||||||
else:
|
arguments = lines[1].strip()
|
||||||
metadata, content = "", response
|
special_tools = ["cogview", "simple_browser"]
|
||||||
if not metadata.strip():
|
|
||||||
content = content.strip()
|
arguments_json = None
|
||||||
else:
|
try:
|
||||||
if use_tool:
|
arguments_json = json.loads(arguments)
|
||||||
parameters = eval(content.strip())
|
is_tool_call = True
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
is_tool_call = function_name in special_tools
|
||||||
|
|
||||||
|
if is_tool_call and use_tool:
|
||||||
content = {
|
content = {
|
||||||
"name": metadata.strip(),
|
"name": function_name,
|
||||||
"arguments": json.dumps(parameters, ensure_ascii=False)
|
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
}
|
}
|
||||||
else:
|
if function_name in special_tools:
|
||||||
|
content["text"] = arguments
|
||||||
|
return content
|
||||||
|
elif is_tool_call:
|
||||||
content = {
|
content = {
|
||||||
"name": metadata.strip(),
|
"name": function_name,
|
||||||
"content": content
|
"content": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
}
|
}
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
return output.strip()
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
async def generate_stream_glm4(params):
|
async def generate_stream_glm4(params):
|
||||||
|
@ -240,7 +244,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
||||||
)
|
)
|
||||||
msg_has_sys = True
|
msg_has_sys = True
|
||||||
|
|
||||||
# add to metadata
|
|
||||||
if isinstance(tool_choice, dict) and tools:
|
if isinstance(tool_choice, dict) and tools:
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
|
@ -278,6 +281,12 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
||||||
continue
|
continue
|
||||||
messages.append({"role": role, "content": content})
|
messages.append({"role": role, "content": content})
|
||||||
|
|
||||||
|
if not tools or tool_choice == "none":
|
||||||
|
for m in _messages:
|
||||||
|
if m.role == 'system':
|
||||||
|
messages.insert(0, {"role": m.role, "content": m.content})
|
||||||
|
break
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
@ -348,22 +357,33 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
response["text"] = response["text"].strip()
|
response["text"] = response["text"].strip()
|
||||||
|
|
||||||
usage = UsageInfo()
|
usage = UsageInfo()
|
||||||
|
|
||||||
function_call, finish_reason = None, "stop"
|
function_call, finish_reason = None, "stop"
|
||||||
|
tool_calls = None
|
||||||
if request.tools:
|
if request.tools:
|
||||||
try:
|
try:
|
||||||
function_call = process_response(response["text"], use_tool=True)
|
function_call = process_response(response["text"], use_tool=True)
|
||||||
except:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Failed to parse tool call: {e}")
|
||||||
"Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
|
|
||||||
|
|
||||||
if isinstance(function_call, dict):
|
if isinstance(function_call, dict):
|
||||||
finish_reason = "function_call"
|
finish_reason = "tool_calls"
|
||||||
function_call = FunctionCallResponse(**function_call)
|
function_call_response = FunctionCallResponse(**function_call)
|
||||||
|
function_call_instance = FunctionCall(
|
||||||
|
name=function_call_response.name,
|
||||||
|
arguments=function_call_response.arguments
|
||||||
|
)
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id=f"call_{int(time.time() * 1000)}",
|
||||||
|
function=function_call_instance,
|
||||||
|
type="function")]
|
||||||
|
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=response["text"],
|
content=None if tool_calls else response["text"],
|
||||||
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
function_call=None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"==== message ====\n{message}")
|
logger.debug(f"==== message ====\n{message}")
|
||||||
|
@ -402,11 +422,11 @@ async def predict(model_id: str, params: dict):
|
||||||
previous_text = decoded_unicode
|
previous_text = decoded_unicode
|
||||||
|
|
||||||
finish_reason = new_response["finish_reason"]
|
finish_reason = new_response["finish_reason"]
|
||||||
if len(delta_text) == 0 and finish_reason != "function_call":
|
if len(delta_text) == 0 and finish_reason != "tool_calls":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
function_call = None
|
function_call = None
|
||||||
if finish_reason == "function_call":
|
if finish_reason == "tool_calls":
|
||||||
try:
|
try:
|
||||||
function_call = process_response(decoded_unicode, use_tool=True)
|
function_call = process_response(decoded_unicode, use_tool=True)
|
||||||
except:
|
except:
|
||||||
|
@ -417,9 +437,15 @@ async def predict(model_id: str, params: dict):
|
||||||
function_call = FunctionCallResponse(**function_call)
|
function_call = FunctionCallResponse(**function_call)
|
||||||
|
|
||||||
delta = DeltaMessage(
|
delta = DeltaMessage(
|
||||||
content=delta_text,
|
content=None,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
function_call=None,
|
||||||
|
tool_calls=[{
|
||||||
|
"id": f"call_{int(time.time() * 1000)}",
|
||||||
|
"index": 0,
|
||||||
|
"type": "function",
|
||||||
|
"function": function_call
|
||||||
|
}] if isinstance(function_call, FunctionCallResponse) else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
@ -458,12 +484,11 @@ async def predict_stream(model_id, gen_params):
|
||||||
decoded_unicode = new_response["text"]
|
decoded_unicode = new_response["text"]
|
||||||
delta_text = decoded_unicode[len(output):]
|
delta_text = decoded_unicode[len(output):]
|
||||||
output = decoded_unicode
|
output = decoded_unicode
|
||||||
|
lines = output.strip().split("\n")
|
||||||
|
if not is_function_call and len(lines) >= 2:
|
||||||
|
is_function_call = True
|
||||||
|
|
||||||
if not is_function_call and len(output) > 7:
|
if not is_function_call and len(output) > 7:
|
||||||
is_function_call = output and 'get_' in output
|
|
||||||
if is_function_call:
|
|
||||||
continue
|
|
||||||
|
|
||||||
finish_reason = new_response["finish_reason"]
|
finish_reason = new_response["finish_reason"]
|
||||||
if not has_send_first_chunk:
|
if not has_send_first_chunk:
|
||||||
message = DeltaMessage(
|
message = DeltaMessage(
|
||||||
|
@ -538,9 +563,9 @@ if __name__ == "__main__":
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=0.3,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
worker_use_ray=True,
|
worker_use_ray=False,
|
||||||
engine_use_ray=False,
|
engine_use_ray=False,
|
||||||
disable_log_requests=True,
|
disable_log_requests=True,
|
||||||
max_model_len=MAX_MODEL_LENGTH,
|
max_model_len=MAX_MODEL_LENGTH,
|
||||||
|
|
Loading…
Reference in New Issue