修正role

This commit is contained in:
Tianxiang Zhan 2024-06-08 16:12:58 +08:00 committed by GitHub
parent f4fc0a316e
commit 76fff757b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 25 deletions

View File

@ -6,12 +6,12 @@ allowing users to interact with the model through a chat-like interface.
"""
import os
from pathlib import Path
from threading import Thread
from typing import Union
import gradio as gr
import torch
from threading import Thread
from typing import Union
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
@ -99,10 +99,11 @@ def parse_text(text):
return text
def predict(history, max_length, top_p, temperature):
def predict(history, prompt, max_length, top_p, temperature):
stop = StopOnTokens()
messages = []
if prompt:
messages.append({"role": "system", "content": prompt})
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
@ -160,29 +161,16 @@ with gr.Blocks() as demo:
return "", history + [[parse_text(query), ""]]
def set_prompt(prompt_text, history):
"""
Sets the initial prompt for the chat session.
Parameters:
- prompt_text (str): The text of the prompt from the prompt textbox.
- history (list): Current chat history.
Returns:
- list: Updated chat history with the new prompt at the beginning.
"""
# Clear any existing history and add the prompt as the first message
return [[parse_text(prompt_text), ""]]
def set_prompt(prompt_text):
return [[parse_text(prompt_text), "成功设置prompt"]]
# Connect the 'Set Prompt' button click event to the 'set_prompt' function
pBtn.click(set_prompt, inputs=[prompt_input, chatbot], outputs=chatbot)
pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
pBtn.click()
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot, max_length, top_p, temperature], chatbot
predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
)
emptyBtn.click(lambda: None, None, chatbot, queue=False)
emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
demo.queue()
demo.launch(server_name="127.0.0.1", server_port=8002, inbrowser=True, share=True)
demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True)