This commit is contained in:
zR 2024-06-09 16:11:23 +08:00
commit adeeb0e8e0
1 changed files with 21 additions and 8 deletions

View File

@ -6,12 +6,12 @@ allowing users to interact with the model through a chat-like interface.
""" """
import os import os
from pathlib import Path
from threading import Thread
from typing import Union
import gradio as gr import gradio as gr
import torch import torch
from threading import Thread
from typing import Union
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
@ -99,10 +99,14 @@ def parse_text(text):
return text return text
def predict(history, max_length, top_p, temperature): def predict(history, prompt, max_length, top_p, temperature):
stop = StopOnTokens() stop = StopOnTokens()
messages = [] messages = []
if prompt:
messages.append({"role": "system", "content": prompt})
for idx, (user_msg, model_msg) in enumerate(history): for idx, (user_msg, model_msg) in enumerate(history):
if prompt and idx == 0:
continue
if idx == len(history) - 1 and not model_msg: if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg}) messages.append({"role": "user", "content": user_msg})
break break
@ -140,11 +144,14 @@ with gr.Blocks() as demo:
chatbot = gr.Chatbot() chatbot = gr.Chatbot()
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=3):
with gr.Column(scale=12): with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False) user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
with gr.Column(min_width=32, scale=1): with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit") submitBtn = gr.Button("Submit")
with gr.Column(scale=1):
prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
pBtn = gr.Button("Set Prompt")
with gr.Column(scale=1): with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History") emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True) max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
@ -156,10 +163,16 @@ with gr.Blocks() as demo:
return "", history + [[parse_text(query), ""]] return "", history + [[parse_text(query), ""]]
def set_prompt(prompt_text):
return [[parse_text(prompt_text), "成功设置prompt"]]
pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, max_length, top_p, temperature], chatbot predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
) )
emptyBtn.click(lambda: None, None, chatbot, queue=False) emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
demo.queue() demo.queue()
demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True) demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True)