mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -14,9 +14,10 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_npu_available
|
||||
@@ -51,17 +52,16 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Runner:
|
||||
r"""
|
||||
A class to manage the running status of the trainers.
|
||||
"""
|
||||
r"""A class to manage the running status of the trainers."""
|
||||
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
|
||||
r"""Init a runner."""
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.trainer: Optional["Popen"] = None
|
||||
self.trainer: Optional[Popen] = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
self.running_data: dict[Component, Any] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
@@ -71,10 +71,8 @@ class Runner:
|
||||
if self.trainer is not None:
|
||||
abort_process(self.trainer.pid)
|
||||
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
r"""
|
||||
Validates the configuration.
|
||||
"""
|
||||
def _initialize(self, data: dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
r"""Validate the configuration."""
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
@@ -116,9 +114,7 @@ class Runner:
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
r"""
|
||||
Cleans the cached memory and resets the runner.
|
||||
"""
|
||||
r"""Clean the cached memory and resets the runner."""
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
gr.Info(finish_info)
|
||||
self.trainer = None
|
||||
@@ -128,10 +124,8 @@ class Runner:
|
||||
torch_gc()
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Builds and validates the training arguments.
|
||||
"""
|
||||
def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
r"""Build and validate the training arguments."""
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
@@ -291,10 +285,8 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Builds and validates the evaluation arguments.
|
||||
"""
|
||||
def _parse_eval_args(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
r"""Build and validate the evaluation arguments."""
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
@@ -345,10 +337,8 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||
r"""
|
||||
Previews the training commands.
|
||||
"""
|
||||
def _preview(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", str], None, None]:
|
||||
r"""Preview the training commands."""
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
@@ -358,10 +348,8 @@ class Runner:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield {output_box: gen_cmd(args)}
|
||||
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
|
||||
r"""
|
||||
Starts the training process.
|
||||
"""
|
||||
def _launch(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", Any], None, None]:
|
||||
r"""Start the training process."""
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
@@ -383,10 +371,8 @@ class Runner:
|
||||
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
|
||||
yield from self.monitor()
|
||||
|
||||
def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Builds a dictionary containing the current training configuration.
|
||||
"""
|
||||
def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
r"""Build a dictionary containing the current training configuration."""
|
||||
config_dict = {}
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
|
||||
for elem, value in data.items():
|
||||
@@ -409,9 +395,7 @@ class Runner:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self):
|
||||
r"""
|
||||
Monitors the training progress and logs.
|
||||
"""
|
||||
r"""Monitorgit the training progress and logs."""
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
@@ -469,9 +453,7 @@ class Runner:
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data):
|
||||
r"""
|
||||
Saves the training configuration to config path.
|
||||
"""
|
||||
r"""Save the training configuration to config path."""
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
@@ -487,27 +469,23 @@ class Runner:
|
||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||
|
||||
def load_args(self, lang: str, config_path: str):
|
||||
r"""
|
||||
Loads the training configuration from config path.
|
||||
"""
|
||||
r"""Load the training configuration from config path."""
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
|
||||
if config_dict is None:
|
||||
gr.Warning(ALERTS["err_config_not_found"][lang])
|
||||
return {output_box: ALERTS["err_config_not_found"][lang]}
|
||||
|
||||
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
|
||||
output_dict: dict[Component, Any] = {output_box: ALERTS["info_config_loaded"][lang]}
|
||||
for elem_id, value in config_dict.items():
|
||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||
|
||||
return output_dict
|
||||
|
||||
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
|
||||
r"""
|
||||
Restore the training status if output_dir exists.
|
||||
"""
|
||||
r"""Restore the training status if output_dir exists."""
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
|
||||
output_dict: dict[Component, Any] = {output_box: LOCALES["output_box"][lang]["value"]}
|
||||
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
|
||||
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
|
||||
output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]
|
||||
|
||||
Reference in New Issue
Block a user