diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 7f7ce08e..7398d424 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -136,7 +136,7 @@ class LogCallback(TrainerCallback): ) if self.runner is not None: logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( - logs["loss"], logs["learning_rate"], logs["epoch"] + logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 )) os.makedirs(args.output_dir, exist_ok=True) diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index aa04ffe3..c8d1530e 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -67,6 +67,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: stop_btn = gr.Button() with gr.Row(): + resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) process_bar = gr.Slider(visible=False, interactive=False) with gr.Box(): @@ -74,11 +75,13 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: output_elems = [output_box, process_bar] elem_dict.update(dict( - cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_box=output_box + cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, + resume_btn=resume_btn, process_bar=process_bar, output_box=output_box )) cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems) start_btn.click(engine.runner.run_eval, input_elems, output_elems) stop_btn.click(engine.runner.set_abort, queue=False) + resume_btn.change(engine.runner.monitor, outputs=output_elems) return elem_dict diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 8f9cdd86..8e6dc683 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -141,7 +141,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: output_elems = [output_box, process_bar] elem_dict.update(dict( cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir, - resume_btn=resume_btn, output_box=output_box, loss_viewer=loss_viewer, process_bar=process_bar + resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer )) cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems) diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index 0819c735..03d06144 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -39,9 +39,12 @@ class Engine: yield self._form_dict(resume_dict) - if self.runner.alive: # TODO: resume eval + if self.runner.alive: yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()} - resume_dict = {"train.resume_btn": {"value": True}} + if self.runner.do_train: + resume_dict = {"train.resume_btn": {"value": True}} + else: + resume_dict = {"eval.resume_btn": {"value": True}} else: resume_dict = {"train.output_dir": {"value": get_time()}} yield self._form_dict(resume_dict) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 42acd95b..89cb56a0 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -185,29 +185,16 @@ class Runner: return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args - def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - lang, model_name, model_path, dataset, _, args = self._parse_train_args(data) + def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + parse_func = self._parse_train_args if do_train else self._parse_eval_args + lang, model_name, model_path, dataset, _, args = parse_func(data) error = self._initialize(lang, model_name, model_path, dataset) if error: yield error, gr.update(visible=False) else: yield gen_cmd(args), gr.update(visible=False) - def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - lang, model_name, model_path, dataset, _, args = self._parse_eval_args(data) - error = self._initialize(lang, model_name, model_path, dataset) - if error: - yield error, gr.update(visible=False) - else: - yield gen_cmd(args), gr.update(visible=False) - - def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - yield from self.prepare(data, do_train=True) - - def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - yield from self.prepare(data, do_train=False) - - def prepare(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: parse_func = self._parse_train_args if do_train else self._parse_eval_args lang, model_name, model_path, dataset, output_dir, args = parse_func(data) self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir) @@ -221,6 +208,18 @@ class Runner: self.thread.start() yield from self.monitor() + def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + yield from self._preview(data, do_train=True) + + def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + yield from self._preview(data, do_train=False) + + def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + yield from self._launch(data, do_train=True) + + def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + yield from self._launch(data, do_train=False) + def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]: lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"] while self.thread.is_alive():