Merge pull request #91 from ztxtech/main

增加设置提示词的块
This commit is contained in:
zR 2024-06-08 20:18:57 +08:00 committed by GitHub
commit 56426f2186
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 8 deletions

View File

@ -6,12 +6,12 @@ allowing users to interact with the model through a chat-like interface.
"""
import os
from pathlib import Path
from threading import Thread
from typing import Union
import gradio as gr
import torch
from threading import Thread
from typing import Union
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
@ -99,10 +99,14 @@ def parse_text(text):
return text
def predict(history, max_length, top_p, temperature):
def predict(history, prompt, max_length, top_p, temperature):
stop = StopOnTokens()
messages = []
if prompt:
messages.append({"role": "system", "content": prompt})
for idx, (user_msg, model_msg) in enumerate(history):
if prompt and idx == 0:
continue
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
@ -140,11 +144,14 @@ with gr.Blocks() as demo:
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=3):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit")
with gr.Column(scale=1):
prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
pBtn = gr.Button("Set Prompt")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
@ -156,10 +163,16 @@ with gr.Blocks() as demo:
return "", history + [[parse_text(query), ""]]
def set_prompt(prompt_text):
return [[parse_text(prompt_text), "成功设置prompt"]]
pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, max_length, top_p, temperature], chatbot
predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
)
emptyBtn.click(lambda: None, None, chatbot, queue=False)
emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
demo.queue()
demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True)