mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix #3366
This commit is contained in:
@@ -21,10 +21,11 @@ from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
@@ -243,7 +244,7 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
@@ -253,7 +254,7 @@ class Runner:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield {output_box: gen_cmd(args)}
|
||||
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
@@ -267,19 +268,19 @@ class Runner:
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
||||
def preview_train(self, data):
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
||||
def preview_eval(self, data):
|
||||
yield from self._preview(data, do_train=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
||||
def run_train(self, data):
|
||||
yield from self._launch(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
||||
def run_eval(self, data):
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
|
||||
def monitor(self):
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
@@ -336,7 +337,7 @@ class Runner:
|
||||
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
|
||||
def save_args(self, data):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
@@ -355,7 +356,7 @@ class Runner:
|
||||
save_path = save_args(config_path, config_dict)
|
||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||
|
||||
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
|
||||
def load_args(self, lang: str, config_path: str):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
config_dict = load_args(config_path)
|
||||
if config_dict is None:
|
||||
|
||||
Reference in New Issue
Block a user