diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 479846ca..dac7dd67 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -1,6 +1,6 @@ import json import os -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple from ..chat import ChatModel from ..data import Role @@ -17,7 +17,6 @@ if TYPE_CHECKING: if is_gradio_available(): import gradio as gr - from gradio.components import Component # cannot use TYPE_CHECKING here class WebChatModel(ChatModel): @@ -38,7 +37,7 @@ class WebChatModel(ChatModel): def loaded(self) -> bool: return self.engine is not None - def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: + def load_model(self, data) -> Generator[str, None, None]: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] lang = get("top.lang") error = "" @@ -82,7 +81,7 @@ class WebChatModel(ChatModel): yield ALERTS["info_loaded"][lang] - def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: + def unload_model(self, data) -> Generator[str, None, None]: lang = data[self.manager.get_elem_by_id("top.lang")] if self.demo_mode: diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index 65945533..b9ee61d2 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -1,6 +1,5 @@ -from typing import Any, Dict, Generator +from typing import TYPE_CHECKING, Any, Dict -from ..extras.packages import is_gradio_available from .chatter import WebChatModel from .common import get_model_path, list_dataset, load_config from .locales import LOCALES @@ -9,8 +8,8 @@ from .runner import Runner from .utils import get_time -if is_gradio_available(): - from gradio.components import Component # cannot use TYPE_CHECKING here +if TYPE_CHECKING: + from gradio.components import Component class Engine: @@ -32,7 +31,7 @@ class Engine: return output_dict - def resume(self) -> Generator[Dict[Component, Component], None, None]: + def resume(self): user_config = load_config() if not self.demo_mode else {} lang = user_config.get("lang", None) or "en" @@ -58,7 +57,7 @@ class Engine: else: yield self._update_component({"eval.resume_btn": {"value": True}}) - def change_lang(self, lang: str) -> Dict[Component, Component]: + def change_lang(self, lang: str): return { elem: elem.__class__(**LOCALES[elem_name][lang]) for elem_name, elem in self.manager.get_elem_iter() diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 12307234..ec493c96 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -21,10 +21,11 @@ from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar if is_gradio_available(): import gradio as gr - from gradio.components import Component # cannot use TYPE_CHECKING here if TYPE_CHECKING: + from gradio.components import Component + from .manager import Manager @@ -243,7 +244,7 @@ class Runner: return args - def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]: + def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) error = self._initialize(data, do_train, from_preview=True) if error: @@ -253,7 +254,7 @@ class Runner: args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) yield {output_box: gen_cmd(args)} - def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]: + def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]: output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) error = self._initialize(data, do_train, from_preview=False) if error: @@ -267,19 +268,19 @@ class Runner: self.thread.start() yield from self.monitor() - def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]: + def preview_train(self, data): yield from self._preview(data, do_train=True) - def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]: + def preview_eval(self, data): yield from self._preview(data, do_train=False) - def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]: + def run_train(self, data): yield from self._launch(data, do_train=True) - def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]: + def run_eval(self, data): yield from self._launch(data, do_train=False) - def monitor(self) -> Generator[Dict[Component, Any], None, None]: + def monitor(self): get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] self.aborted = False self.running = True @@ -336,7 +337,7 @@ class Runner: yield return_dict - def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]: + def save_args(self, data): output_box = self.manager.get_elem_by_id("train.output_box") error = self._initialize(data, do_train=True, from_preview=True) if error: @@ -355,7 +356,7 @@ class Runner: save_path = save_args(config_path, config_dict) return {output_box: ALERTS["info_config_saved"][lang] + save_path} - def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]: + def load_args(self, lang: str, config_path: str): output_box = self.manager.get_elem_by_id("train.output_box") config_dict = load_args(config_path) if config_dict is None: