diff --git a/src/api_demo.py b/src/api_demo.py index f7649e7b..c0ca9760 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -1,8 +1,3 @@ -# coding=utf-8 -# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) -# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint -# Visit http://localhost:8000/docs for document. - import uvicorn from llmtuner import ChatModel, create_app @@ -12,6 +7,7 @@ def main(): chat_model = ChatModel() app = create_app(chat_model) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) + # Visit http://localhost:8000/docs for document. if __name__ == "__main__": diff --git a/src/cli_demo.py b/src/cli_demo.py index 2d0aff7e..05e8c8eb 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -1,7 +1,3 @@ -# coding=utf-8 -# Implements stream chat in command line for fine-tuned models. -# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint - from llmtuner import ChatModel diff --git a/src/export_model.py b/src/export_model.py index a0a86996..4baeb2c3 100644 --- a/src/export_model.py +++ b/src/export_model.py @@ -1,7 +1,3 @@ -# coding=utf-8 -# Exports the fine-tuned model. -# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model - from llmtuner import export_model diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 04ffbb67..e647b92b 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -3,7 +3,7 @@ from llmtuner.api import create_app from llmtuner.chat import ChatModel from llmtuner.tuner import export_model, run_exp -from llmtuner.webui import Manager, WebChatModel, create_ui, create_chat_box +from llmtuner.webui import create_ui, create_web_demo __version__ = "0.1.5" diff --git a/src/llmtuner/webui/__init__.py b/src/llmtuner/webui/__init__.py index 8544957c..a27c7f6e 100644 --- a/src/llmtuner/webui/__init__.py +++ b/src/llmtuner/webui/__init__.py @@ -1,4 +1 @@ -from llmtuner.webui.chat import WebChatModel -from llmtuner.webui.interface import create_ui -from llmtuner.webui.manager import Manager -from llmtuner.webui.components import create_chat_box +from llmtuner.webui.interface import create_ui, create_web_demo diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index 01e6f0e4..d0eb61df 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -10,11 +10,12 @@ from llmtuner.webui.locales import ALERTS class WebChatModel(ChatModel): - def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: - self.model = None - self.tokenizer = None - self.generating_args = GeneratingArguments() - if args is not None: + def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None: + if lazy_init: + self.model = None + self.tokenizer = None + self.generating_args = GeneratingArguments() + else: super().__init__(args) def load_model( diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index afc50d6e..2fb61d37 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -6,8 +6,10 @@ from llmtuner.webui.components import ( create_sft_tab, create_eval_tab, create_infer_tab, - create_export_tab + create_export_tab, + create_chat_box ) +from llmtuner.webui.chat import WebChatModel from llmtuner.webui.css import CSS from llmtuner.webui.manager import Manager from llmtuner.webui.runner import Runner @@ -53,6 +55,23 @@ def create_ui() -> gr.Blocks: return demo +def create_web_demo() -> gr.Blocks: + chat_model = WebChatModel(lazy_init=False) + + with gr.Blocks(title="Web Demo", css=CSS) as demo: + lang = gr.Dropdown(choices=["en", "zh"], value="en") + + _, _, _, chat_elems = create_chat_box(chat_model, visible=True) + + manager = Manager([{"lang": lang}, chat_elems]) + + demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values())) + + lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values())) + + return demo + + if __name__ == "__main__": demo = create_ui() demo.queue() diff --git a/src/web_demo.py b/src/web_demo.py index 112bac8b..257536ab 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -1,30 +1,8 @@ -# coding=utf-8 -# Implements user interface in browser for fine-tuned models. -# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint - -import gradio as gr -from transformers.utils.versions import require_version - -from llmtuner import Manager, WebChatModel, create_chat_box - - -require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") +from llmtuner import create_web_demo def main(): - chat_model = WebChatModel() - - with gr.Blocks(title="Web Demo") as demo: - lang = gr.Dropdown(choices=["en", "zh"], value="en") - - _, _, _, chat_elems = create_chat_box(chat_model, visible=True) - - manager = Manager([{"lang": lang}, chat_elems]) - - demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values())) - - lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values())) - + demo = create_web_demo() demo.queue() demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)