Former-commit-id: 4469a3237db8b40e07bec6de36c9ed5b2347f854
This commit is contained in:
hiyouga 2024-04-21 21:34:25 +08:00
parent 6a3ee1edd9
commit 1ce9db5654
3 changed files with 19 additions and 20 deletions

View File

@ -1,6 +1,6 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role from ..data import Role
@ -17,7 +17,6 @@ if TYPE_CHECKING:
if is_gradio_available(): if is_gradio_available():
import gradio as gr import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
def loaded(self) -> bool: def loaded(self) -> bool:
return self.engine is not None return self.engine is not None
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang") lang = get("top.lang")
error = "" error = ""
@ -82,7 +81,7 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def unload_model(self, data) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_id("top.lang")] lang = data[self.manager.get_elem_by_id("top.lang")]
if self.demo_mode: if self.demo_mode:

View File

@ -1,6 +1,5 @@
from typing import Any, Dict, Generator from typing import TYPE_CHECKING, Any, Dict
from ..extras.packages import is_gradio_available
from .chatter import WebChatModel from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config from .common import get_model_path, list_dataset, load_config
from .locales import LOCALES from .locales import LOCALES
@ -9,8 +8,8 @@ from .runner import Runner
from .utils import get_time from .utils import get_time
if is_gradio_available(): if TYPE_CHECKING:
from gradio.components import Component # cannot use TYPE_CHECKING here from gradio.components import Component
class Engine: class Engine:
@ -32,7 +31,7 @@ class Engine:
return output_dict return output_dict
def resume(self) -> Generator[Dict[Component, Component], None, None]: def resume(self):
user_config = load_config() if not self.demo_mode else {} user_config = load_config() if not self.demo_mode else {}
lang = user_config.get("lang", None) or "en" lang = user_config.get("lang", None) or "en"
@ -58,7 +57,7 @@ class Engine:
else: else:
yield self._update_component({"eval.resume_btn": {"value": True}}) yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str) -> Dict[Component, Component]: def change_lang(self, lang: str):
return { return {
elem: elem.__class__(**LOCALES[elem_name][lang]) elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter() for elem_name, elem in self.manager.get_elem_iter()

View File

@ -21,10 +21,11 @@ from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
if is_gradio_available(): if is_gradio_available():
import gradio as gr import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component
from .manager import Manager from .manager import Manager
@ -243,7 +244,7 @@ class Runner:
return args return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]: def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True) error = self._initialize(data, do_train, from_preview=True)
if error: if error:
@ -253,7 +254,7 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield {output_box: gen_cmd(args)} yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]: def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False) error = self._initialize(data, do_train, from_preview=False)
if error: if error:
@ -267,19 +268,19 @@ class Runner:
self.thread.start() self.thread.start()
yield from self.monitor() yield from self.monitor()
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]: def preview_train(self, data):
yield from self._preview(data, do_train=True) yield from self._preview(data, do_train=True)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]: def preview_eval(self, data):
yield from self._preview(data, do_train=False) yield from self._preview(data, do_train=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]: def run_train(self, data):
yield from self._launch(data, do_train=True) yield from self._launch(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]: def run_eval(self, data):
yield from self._launch(data, do_train=False) yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Dict[Component, Any], None, None]: def monitor(self):
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.aborted = False self.aborted = False
self.running = True self.running = True
@ -336,7 +337,7 @@ class Runner:
yield return_dict yield return_dict
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]: def save_args(self, data):
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True) error = self._initialize(data, do_train=True, from_preview=True)
if error: if error:
@ -355,7 +356,7 @@ class Runner:
save_path = save_args(config_path, config_dict) save_path = save_args(config_path, config_dict)
return {output_box: ALERTS["info_config_saved"][lang] + save_path} return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]: def load_args(self, lang: str, config_path: str):
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(config_path) config_dict = load_args(config_path)
if config_dict is None: if config_dict is None: