2024-06-05 10:22:16 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
This demo show the All tools and Long Context chat Capabilities of GLM-4.
|
|
|
|
|
Please follow the Readme.md to run the demo.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import traceback
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
|
|
import streamlit as st
|
|
|
|
|
from streamlit.delta_generator import DeltaGenerator
|
|
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
from client import Client, ClientType, get_client
|
|
|
|
|
from conversation import (
|
|
|
|
|
FILE_TEMPLATE,
|
|
|
|
|
Conversation,
|
|
|
|
|
Role,
|
|
|
|
|
postprocess_text,
|
|
|
|
|
response_to_str,
|
|
|
|
|
)
|
|
|
|
|
from tools.tool_registry import dispatch_tool, get_tools
|
|
|
|
|
from utils import extract_pdf, extract_docx, extract_pptx, extract_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
|
|
|
|
|
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
|
|
|
|
|
|
|
|
|
|
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
|
2024-06-12 23:59:22 +08:00
|
|
|
|
USE_API = os.environ.get("USE_API", "0") == "1"
|
2024-06-05 10:22:16 +08:00
|
|
|
|
|
|
|
|
|
class Mode(str, Enum):
|
|
|
|
|
ALL_TOOLS = "🛠️ All Tools"
|
|
|
|
|
LONG_CTX = "📝 文档解读"
|
|
|
|
|
VLM = "🖼️ 多模态"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_conversation(
|
|
|
|
|
conversation: Conversation,
|
|
|
|
|
history: list[Conversation],
|
|
|
|
|
placeholder: DeltaGenerator | None = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Append a conversation piece into history, meanwhile show it in a new markdown block
|
|
|
|
|
"""
|
|
|
|
|
history.append(conversation)
|
|
|
|
|
conversation.show(placeholder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(
|
|
|
|
|
page_title="GLM-4 Demo",
|
|
|
|
|
page_icon=":robot:",
|
|
|
|
|
layout="centered",
|
|
|
|
|
initial_sidebar_state="expanded",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
st.title("GLM-4 Demo")
|
|
|
|
|
st.markdown(
|
|
|
|
|
"<sub>智谱AI 公开在线技术文档: https://zhipu-ai.feishu.cn/wiki/RuMswanpkiRh3Ok4z5acOABBnjf </sub> \n\n <sub> 更多 GLM-4 开源模型的使用方法请参考文档。</sub>",
|
|
|
|
|
unsafe_allow_html=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with st.sidebar:
|
|
|
|
|
top_p = st.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
|
|
|
|
|
top_k = st.slider("top_k", 1, 20, 10, step=1, key="top_k")
|
|
|
|
|
temperature = st.slider("temperature", 0.0, 1.5, 0.95, step=0.01)
|
|
|
|
|
repetition_penalty = st.slider("repetition_penalty", 0.0, 2.0, 1.0, step=0.01)
|
|
|
|
|
max_new_tokens = st.slider("max_new_tokens", 1, 4096, 2048, step=1)
|
|
|
|
|
cols = st.columns(2)
|
|
|
|
|
export_btn = cols[0]
|
|
|
|
|
clear_history = cols[1].button("Clear", use_container_width=True)
|
|
|
|
|
retry = export_btn.button("Retry", use_container_width=True)
|
|
|
|
|
|
|
|
|
|
if clear_history:
|
|
|
|
|
page = st.session_state.page
|
|
|
|
|
client = st.session_state.client
|
|
|
|
|
st.session_state.clear()
|
|
|
|
|
st.session_state.page = page
|
|
|
|
|
st.session_state.client = client
|
|
|
|
|
st.session_state.files_uploaded = False
|
|
|
|
|
st.session_state.uploaded_texts = ""
|
|
|
|
|
st.session_state.uploaded_file_nums = 0
|
|
|
|
|
st.session_state.history = []
|
|
|
|
|
|
|
|
|
|
if "files_uploaded" not in st.session_state:
|
|
|
|
|
st.session_state.files_uploaded = False
|
|
|
|
|
|
|
|
|
|
if "session_id" not in st.session_state:
|
|
|
|
|
st.session_state.session_id = uuid4()
|
|
|
|
|
|
|
|
|
|
if "history" not in st.session_state:
|
|
|
|
|
st.session_state.history = []
|
|
|
|
|
|
|
|
|
|
first_round = len(st.session_state.history) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_client(mode: Mode) -> Client:
|
|
|
|
|
match mode:
|
|
|
|
|
case Mode.ALL_TOOLS:
|
|
|
|
|
st.session_state.top_k = 10
|
|
|
|
|
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
|
2024-06-12 23:59:22 +08:00
|
|
|
|
typ = ClientType.API if USE_API else typ
|
2024-06-05 10:22:16 +08:00
|
|
|
|
return get_client(CHAT_MODEL_PATH, typ)
|
|
|
|
|
case Mode.LONG_CTX:
|
|
|
|
|
st.session_state.top_k = 10
|
|
|
|
|
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
|
|
|
|
|
return get_client(CHAT_MODEL_PATH, typ)
|
|
|
|
|
case Mode.VLM:
|
|
|
|
|
st.session_state.top_k = 1
|
|
|
|
|
# vLLM is not available for VLM mode
|
|
|
|
|
return get_client(VLM_MODEL_PATH, ClientType.HF)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Callback function for page change
|
|
|
|
|
def page_changed() -> None:
|
|
|
|
|
global client
|
|
|
|
|
new_page: str = st.session_state.page
|
|
|
|
|
st.session_state.history.clear()
|
|
|
|
|
st.session_state.client = build_client(Mode(new_page))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
page = st.radio(
|
|
|
|
|
"选择功能",
|
|
|
|
|
[mode.value for mode in Mode],
|
|
|
|
|
key="page",
|
|
|
|
|
horizontal=True,
|
|
|
|
|
index=None,
|
|
|
|
|
label_visibility="hidden",
|
|
|
|
|
on_change=page_changed,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
HELP = """
|
|
|
|
|
### 🎉 欢迎使用 GLM-4!
|
|
|
|
|
|
|
|
|
|
请在上方选取一个功能。每次切换功能时,将会重新加载模型并清空对话历史。
|
|
|
|
|
|
|
|
|
|
文档解读模式与 VLM 模式仅支持在第一轮传入文档或图像。
|
|
|
|
|
""".strip()
|
|
|
|
|
|
|
|
|
|
if page is None:
|
|
|
|
|
st.markdown(HELP)
|
|
|
|
|
exit()
|
|
|
|
|
|
|
|
|
|
if page == Mode.LONG_CTX:
|
|
|
|
|
if first_round:
|
|
|
|
|
uploaded_files = st.file_uploader(
|
|
|
|
|
"上传文件",
|
|
|
|
|
type=["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"],
|
|
|
|
|
accept_multiple_files=True,
|
|
|
|
|
)
|
|
|
|
|
if uploaded_files and not st.session_state.files_uploaded:
|
|
|
|
|
uploaded_texts = []
|
|
|
|
|
for uploaded_file in uploaded_files:
|
|
|
|
|
file_name: str = uploaded_file.name
|
|
|
|
|
random_file_name = str(uuid4())
|
|
|
|
|
file_extension = os.path.splitext(file_name)[1]
|
|
|
|
|
file_path = os.path.join("/tmp", random_file_name + file_extension)
|
|
|
|
|
with open(file_path, "wb") as f:
|
|
|
|
|
f.write(uploaded_file.getbuffer())
|
|
|
|
|
if file_name.endswith(".pdf"):
|
|
|
|
|
content = extract_pdf(file_path)
|
|
|
|
|
elif file_name.endswith(".docx"):
|
|
|
|
|
content = extract_docx(file_path)
|
|
|
|
|
elif file_name.endswith(".pptx"):
|
|
|
|
|
content = extract_pptx(file_path)
|
|
|
|
|
else:
|
|
|
|
|
content = extract_text(file_path)
|
|
|
|
|
uploaded_texts.append(
|
|
|
|
|
FILE_TEMPLATE.format(file_name=file_name, file_content=content)
|
|
|
|
|
)
|
|
|
|
|
os.remove(file_path)
|
|
|
|
|
st.session_state.uploaded_texts = "\n\n".join(uploaded_texts)
|
|
|
|
|
st.session_state.uploaded_file_nums = len(uploaded_files)
|
|
|
|
|
else:
|
|
|
|
|
st.session_state.uploaded_texts = ""
|
|
|
|
|
st.session_state.uploaded_file_nums = 0
|
|
|
|
|
elif page == Mode.VLM:
|
|
|
|
|
if first_round:
|
|
|
|
|
uploaded_image = st.file_uploader(
|
|
|
|
|
"上传图片",
|
|
|
|
|
type=["png", "jpg", "jpeg", "bmp", "tiff", "webp"],
|
|
|
|
|
accept_multiple_files=False,
|
|
|
|
|
)
|
|
|
|
|
if uploaded_image:
|
|
|
|
|
data: bytes = uploaded_image.read()
|
|
|
|
|
image = Image.open(BytesIO(data)).convert("RGB")
|
|
|
|
|
st.session_state.uploaded_image = image
|
|
|
|
|
else:
|
|
|
|
|
st.session_state.uploaded_image = None
|
|
|
|
|
|
|
|
|
|
prompt_text = st.chat_input("Chat with GLM-4!", key="chat_input")
|
|
|
|
|
|
|
|
|
|
if prompt_text == "" and retry == False:
|
|
|
|
|
print("\n== Clean ==\n")
|
|
|
|
|
st.session_state.history = []
|
|
|
|
|
exit()
|
|
|
|
|
|
|
|
|
|
history: list[Conversation] = st.session_state.history
|
|
|
|
|
|
|
|
|
|
if retry:
|
|
|
|
|
print("\n== Retry ==\n")
|
|
|
|
|
last_user_conversation_idx = None
|
|
|
|
|
for idx, conversation in enumerate(history):
|
|
|
|
|
if conversation.role.value == Role.USER.value:
|
|
|
|
|
last_user_conversation_idx = idx
|
|
|
|
|
if last_user_conversation_idx is not None:
|
|
|
|
|
prompt_text = history[last_user_conversation_idx].content
|
|
|
|
|
print(f"New prompt: {prompt_text}, idx = {last_user_conversation_idx}")
|
|
|
|
|
del history[last_user_conversation_idx:]
|
|
|
|
|
|
|
|
|
|
for conversation in history:
|
|
|
|
|
conversation.show()
|
|
|
|
|
|
|
|
|
|
tools = get_tools() if page == Mode.ALL_TOOLS else []
|
|
|
|
|
|
|
|
|
|
client: Client = st.session_state.client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(prompt_text: str):
|
|
|
|
|
global client
|
|
|
|
|
assert client is not None
|
|
|
|
|
|
|
|
|
|
if prompt_text:
|
|
|
|
|
prompt_text = prompt_text.strip()
|
|
|
|
|
|
|
|
|
|
# Append uploaded files
|
|
|
|
|
uploaded_texts = st.session_state.get("uploaded_texts")
|
|
|
|
|
if page == Mode.LONG_CTX and uploaded_texts and first_round:
|
|
|
|
|
meta_msg = "{} files uploaded.\n".format(
|
|
|
|
|
st.session_state.uploaded_file_nums
|
|
|
|
|
)
|
|
|
|
|
prompt_text = uploaded_texts + "\n\n\n" + meta_msg + prompt_text
|
|
|
|
|
# Clear after first use
|
|
|
|
|
st.session_state.files_uploaded = True
|
|
|
|
|
st.session_state.uploaded_texts = ""
|
|
|
|
|
st.session_state.uploaded_file_nums = 0
|
|
|
|
|
|
|
|
|
|
image = st.session_state.get("uploaded_image")
|
|
|
|
|
if page == Mode.VLM and image and first_round:
|
|
|
|
|
st.session_state.uploaded_image = None
|
|
|
|
|
|
|
|
|
|
role = Role.USER
|
|
|
|
|
append_conversation(Conversation(role, prompt_text, image=image), history)
|
|
|
|
|
|
|
|
|
|
placeholder = st.container()
|
|
|
|
|
message_placeholder = placeholder.chat_message(
|
|
|
|
|
name="assistant", avatar="assistant"
|
|
|
|
|
)
|
|
|
|
|
markdown_placeholder = message_placeholder.empty()
|
|
|
|
|
|
|
|
|
|
def add_new_block():
|
|
|
|
|
nonlocal message_placeholder, markdown_placeholder
|
|
|
|
|
message_placeholder = placeholder.chat_message(
|
|
|
|
|
name="assistant", avatar="assistant"
|
|
|
|
|
)
|
|
|
|
|
markdown_placeholder = message_placeholder.empty()
|
|
|
|
|
|
|
|
|
|
def commit_conversation(
|
|
|
|
|
role: Role,
|
|
|
|
|
text: str,
|
|
|
|
|
metadata: str | None = None,
|
|
|
|
|
image: str | None = None,
|
|
|
|
|
new: bool = False,
|
|
|
|
|
):
|
|
|
|
|
processed_text = postprocess_text(text, role.value == Role.ASSISTANT.value)
|
|
|
|
|
conversation = Conversation(role, text, processed_text, metadata, image)
|
|
|
|
|
|
|
|
|
|
# Use different placeholder for new block
|
|
|
|
|
placeholder = message_placeholder if new else markdown_placeholder
|
|
|
|
|
|
|
|
|
|
append_conversation(
|
|
|
|
|
conversation,
|
|
|
|
|
history,
|
|
|
|
|
placeholder,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
response = ""
|
|
|
|
|
for _ in range(10):
|
|
|
|
|
last_response = None
|
|
|
|
|
history_len = None
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for response, chat_history in client.generate_stream(
|
|
|
|
|
tools=tools,
|
|
|
|
|
history=history,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
top_p=top_p,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
|
|
):
|
|
|
|
|
if history_len is None:
|
|
|
|
|
history_len = len(chat_history)
|
|
|
|
|
elif history_len != len(chat_history):
|
|
|
|
|
commit_conversation(Role.ASSISTANT, last_response)
|
|
|
|
|
add_new_block()
|
|
|
|
|
history_len = len(chat_history)
|
|
|
|
|
last_response = response
|
|
|
|
|
replace_quote = chat_history[-1]["role"] == "assistant"
|
|
|
|
|
markdown_placeholder.markdown(
|
|
|
|
|
postprocess_text(
|
|
|
|
|
str(response) + "●", replace_quote=replace_quote
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
metadata = (
|
|
|
|
|
page == Mode.ALL_TOOLS
|
|
|
|
|
and isinstance(response, dict)
|
|
|
|
|
and response.get("name")
|
|
|
|
|
or None
|
|
|
|
|
)
|
|
|
|
|
role = Role.TOOL if metadata else Role.ASSISTANT
|
|
|
|
|
text = (
|
|
|
|
|
response.get("content")
|
|
|
|
|
if metadata
|
|
|
|
|
else response_to_str(response)
|
|
|
|
|
)
|
|
|
|
|
commit_conversation(role, text, metadata)
|
|
|
|
|
if metadata:
|
|
|
|
|
add_new_block()
|
|
|
|
|
try:
|
|
|
|
|
with markdown_placeholder:
|
|
|
|
|
with st.spinner(f"Calling tool {metadata}..."):
|
|
|
|
|
observations = dispatch_tool(
|
|
|
|
|
metadata, text, str(st.session_state.session_id)
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
st.error(f'Uncaught exception in `"{metadata}"`: {e}')
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
for observation in observations:
|
|
|
|
|
observation.text = observation.text
|
|
|
|
|
commit_conversation(
|
|
|
|
|
Role.OBSERVATION,
|
|
|
|
|
observation.text,
|
|
|
|
|
observation.role_metadata,
|
|
|
|
|
observation.image_url,
|
|
|
|
|
new=True,
|
|
|
|
|
)
|
|
|
|
|
add_new_block()
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
st.error(f"Uncaught exception: {traceback.format_exc()}")
|
|
|
|
|
else:
|
|
|
|
|
st.error("Too many chaining function calls!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main(prompt_text)
|