From eadb7c61af83dab6f10a525611842cfc5beaae81 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 15 Oct 2023 03:41:58 +0800 Subject: [PATCH] fix bugs in webui Former-commit-id: fde05cacfc8e669909858587ce5e84380b2e35fb --- src/llmtuner/webui/components/top.py | 6 ------ src/llmtuner/webui/engine.py | 20 +++++++++-------- src/llmtuner/webui/interface.py | 32 +++++++++++++++++++--------- src/llmtuner/webui/locales.py | 8 ------- src/llmtuner/webui/runner.py | 10 ++++----- 5 files changed, 38 insertions(+), 38 deletions(-) diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index ec6fb91e..c831850f 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -37,12 +37,6 @@ def create_top(engine: "Engine") -> Dict[str, "Component"]: shift_attn = gr.Checkbox(value=False) rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none") - lang.change( - engine.change_lang, [lang], engine.manager.list_elems(), queue=False - ).then( - save_config, inputs=[config, lang, model_name, model_path] - ) - model_name.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] ).then( diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index 90beb5e2..77c3373a 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -12,25 +12,27 @@ from llmtuner.webui.utils import get_time class Engine: - def __init__(self, init_chat: Optional[bool] = False) -> None: + def __init__(self, pure_chat: Optional[bool] = False) -> None: + self.pure_chat = pure_chat self.manager: "Manager" = Manager() self.runner: "Runner" = Runner(self.manager) - self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not init_chat)) + self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat)) def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]: lang = config.get("lang", None) or "en" resume_dict = { - "top.config": {"value": config}, "top.lang": {"value": lang}, - "train.dataset": {"choices": list_dataset()["choices"]}, - "eval.dataset": {"choices": list_dataset()["choices"]}, "infer.chat_box": {"visible": self.chatter.loaded} } - if config.get("last_model", None): - resume_dict["top.model_name"] = {"value": config["last_model"]} - resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])} + if not self.pure_chat: + resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]} + resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]} + + if config.get("last_model", None): + resume_dict["top.model_name"] = {"value": config["last_model"]} + resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])} yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()} @@ -42,5 +44,5 @@ class Engine: def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]: return { component: gr.update(**LOCALES[name][lang]) - for elems in self.manager.all_elems.values() for name, component in elems.items() + for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES } diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 85f44040..68344151 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -18,10 +18,14 @@ require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") def create_ui() -> gr.Blocks: - engine = Engine(init_chat=False) + engine = Engine(pure_chat=False) with gr.Blocks(title="Web Tuner", css=CSS) as demo: engine.manager.all_elems["top"] = create_top(engine) + lang: "gr.Dropdown" = engine.manager.get_elem("top.lang") + config = engine.manager.get_elem("top.config") + model_name = engine.manager.get_elem("top.model_name") + model_path = engine.manager.get_elem("top.model_path") with gr.Tab("Train"): engine.manager.all_elems["train"] = create_train_tab(engine) @@ -35,29 +39,37 @@ def create_ui() -> gr.Blocks: with gr.Tab("Export"): engine.manager.all_elems["export"] = create_export_tab(engine) - demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems()) + demo.load(engine.resume, [config], engine.manager.list_elems()) + + lang.change( + engine.change_lang, [lang], engine.manager.list_elems(), queue=False + ).then( + save_config, inputs=[config, lang, model_name, model_path] + ) return demo def create_web_demo() -> gr.Blocks: - engine = Engine(init_chat=True) + engine = Engine(pure_chat=True) with gr.Blocks(title="Web Demo", css=CSS) as demo: - lang = gr.Dropdown(choices=["en", "zh"]) config = gr.State(value=load_config()) + lang = gr.Dropdown(choices=["en", "zh"]) + + engine.manager.all_elems["top"] = dict(config=config, lang=lang) + + chat_box, _, _, chat_elems = create_chat_box(engine, visible=True) + engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems) + + demo.load(engine.resume, [config], engine.manager.list_elems()) + lang.change( engine.change_lang, [lang], engine.manager.list_elems(), queue=False ).then( save_config, inputs=[config, lang] ) - engine.manager.all_elems["top"] = dict(lang=lang) - - _, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True) - - demo.load(engine.resume, [config], engine.manager.list_elems()) - return demo diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 3bfa5329..93005e52 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -1,8 +1,4 @@ LOCALES = { - "config": { - "en": {}, - "zh": {} - }, "lang": { "en": { "label": "Lang" @@ -447,10 +443,6 @@ LOCALES = { "label": "保存预测结果" } }, - "chat_box": { - "en": {}, - "zh": {} - }, "load_btn": { "en": { "value": "Load model" diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 5c629790..c287fe20 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -199,10 +199,10 @@ class Runner: yield gen_cmd(args), gr.update(visible=False) def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - self.prepare(data, self._parse_train_args) + yield from self.prepare(data, self._parse_train_args) def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - self.prepare(data, self._parse_eval_args) + yield from self.prepare(data, self._parse_eval_args) def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: parse_func = self._parse_train_args if is_training else self._parse_eval_args @@ -213,9 +213,9 @@ class Runner: else: self.running = True run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) - thread = Thread(target=run_exp, kwargs=run_kwargs) - thread.start() - yield self.monitor(lang, output_dir, is_training) + self.thread = Thread(target=run_exp, kwargs=run_kwargs) + self.thread.start() + yield from self.monitor(lang, output_dir, is_training) def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: while self.thread.is_alive():