diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 234d924c..5a6c16d3 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -73,14 +73,15 @@ def get_template(model_name: str) -> str: def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: checkpoints = [] - save_dir = get_save_dir(model_name, finetuning_type) - if save_dir and os.path.isdir(save_dir): - for checkpoint in os.listdir(save_dir): - if ( - os.path.isdir(os.path.join(save_dir, checkpoint)) - and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES]) - ): - checkpoints.append(checkpoint) + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for checkpoint in os.listdir(save_dir): + if ( + os.path.isdir(os.path.join(save_dir, checkpoint)) + and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES]) + ): + checkpoints.append(checkpoint) return gr.update(value=[], choices=checkpoints) diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index b163282b..661dfb48 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -19,7 +19,7 @@ class Engine: self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat)) def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): - return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()} + return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()} def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]: user_config = load_config() @@ -42,7 +42,7 @@ class Engine: if not self.pure_chat: if self.runner.alive: - yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()} + yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()} if self.runner.do_train: yield self._form_dict({"train.resume_btn": {"value": True}}) else: diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py index 118bd0a9..ca067aea 100644 --- a/src/llmtuner/webui/manager.py +++ b/src/llmtuner/webui/manager.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Set if TYPE_CHECKING: from gradio.components import Component @@ -9,14 +9,14 @@ class Manager: def __init__(self) -> None: self.all_elems: Dict[str, Dict[str, "Component"]] = {} - def get_elem(self, name: str) -> "Component": + def get_elem_by_name(self, name: str) -> "Component": r""" Example: top.lang, train.dataset """ tab_name, elem_name = name.split(".") return self.all_elems[tab_name][elem_name] - def get_base_elems(self): + def get_base_elems(self) -> Set["Component"]: return { self.all_elems["top"]["lang"], self.all_elems["top"]["model_name"], diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 7ce77836..ab9e9ffc 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -26,12 +26,15 @@ class Runner: def __init__(self, manager: "Manager") -> None: self.manager = manager + """ Resume """ self.thread: "Thread" = None - self.data: Dict["Component", Any] = None self.do_train = True + self.running_data: Dict["Component", Any] = None self.monitor_inputs: Dict[str, str] = None + """ State """ self.aborted = False self.running = False + """ Handler """ self.logger_handler = LoggerHandler() self.logger_handler.setLevel(logging.INFO) logging.root.addHandler(self.logger_handler) @@ -45,7 +48,11 @@ class Runner: self.aborted = True self.running = False - def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str: + def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str: + get = lambda name: data[self.manager.get_elem_by_name(name)] + lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") + dataset = get("train.dataset") if do_train else get("eval.dataset") + if self.running: return ALERTS["err_conflict"][lang] @@ -72,8 +79,8 @@ class Runner: else: return finish_info - def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: - get = lambda name: data[self.manager.get_elem(name)] + def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]: + get = lambda name: data[self.manager.get_elem_by_name(name)] user_config = load_config() if get("top.checkpoints"): @@ -83,8 +90,6 @@ class Runner: else: checkpoint_dir = None - output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")) - args = dict( stage=TRAINING_STAGES[get("train.training_stage")], model_name_or_path=get("top.model_path"), @@ -119,7 +124,7 @@ class Runner: lora_target=get("train.lora_target") or get_module(get("top.model_name")), additional_target=get("train.additional_target") if get("train.additional_target") else None, resume_lora_training=get("train.resume_lora_training"), - output_dir=output_dir + output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")) ) args[get("train.compute_type")] = True args["disable_tqdm"] = True @@ -142,10 +147,10 @@ class Runner: args["eval_steps"] = get("train.save_steps") args["load_best_model_at_end"] = True - return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args + return args - def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: - get = lambda name: data[self.manager.get_elem(name)] + def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]: + get = lambda name: data[self.manager.get_elem_by_name(name)] user_config = load_config() if get("top.checkpoints"): @@ -188,27 +193,28 @@ class Runner: args.pop("do_eval", None) args["do_predict"] = True - return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args + return args def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - parse_func = self._parse_train_args if do_train else self._parse_eval_args - lang, model_name, model_path, dataset, _, args = parse_func(data) - error = self._initialize(lang, model_name, model_path, dataset) + error = self._initialize(data, do_train) if error: + gr.Warning(error) yield error, gr.update(visible=False) else: + args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) yield gen_cmd(args), gr.update(visible=False) def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - parse_func = self._parse_train_args if do_train else self._parse_eval_args - lang, model_name, model_path, dataset, output_dir, args = parse_func(data) - self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir) - error = self._initialize(lang, model_name, model_path, dataset) + error = self._initialize(data, do_train) if error: + gr.Warning(error) yield error, gr.update(visible=False) else: - self.running = True + args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) + self.running = True + self.do_train, self.running_data = do_train, data + self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"]) self.thread = Thread(target=run_exp, kwargs=run_kwargs) self.thread.start() yield from self.monitor()