mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-28 00:32:48 +08:00
62 lines
2.6 KiB
Python
62 lines
2.6 KiB
Python
import gradio as gr
|
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
|
from typing import Any, Dict, Generator, Optional
|
|
|
|
from llmtuner.webui.chatter import WebChatModel
|
|
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
|
from llmtuner.webui.locales import LOCALES
|
|
from llmtuner.webui.manager import Manager
|
|
from llmtuner.webui.runner import Runner
|
|
from llmtuner.webui.utils import get_time
|
|
|
|
|
|
class Engine:
|
|
|
|
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
|
|
self.demo_mode = demo_mode
|
|
self.pure_chat = pure_chat
|
|
self.manager = Manager()
|
|
self.runner = Runner(self.manager, demo_mode=demo_mode)
|
|
self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat))
|
|
|
|
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
|
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
|
|
|
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
|
user_config = load_config() if not self.demo_mode else {}
|
|
lang = user_config.get("lang", None) or "en"
|
|
|
|
init_dict = {
|
|
"top.lang": {"value": lang},
|
|
"infer.chat_box": {"visible": self.chatter.loaded}
|
|
}
|
|
|
|
if not self.pure_chat:
|
|
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
|
init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
|
|
|
if user_config.get("last_model", None):
|
|
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
|
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
|
|
|
|
yield self._form_dict(init_dict)
|
|
|
|
if not self.pure_chat:
|
|
if self.runner.alive:
|
|
yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
|
|
if self.runner.do_train:
|
|
yield self._form_dict({"train.resume_btn": {"value": True}})
|
|
else:
|
|
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
|
else:
|
|
yield self._form_dict({
|
|
"train.output_dir": {"value": "train_" + get_time()},
|
|
"eval.output_dir": {"value": "eval_" + get_time()},
|
|
})
|
|
|
|
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
|
return {
|
|
component: gr.update(**LOCALES[name][lang])
|
|
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
|
|
}
|