diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 41d657ff..255f0fd7 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -106,6 +106,7 @@ class WebChatModel(ChatModel): def predict( self, chatbot: List[Tuple[str, str]], + role: str, query: str, messages: Sequence[Tuple[str, str]], system: str, @@ -115,7 +116,7 @@ class WebChatModel(ChatModel): temperature: float, ) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]: chatbot.append([query, ""]) - query_messages = messages + [{"role": Role.USER.value, "content": query}] + query_messages = messages + [{"role": role, "content": query}] response = "" for new_text in self.stream_chat( query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index fe739dd0..802954da 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple import gradio as gr +from ...data import Role from ..utils import check_json_schema @@ -22,6 +23,7 @@ def create_chat_box( with gr.Column(scale=4): system = gr.Textbox(show_label=False) tools = gr.Textbox(show_label=False, lines=2) + role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value) query = gr.Textbox(show_label=False, lines=8) submit_btn = gr.Button(variant="primary") @@ -36,7 +38,7 @@ def create_chat_box( submit_btn.click( engine.chatter.predict, - [chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature], + [chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature], [chatbot, messages], show_progress=True, ).then(lambda: gr.update(value=""), outputs=[query]) @@ -50,6 +52,7 @@ def create_chat_box( dict( system=system, tools=tools, + role=role, query=query, submit_btn=submit_btn, clear_btn=clear_btn, diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 6ad5fc7c..e2a1aa81 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -861,6 +861,17 @@ LOCALES = { "placeholder": "工具列表(非必填)", }, }, + "role": { + "en": { + "label": "Role", + }, + "ru": { + "label": "Роль", + }, + "zh": { + "label": "角色", + }, + }, "query": { "en": { "placeholder": "Input...",