From b062d980c840981d9817fa3d6c62576f47599b95 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 18 Jul 2023 17:21:16 +0800 Subject: [PATCH] add web demo Former-commit-id: b447fa85aa563b6105cfa64c9ec802c5bd63af56 --- README.md | 8 +++++ src/llmtuner/webui/chat.py | 4 ++- src/llmtuner/webui/components/chatbot.py | 7 ++-- src/web_demo.py | 44 ++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 4 deletions(-) create mode 100644 src/web_demo.py diff --git a/README.md b/README.md index ddbd148d..e35fdac6 100644 --- a/README.md +++ b/README.md @@ -291,6 +291,14 @@ python src/cli_demo.py \ --checkpoint_dir path_to_checkpoint ``` +### Web Demo + +```bash +python src/web_demo.py \ + --model_name_or_path path_to_your_model \ + --checkpoint_dir path_to_checkpoint +``` + ### Export model ```bash diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index c889eca5..19b74a5a 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -11,10 +11,12 @@ from llmtuner.webui.locales import ALERTS class WebChatModel(ChatModel): - def __init__(self): + def __init__(self, *args): self.model = None self.tokenizer = None self.generating_args = GeneratingArguments() + if len(args) != 0: + super().__init__(*args) def load_model( self, diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index d56dd592..d4a20282 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import gradio as gr from gradio.blocks import Block @@ -8,9 +8,10 @@ from llmtuner.webui.chat import WebChatModel def create_chat_box( - chat_model: WebChatModel + chat_model: WebChatModel, + visible: Optional[bool] = False ) -> Tuple[Block, Component, Component, Dict[str, Component]]: - with gr.Box(visible=False) as chat_box: + with gr.Box(visible=visible) as chat_box: chatbot = gr.Chatbot() with gr.Row(): diff --git a/src/web_demo.py b/src/web_demo.py new file mode 100644 index 00000000..2a05577d --- /dev/null +++ b/src/web_demo.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Implements user interface in browser for fine-tuned models. +# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint + +import gradio as gr +from transformers.utils.versions import require_version + +from llmtuner import get_infer_args +from llmtuner.webui.chat import WebChatModel +from llmtuner.webui.components.chatbot import create_chat_box +from llmtuner.webui.manager import Manager + + +require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") + + +def main(): + chat_model = WebChatModel(*get_infer_args()) + + with gr.Blocks(title="Web Demo") as demo: + lang = gr.Dropdown(choices=["en", "zh"], value="en") + + _, _, _, chat_elems = create_chat_box(chat_model, visible=True) + + manager = Manager([{"lang": lang}, chat_elems]) + + demo.load( + manager.gen_label, + [lang], + [lang] + [elem for elem in chat_elems.values()], + ) + + lang.change( + manager.gen_label, + [lang], + [lang] + [elem for elem in chat_elems.values()], + ) + + demo.queue() + demo.launch(server_name="0.0.0.0", share=False, inbrowser=True) + + +if __name__ == "__main__": + main()