fix bug in latest gradio

This commit is contained in:
hiyouga
2024-04-04 00:55:31 +08:00
parent 7f6e412604
commit 5ddcecda50
8 changed files with 111 additions and 204 deletions

View File

@@ -2,7 +2,7 @@ import logging
import os
import time
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator
import gradio as gr
import transformers
@@ -17,7 +17,7 @@ from ..extras.misc import get_device_count, torch_gc
from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
if TYPE_CHECKING:
@@ -239,20 +239,22 @@ class Runner:
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
yield error, gr.Slider(visible=False)
yield {output_box: error}
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield gen_cmd(args), gr.Slider(visible=False)
yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False)
if error:
gr.Warning(error)
yield error, gr.Slider(visible=False)
yield {output_box: error}
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
@@ -261,54 +263,80 @@ class Runner:
self.thread.start()
yield from self.monitor()
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
yield from self._preview(data, do_train=True)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
yield from self._preview(data, do_train=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
yield from self._launch(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, "gr.Slider"], None, None]:
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(
get("top.model_name"),
get("top.finetuning_type"),
get("{}.output_dir".format("train" if self.do_train else "eval")),
)
model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type")
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
output_path = get_save_dir(model_name, finetuning_type, output_dir)
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
while self.thread is not None and self.thread.is_alive():
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.Slider(visible=False)
yield {
output_box: ALERTS["info_aborting"][lang],
process_bar: gr.Slider(visible=False),
}
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
return_dict = {
output_box: self.logger_handler.log,
process_bar: update_process_bar(self.trainer_callback),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
yield return_dict
time.sleep(2)
if self.do_train:
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
else:
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
if os.path.exists(os.path.join(output_path, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
return_dict = {
output_box: self._finalize(lang, finish_info),
process_bar: gr.Slider(visible=False),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]:
yield return_dict
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:
gr.Warning(error)
return error, gr.Slider(visible=False)
return {output_box: error}
config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")]
@@ -320,15 +348,16 @@ class Runner:
config_dict[elem_id] = value
save_path = save_args(config_path, config_dict)
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False)
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(config_path)
if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang])
return {self.manager.get_elem_by_id("top.lang"): lang}
return {output_box: ALERTS["err_config_not_found"][lang]}
output_dict: Dict["Component", Any] = {}
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value