This commit is contained in:
hiyouga
2023-10-15 18:28:45 +08:00
parent 0d63584c03
commit a6a04be2e6
9 changed files with 40 additions and 57 deletions

View File

@@ -3,7 +3,7 @@ from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional
from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
@@ -21,8 +21,9 @@ class Engine:
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
lang = config.get("lang", None) or "en"
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
user_config = load_config()
lang = user_config.get("lang", None) or "en"
resume_dict = {
"top.lang": {"value": lang},
@@ -33,9 +34,9 @@ class Engine:
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
if config.get("last_model", None):
resume_dict["top.model_name"] = {"value": config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
if user_config.get("last_model", None):
resume_dict["top.model_name"] = {"value": user_config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
yield self._form_dict(resume_dict)