mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix eval resuming in webui
Former-commit-id: 273745f9b9d117d4053afc1746108af95b0a51a4
This commit is contained in:
parent
99592478c9
commit
0503d45782
@ -136,7 +136,7 @@ class LogCallback(TrainerCallback):
|
|||||||
)
|
)
|
||||||
if self.runner is not None:
|
if self.runner is not None:
|
||||||
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||||
logs["loss"], logs["learning_rate"], logs["epoch"]
|
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||||
))
|
))
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
@ -67,6 +67,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
stop_btn = gr.Button()
|
stop_btn = gr.Button()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||||
process_bar = gr.Slider(visible=False, interactive=False)
|
process_bar = gr.Slider(visible=False, interactive=False)
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
@ -74,11 +75,13 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
|
|
||||||
output_elems = [output_box, process_bar]
|
output_elems = [output_box, process_bar]
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(
|
||||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_box=output_box
|
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
|
||||||
|
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
|
||||||
))
|
))
|
||||||
|
|
||||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||||
|
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
@ -141,7 +141,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
output_elems = [output_box, process_bar]
|
output_elems = [output_box, process_bar]
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(
|
||||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
||||||
resume_btn=resume_btn, output_box=output_box, loss_viewer=loss_viewer, process_bar=process_bar
|
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
|
||||||
))
|
))
|
||||||
|
|
||||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
||||||
|
@ -39,9 +39,12 @@ class Engine:
|
|||||||
|
|
||||||
yield self._form_dict(resume_dict)
|
yield self._form_dict(resume_dict)
|
||||||
|
|
||||||
if self.runner.alive: # TODO: resume eval
|
if self.runner.alive:
|
||||||
yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
|
yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
|
||||||
resume_dict = {"train.resume_btn": {"value": True}}
|
if self.runner.do_train:
|
||||||
|
resume_dict = {"train.resume_btn": {"value": True}}
|
||||||
|
else:
|
||||||
|
resume_dict = {"eval.resume_btn": {"value": True}}
|
||||||
else:
|
else:
|
||||||
resume_dict = {"train.output_dir": {"value": get_time()}}
|
resume_dict = {"train.output_dir": {"value": get_time()}}
|
||||||
yield self._form_dict(resume_dict)
|
yield self._form_dict(resume_dict)
|
||||||
|
@ -185,29 +185,16 @@ class Runner:
|
|||||||
|
|
||||||
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
|
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
|
||||||
|
|
||||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
lang, model_name, model_path, dataset, _, args = self._parse_train_args(data)
|
parse_func = self._parse_train_args if do_train else self._parse_eval_args
|
||||||
|
lang, model_name, model_path, dataset, _, args = parse_func(data)
|
||||||
error = self._initialize(lang, model_name, model_path, dataset)
|
error = self._initialize(lang, model_name, model_path, dataset)
|
||||||
if error:
|
if error:
|
||||||
yield error, gr.update(visible=False)
|
yield error, gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
yield gen_cmd(args), gr.update(visible=False)
|
yield gen_cmd(args), gr.update(visible=False)
|
||||||
|
|
||||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
lang, model_name, model_path, dataset, _, args = self._parse_eval_args(data)
|
|
||||||
error = self._initialize(lang, model_name, model_path, dataset)
|
|
||||||
if error:
|
|
||||||
yield error, gr.update(visible=False)
|
|
||||||
else:
|
|
||||||
yield gen_cmd(args), gr.update(visible=False)
|
|
||||||
|
|
||||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
||||||
yield from self.prepare(data, do_train=True)
|
|
||||||
|
|
||||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
||||||
yield from self.prepare(data, do_train=False)
|
|
||||||
|
|
||||||
def prepare(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
||||||
parse_func = self._parse_train_args if do_train else self._parse_eval_args
|
parse_func = self._parse_train_args if do_train else self._parse_eval_args
|
||||||
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
|
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
|
||||||
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
|
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
|
||||||
@ -221,6 +208,18 @@ class Runner:
|
|||||||
self.thread.start()
|
self.thread.start()
|
||||||
yield from self.monitor()
|
yield from self.monitor()
|
||||||
|
|
||||||
|
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
|
yield from self._preview(data, do_train=True)
|
||||||
|
|
||||||
|
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
|
yield from self._preview(data, do_train=False)
|
||||||
|
|
||||||
|
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
|
yield from self._launch(data, do_train=True)
|
||||||
|
|
||||||
|
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
|
yield from self._launch(data, do_train=False)
|
||||||
|
|
||||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
||||||
while self.thread.is_alive():
|
while self.thread.is_alive():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user