mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
6a3ee1edd9
commit
1ce9db5654
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..data import Role
|
from ..data import Role
|
||||||
@ -17,7 +17,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
if is_gradio_available():
|
if is_gradio_available():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
|
||||||
|
|
||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
@ -38,7 +37,7 @@ class WebChatModel(ChatModel):
|
|||||||
def loaded(self) -> bool:
|
def loaded(self) -> bool:
|
||||||
return self.engine is not None
|
return self.engine is not None
|
||||||
|
|
||||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
def load_model(self, data) -> Generator[str, None, None]:
|
||||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||||
lang = get("top.lang")
|
lang = get("top.lang")
|
||||||
error = ""
|
error = ""
|
||||||
@ -82,7 +81,7 @@ class WebChatModel(ChatModel):
|
|||||||
|
|
||||||
yield ALERTS["info_loaded"][lang]
|
yield ALERTS["info_loaded"][lang]
|
||||||
|
|
||||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
def unload_model(self, data) -> Generator[str, None, None]:
|
||||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||||
|
|
||||||
if self.demo_mode:
|
if self.demo_mode:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from typing import Any, Dict, Generator
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
from ..extras.packages import is_gradio_available
|
|
||||||
from .chatter import WebChatModel
|
from .chatter import WebChatModel
|
||||||
from .common import get_model_path, list_dataset, load_config
|
from .common import get_model_path, list_dataset, load_config
|
||||||
from .locales import LOCALES
|
from .locales import LOCALES
|
||||||
@ -9,8 +8,8 @@ from .runner import Runner
|
|||||||
from .utils import get_time
|
from .utils import get_time
|
||||||
|
|
||||||
|
|
||||||
if is_gradio_available():
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
@ -32,7 +31,7 @@ class Engine:
|
|||||||
|
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
def resume(self) -> Generator[Dict[Component, Component], None, None]:
|
def resume(self):
|
||||||
user_config = load_config() if not self.demo_mode else {}
|
user_config = load_config() if not self.demo_mode else {}
|
||||||
lang = user_config.get("lang", None) or "en"
|
lang = user_config.get("lang", None) or "en"
|
||||||
|
|
||||||
@ -58,7 +57,7 @@ class Engine:
|
|||||||
else:
|
else:
|
||||||
yield self._update_component({"eval.resume_btn": {"value": True}})
|
yield self._update_component({"eval.resume_btn": {"value": True}})
|
||||||
|
|
||||||
def change_lang(self, lang: str) -> Dict[Component, Component]:
|
def change_lang(self, lang: str):
|
||||||
return {
|
return {
|
||||||
elem: elem.__class__(**LOCALES[elem_name][lang])
|
elem: elem.__class__(**LOCALES[elem_name][lang])
|
||||||
for elem_name, elem in self.manager.get_elem_iter()
|
for elem_name, elem in self.manager.get_elem_iter()
|
||||||
|
@ -21,10 +21,11 @@ from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
|||||||
|
|
||||||
if is_gradio_available():
|
if is_gradio_available():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from gradio.components import Component
|
||||||
|
|
||||||
from .manager import Manager
|
from .manager import Manager
|
||||||
|
|
||||||
|
|
||||||
@ -243,7 +244,7 @@ class Runner:
|
|||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
|
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||||
error = self._initialize(data, do_train, from_preview=True)
|
error = self._initialize(data, do_train, from_preview=True)
|
||||||
if error:
|
if error:
|
||||||
@ -253,7 +254,7 @@ class Runner:
|
|||||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||||
yield {output_box: gen_cmd(args)}
|
yield {output_box: gen_cmd(args)}
|
||||||
|
|
||||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
|
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
|
||||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||||
error = self._initialize(data, do_train, from_preview=False)
|
error = self._initialize(data, do_train, from_preview=False)
|
||||||
if error:
|
if error:
|
||||||
@ -267,19 +268,19 @@ class Runner:
|
|||||||
self.thread.start()
|
self.thread.start()
|
||||||
yield from self.monitor()
|
yield from self.monitor()
|
||||||
|
|
||||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
def preview_train(self, data):
|
||||||
yield from self._preview(data, do_train=True)
|
yield from self._preview(data, do_train=True)
|
||||||
|
|
||||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
def preview_eval(self, data):
|
||||||
yield from self._preview(data, do_train=False)
|
yield from self._preview(data, do_train=False)
|
||||||
|
|
||||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
def run_train(self, data):
|
||||||
yield from self._launch(data, do_train=True)
|
yield from self._launch(data, do_train=True)
|
||||||
|
|
||||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
def run_eval(self, data):
|
||||||
yield from self._launch(data, do_train=False)
|
yield from self._launch(data, do_train=False)
|
||||||
|
|
||||||
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
|
def monitor(self):
|
||||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||||
self.aborted = False
|
self.aborted = False
|
||||||
self.running = True
|
self.running = True
|
||||||
@ -336,7 +337,7 @@ class Runner:
|
|||||||
|
|
||||||
yield return_dict
|
yield return_dict
|
||||||
|
|
||||||
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
|
def save_args(self, data):
|
||||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||||
error = self._initialize(data, do_train=True, from_preview=True)
|
error = self._initialize(data, do_train=True, from_preview=True)
|
||||||
if error:
|
if error:
|
||||||
@ -355,7 +356,7 @@ class Runner:
|
|||||||
save_path = save_args(config_path, config_dict)
|
save_path = save_args(config_path, config_dict)
|
||||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||||
|
|
||||||
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
|
def load_args(self, lang: str, config_path: str):
|
||||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||||
config_dict = load_args(config_path)
|
config_dict = load_args(config_path)
|
||||||
if config_dict is None:
|
if config_dict is None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user