add web demo

Former-commit-id: b447fa85aa563b6105cfa64c9ec802c5bd63af56
This commit is contained in:
hiyouga 2023-07-18 17:21:16 +08:00
parent 35d52c4100
commit b062d980c8
4 changed files with 59 additions and 4 deletions

View File

@ -291,6 +291,14 @@ python src/cli_demo.py \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```
### Export model
```bash

View File

@ -11,10 +11,12 @@ from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self):
def __init__(self, *args):
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if len(args) != 0:
super().__init__(*args)
def load_model(
self,

View File

@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
import gradio as gr
from gradio.blocks import Block
@ -8,9 +8,10 @@ from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel
chat_model: WebChatModel,
visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
with gr.Box(visible=False) as chat_box:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
with gr.Row():

44
src/web_demo.py Normal file
View File

@ -0,0 +1,44 @@
# coding=utf-8
# Implements user interface in browser for fine-tuned models.
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner import get_infer_args
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def main():
chat_model = WebChatModel(*get_infer_args())
with gr.Blocks(title="Web Demo") as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
lang.change(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
if __name__ == "__main__":
main()