add demo mode for web UI

This commit is contained in:
hiyouga
2023-11-15 23:51:26 +08:00
parent 1e19cf242a
commit 8350bcf85d
10 changed files with 185 additions and 19 deletions

View File

@@ -4,7 +4,7 @@ import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
@@ -24,8 +24,9 @@ if TYPE_CHECKING:
class Runner:
def __init__(self, manager: "Manager") -> None:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.thread: "Thread" = None
self.do_train = True
@@ -46,9 +47,8 @@ class Runner:
def set_abort(self) -> None:
self.aborted = True
self.running = False
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
@@ -65,6 +65,9 @@ class Runner:
if len(dataset) == 0:
return ALERTS["err_no_dataset"][lang]
if self.demo_mode and (not from_preview):
return ALERTS["err_demo"][lang]
self.aborted = False
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
@@ -196,7 +199,7 @@ class Runner:
return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train)
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
@@ -205,14 +208,13 @@ class Runner:
yield gen_cmd(args), gr.update(visible=False)
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train)
error = self._initialize(data, do_train, from_preview=False)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.running = True
self.do_train, self.running_data = do_train, data
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
@@ -232,6 +234,7 @@ class Runner:
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
self.running = True
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
while self.thread.is_alive():
time.sleep(2)