mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
better llamaboard
* easily resume from checkpoint * support full and freeze checkpoints * faster ui Former-commit-id: 80708717329b4552920dd4ce8cebc683e65d54c5
This commit is contained in:
parent
19a3262387
commit
820404946e
@ -1,4 +1,4 @@
|
|||||||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||||
|
|
||||||
from .cli import VERSION
|
from .cli import VERSION
|
||||||
|
|
||||||
|
@ -2,6 +2,19 @@ from collections import OrderedDict, defaultdict
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||||
|
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
|
CHECKPOINT_NAMES = {
|
||||||
|
SAFE_ADAPTER_WEIGHTS_NAME,
|
||||||
|
ADAPTER_WEIGHTS_NAME,
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
}
|
||||||
|
|
||||||
CHOICES = ["A", "B", "C", "D"]
|
CHOICES = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
@ -26,9 +39,9 @@ LAYERNORM_NAMES = {"norm", "ln"}
|
|||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
|
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||||
|
|
||||||
PEFT_METHODS = ["lora"]
|
PEFT_METHODS = {"lora"}
|
||||||
|
|
||||||
RUNNING_LOG = "running_log.txt"
|
RUNNING_LOG = "running_log.txt"
|
||||||
|
|
||||||
@ -49,9 +62,9 @@ TRAINING_STAGES = {
|
|||||||
"Pre-Training": "pt",
|
"Pre-Training": "pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
|
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras.constants import CHECKPOINT_NAMES
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import check_dependencies, get_current_device
|
from ..extras.misc import check_dependencies, get_current_device
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
@ -255,13 +256,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
and can_resume_from_checkpoint
|
and can_resume_from_checkpoint
|
||||||
):
|
):
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
if last_checkpoint is None and any(
|
||||||
|
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
|
||||||
|
):
|
||||||
|
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||||
|
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
training_args.resume_from_checkpoint = last_checkpoint
|
training_args.resume_from_checkpoint = last_checkpoint
|
||||||
logger.info(
|
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
|
||||||
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||||
training_args.resume_from_checkpoint
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.stage in ["rm", "ppo"]
|
finetuning_args.stage in ["rm", "ppo"]
|
||||||
|
@ -6,6 +6,7 @@ from numpy.typing import NDArray
|
|||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..data import Role
|
from ..data import Role
|
||||||
|
from ..extras.constants import PEFT_METHODS
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..extras.packages import is_gradio_available
|
from ..extras.packages import is_gradio_available
|
||||||
from .common import get_save_dir
|
from .common import get_save_dir
|
||||||
@ -44,13 +45,14 @@ class WebChatModel(ChatModel):
|
|||||||
|
|
||||||
def load_model(self, data) -> Generator[str, None, None]:
|
def load_model(self, data) -> Generator[str, None, None]:
|
||||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||||
lang = get("top.lang")
|
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||||
|
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
|
||||||
error = ""
|
error = ""
|
||||||
if self.loaded:
|
if self.loaded:
|
||||||
error = ALERTS["err_exists"][lang]
|
error = ALERTS["err_exists"][lang]
|
||||||
elif not get("top.model_name"):
|
elif not model_name:
|
||||||
error = ALERTS["err_no_model"][lang]
|
error = ALERTS["err_no_model"][lang]
|
||||||
elif not get("top.model_path"):
|
elif not model_path:
|
||||||
error = ALERTS["err_no_path"][lang]
|
error = ALERTS["err_no_path"][lang]
|
||||||
elif self.demo_mode:
|
elif self.demo_mode:
|
||||||
error = ALERTS["err_demo"][lang]
|
error = ALERTS["err_demo"][lang]
|
||||||
@ -60,21 +62,10 @@ class WebChatModel(ChatModel):
|
|||||||
yield error
|
yield error
|
||||||
return
|
return
|
||||||
|
|
||||||
if get("top.adapter_path"):
|
|
||||||
adapter_name_or_path = ",".join(
|
|
||||||
[
|
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
|
||||||
for adapter in get("top.adapter_path")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
adapter_name_or_path = None
|
|
||||||
|
|
||||||
yield ALERTS["info_loading"][lang]
|
yield ALERTS["info_loading"][lang]
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=get("top.model_path"),
|
model_name_or_path=model_path,
|
||||||
adapter_name_or_path=adapter_name_or_path,
|
finetuning_type=finetuning_type,
|
||||||
finetuning_type=get("top.finetuning_type"),
|
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||||
@ -83,8 +74,16 @@ class WebChatModel(ChatModel):
|
|||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
infer_backend=get("infer.infer_backend"),
|
infer_backend=get("infer.infer_backend"),
|
||||||
)
|
)
|
||||||
super().__init__(args)
|
|
||||||
|
|
||||||
|
if checkpoint_path:
|
||||||
|
if finetuning_type in PEFT_METHODS: # list
|
||||||
|
args["adapter_name_or_path"] = ",".join(
|
||||||
|
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
|
||||||
|
)
|
||||||
|
else: # str
|
||||||
|
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
|
||||||
|
|
||||||
|
super().__init__(args)
|
||||||
yield ALERTS["info_loaded"][lang]
|
yield ALERTS["info_loaded"][lang]
|
||||||
|
|
||||||
def unload_model(self, data) -> Generator[str, None, None]:
|
def unload_model(self, data) -> Generator[str, None, None]:
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
|
||||||
from yaml import safe_dump, safe_load
|
from yaml import safe_dump, safe_load
|
||||||
|
|
||||||
from ..extras.constants import (
|
from ..extras.constants import (
|
||||||
|
CHECKPOINT_NAMES,
|
||||||
DATA_CONFIG,
|
DATA_CONFIG,
|
||||||
DEFAULT_MODULE,
|
DEFAULT_MODULE,
|
||||||
DEFAULT_TEMPLATE,
|
DEFAULT_TEMPLATE,
|
||||||
@ -29,7 +29,6 @@ if is_gradio_available():
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
DEFAULT_CONFIG_DIR = "config"
|
DEFAULT_CONFIG_DIR = "config"
|
||||||
DEFAULT_DATA_DIR = "data"
|
DEFAULT_DATA_DIR = "data"
|
||||||
@ -38,19 +37,31 @@ USER_CONFIG = "user_config.yaml"
|
|||||||
|
|
||||||
|
|
||||||
def get_save_dir(*paths: str) -> os.PathLike:
|
def get_save_dir(*paths: str) -> os.PathLike:
|
||||||
|
r"""
|
||||||
|
Gets the path to saved model checkpoints.
|
||||||
|
"""
|
||||||
paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths)
|
paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths)
|
||||||
return os.path.join(DEFAULT_SAVE_DIR, *paths)
|
return os.path.join(DEFAULT_SAVE_DIR, *paths)
|
||||||
|
|
||||||
|
|
||||||
def get_config_path() -> os.PathLike:
|
def get_config_path() -> os.PathLike:
|
||||||
|
r"""
|
||||||
|
Gets the path to user config.
|
||||||
|
"""
|
||||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def get_save_path(config_path: str) -> os.PathLike:
|
def get_arg_save_path(config_path: str) -> os.PathLike:
|
||||||
|
r"""
|
||||||
|
Gets the path to saved arguments.
|
||||||
|
"""
|
||||||
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
|
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
|
||||||
|
|
||||||
|
|
||||||
def load_config() -> Dict[str, Any]:
|
def load_config() -> Dict[str, Any]:
|
||||||
|
r"""
|
||||||
|
Loads user config if exists.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||||
return safe_load(f)
|
return safe_load(f)
|
||||||
@ -59,6 +70,9 @@ def load_config() -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||||
|
r"""
|
||||||
|
Saves user config.
|
||||||
|
"""
|
||||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
user_config["lang"] = lang or user_config["lang"]
|
user_config["lang"] = lang or user_config["lang"]
|
||||||
@ -69,23 +83,10 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
|||||||
safe_dump(user_config, f)
|
safe_dump(user_config, f)
|
||||||
|
|
||||||
|
|
||||||
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
def get_model_path(model_name: str) -> Optional[str]:
|
||||||
try:
|
r"""
|
||||||
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
|
Gets the model path according to the model name.
|
||||||
return safe_load(f)
|
"""
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
|
||||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
|
||||||
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
|
|
||||||
safe_dump(config_dict, f)
|
|
||||||
|
|
||||||
return str(get_save_path(config_path))
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_name: str) -> str:
|
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||||
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
|
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
|
||||||
@ -99,40 +100,71 @@ def get_model_path(model_name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_prefix(model_name: str) -> str:
|
def get_prefix(model_name: str) -> str:
|
||||||
|
r"""
|
||||||
|
Gets the prefix of the model name to obtain the model family.
|
||||||
|
"""
|
||||||
return model_name.split("-")[0]
|
return model_name.split("-")[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
|
||||||
|
r"""
|
||||||
|
Gets the necessary information of this model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model_path (str)
|
||||||
|
template (str)
|
||||||
|
visual (bool)
|
||||||
|
"""
|
||||||
|
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
|
||||||
|
|
||||||
|
|
||||||
def get_module(model_name: str) -> str:
|
def get_module(model_name: str) -> str:
|
||||||
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
|
r"""
|
||||||
|
Gets the LoRA modules of this model.
|
||||||
|
"""
|
||||||
|
return DEFAULT_MODULE.get(get_prefix(model_name), "all")
|
||||||
|
|
||||||
|
|
||||||
def get_template(model_name: str) -> str:
|
def get_template(model_name: str) -> str:
|
||||||
|
r"""
|
||||||
|
Gets the template name if the model is a chat model.
|
||||||
|
"""
|
||||||
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
|
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
|
||||||
return DEFAULT_TEMPLATE[get_prefix(model_name)]
|
return DEFAULT_TEMPLATE[get_prefix(model_name)]
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
def get_visual(model_name: str) -> bool:
|
def get_visual(model_name: str) -> bool:
|
||||||
|
r"""
|
||||||
|
Judges if the model is a vision language model.
|
||||||
|
"""
|
||||||
return get_prefix(model_name) in VISION_MODELS
|
return get_prefix(model_name) in VISION_MODELS
|
||||||
|
|
||||||
|
|
||||||
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||||
if finetuning_type not in PEFT_METHODS:
|
r"""
|
||||||
return gr.Dropdown(value=[], choices=[], interactive=False)
|
Lists all available checkpoints.
|
||||||
|
"""
|
||||||
adapters = []
|
checkpoints = []
|
||||||
if model_name and finetuning_type == "lora":
|
if model_name:
|
||||||
save_dir = get_save_dir(model_name, finetuning_type)
|
save_dir = get_save_dir(model_name, finetuning_type)
|
||||||
if save_dir and os.path.isdir(save_dir):
|
if save_dir and os.path.isdir(save_dir):
|
||||||
for adapter in os.listdir(save_dir):
|
for checkpoint in os.listdir(save_dir):
|
||||||
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
|
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
|
||||||
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
|
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
|
||||||
):
|
):
|
||||||
adapters.append(adapter)
|
checkpoints.append(checkpoint)
|
||||||
return gr.Dropdown(value=[], choices=adapters, interactive=True)
|
|
||||||
|
if finetuning_type in PEFT_METHODS:
|
||||||
|
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
|
||||||
|
else:
|
||||||
|
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||||
|
r"""
|
||||||
|
Loads dataset_info.json.
|
||||||
|
"""
|
||||||
if dataset_dir == "ONLINE":
|
if dataset_dir == "ONLINE":
|
||||||
logger.info("dataset_dir is ONLINE, using online dataset.")
|
logger.info("dataset_dir is ONLINE, using online dataset.")
|
||||||
return {}
|
return {}
|
||||||
@ -145,12 +177,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
|
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
|
||||||
|
r"""
|
||||||
|
Lists all available datasets in the dataset dir for the training stage.
|
||||||
|
"""
|
||||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||||
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
|
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
|
||||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||||
return gr.Dropdown(value=[], choices=datasets)
|
return gr.Dropdown(choices=datasets)
|
||||||
|
|
||||||
|
|
||||||
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button":
|
|
||||||
return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt"))
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from ...extras.packages import is_gradio_available
|
from ...extras.packages import is_gradio_available
|
||||||
from ..common import DEFAULT_DATA_DIR, list_dataset
|
from ..common import DEFAULT_DATA_DIR, list_datasets
|
||||||
from .data import create_preview_box
|
from .data import create_preview_box
|
||||||
|
|
||||||
|
|
||||||
@ -74,6 +74,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
stop_btn.click(engine.runner.set_abort)
|
stop_btn.click(engine.runner.set_abort)
|
||||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||||
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
from typing import TYPE_CHECKING, Dict, Generator, List, Union
|
||||||
|
|
||||||
|
from ...extras.constants import PEFT_METHODS
|
||||||
from ...extras.misc import torch_gc
|
from ...extras.misc import torch_gc
|
||||||
from ...extras.packages import is_gradio_available
|
from ...extras.packages import is_gradio_available
|
||||||
from ...train.tuner import export_model
|
from ...train.tuner import export_model
|
||||||
@ -24,8 +25,8 @@ def save_model(
|
|||||||
lang: str,
|
lang: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
adapter_path: List[str],
|
|
||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
|
checkpoint_path: Union[str, List[str]],
|
||||||
template: str,
|
template: str,
|
||||||
visual_inputs: bool,
|
visual_inputs: bool,
|
||||||
export_size: int,
|
export_size: int,
|
||||||
@ -45,9 +46,9 @@ def save_model(
|
|||||||
error = ALERTS["err_no_export_dir"][lang]
|
error = ALERTS["err_no_export_dir"][lang]
|
||||||
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
|
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
|
||||||
error = ALERTS["err_no_dataset"][lang]
|
error = ALERTS["err_no_dataset"][lang]
|
||||||
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
|
elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path:
|
||||||
error = ALERTS["err_no_adapter"][lang]
|
error = ALERTS["err_no_adapter"][lang]
|
||||||
elif export_quantization_bit in GPTQ_BITS and adapter_path:
|
elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list):
|
||||||
error = ALERTS["err_gptq_lora"][lang]
|
error = ALERTS["err_gptq_lora"][lang]
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
@ -55,16 +56,8 @@ def save_model(
|
|||||||
yield error
|
yield error
|
||||||
return
|
return
|
||||||
|
|
||||||
if adapter_path:
|
|
||||||
adapter_name_or_path = ",".join(
|
|
||||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
adapter_name_or_path = None
|
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_path,
|
model_name_or_path=model_path,
|
||||||
adapter_name_or_path=adapter_name_or_path,
|
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
template=template,
|
template=template,
|
||||||
visual_inputs=visual_inputs,
|
visual_inputs=visual_inputs,
|
||||||
@ -77,6 +70,14 @@ def save_model(
|
|||||||
export_legacy_format=export_legacy_format,
|
export_legacy_format=export_legacy_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if checkpoint_path:
|
||||||
|
if finetuning_type in PEFT_METHODS: # list
|
||||||
|
args["adapter_name_or_path"] = ",".join(
|
||||||
|
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
|
||||||
|
)
|
||||||
|
else: # str
|
||||||
|
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
|
||||||
|
|
||||||
yield ALERTS["info_exporting"][lang]
|
yield ALERTS["info_exporting"][lang]
|
||||||
export_model(args)
|
export_model(args)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
@ -86,7 +87,7 @@ def save_model(
|
|||||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
|
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
|
||||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
|
||||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||||
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
|
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
|
||||||
export_legacy_format = gr.Checkbox()
|
export_legacy_format = gr.Checkbox()
|
||||||
@ -104,8 +105,8 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
engine.manager.get_elem_by_id("top.lang"),
|
engine.manager.get_elem_by_id("top.lang"),
|
||||||
engine.manager.get_elem_by_id("top.model_name"),
|
engine.manager.get_elem_by_id("top.model_name"),
|
||||||
engine.manager.get_elem_by_id("top.model_path"),
|
engine.manager.get_elem_by_id("top.model_path"),
|
||||||
engine.manager.get_elem_by_id("top.adapter_path"),
|
|
||||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||||
|
engine.manager.get_elem_by_id("top.checkpoint_path"),
|
||||||
engine.manager.get_elem_by_id("top.template"),
|
engine.manager.get_elem_by_id("top.template"),
|
||||||
engine.manager.get_elem_by_id("top.visual_inputs"),
|
engine.manager.get_elem_by_id("top.visual_inputs"),
|
||||||
export_size,
|
export_size,
|
||||||
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
|||||||
from ...data import templates
|
from ...data import templates
|
||||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from ...extras.packages import is_gradio_available
|
from ...extras.packages import is_gradio_available
|
||||||
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
|
from ..common import get_model_info, list_checkpoints, save_config
|
||||||
from ..utils import can_quantize
|
from ..utils import can_quantize
|
||||||
|
|
||||||
|
|
||||||
@ -25,8 +25,7 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||||
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
|
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
|
||||||
refresh_btn = gr.Button(scale=1)
|
|
||||||
|
|
||||||
with gr.Accordion(open=False) as advanced_tab:
|
with gr.Accordion(open=False) as advanced_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -36,27 +35,17 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
||||||
visual_inputs = gr.Checkbox(scale=1)
|
visual_inputs = gr.Checkbox(scale=1)
|
||||||
|
|
||||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
|
||||||
get_model_path, [model_name], [model_path], queue=False
|
|
||||||
).then(get_template, [model_name], [template], queue=False).then(
|
|
||||||
get_visual, [model_name], [visual_inputs], queue=False
|
|
||||||
) # do not save config since the below line will save
|
|
||||||
|
|
||||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||||
|
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
|
||||||
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
|
||||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
|
||||||
)
|
|
||||||
|
|
||||||
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
lang=lang,
|
lang=lang,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
adapter_path=adapter_path,
|
checkpoint_path=checkpoint_path,
|
||||||
refresh_btn=refresh_btn,
|
|
||||||
advanced_tab=advanced_tab,
|
advanced_tab=advanced_tab,
|
||||||
quantization_bit=quantization_bit,
|
quantization_bit=quantization_bit,
|
||||||
template=template,
|
template=template,
|
||||||
|
@ -5,8 +5,9 @@ from transformers.trainer_utils import SchedulerType
|
|||||||
from ...extras.constants import TRAINING_STAGES
|
from ...extras.constants import TRAINING_STAGES
|
||||||
from ...extras.misc import get_device_count
|
from ...extras.misc import get_device_count
|
||||||
from ...extras.packages import is_gradio_available
|
from ...extras.packages import is_gradio_available
|
||||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
|
||||||
from ..components.data import create_preview_box
|
from ..utils import change_stage, check_output_dir, list_output_dirs
|
||||||
|
from .data import create_preview_box
|
||||||
|
|
||||||
|
|
||||||
if is_gradio_available():
|
if is_gradio_available():
|
||||||
@ -256,11 +257,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_dir = gr.Textbox()
|
initial_dir = gr.Textbox(visible=False, interactive=False)
|
||||||
|
output_dir = gr.Dropdown(allow_custom_value=True)
|
||||||
config_path = gr.Textbox()
|
config_path = gr.Textbox()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
device_count = gr.Textbox(value=str(get_device_count()), interactive=False)
|
device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False)
|
||||||
ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
|
ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
|
||||||
ds_offload = gr.Checkbox()
|
ds_offload = gr.Checkbox()
|
||||||
|
|
||||||
@ -282,6 +284,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
arg_load_btn=arg_load_btn,
|
arg_load_btn=arg_load_btn,
|
||||||
start_btn=start_btn,
|
start_btn=start_btn,
|
||||||
stop_btn=stop_btn,
|
stop_btn=stop_btn,
|
||||||
|
initial_dir=initial_dir,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
device_count=device_count,
|
device_count=device_count,
|
||||||
@ -295,24 +298,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
)
|
)
|
||||||
output_elems = [output_box, progress_bar, loss_viewer]
|
output_elems = [output_box, progress_bar, loss_viewer]
|
||||||
|
|
||||||
|
lang = engine.manager.get_elem_by_id("top.lang")
|
||||||
|
model_name = engine.manager.get_elem_by_id("top.model_name")
|
||||||
|
finetuning_type = engine.manager.get_elem_by_id("top.finetuning_type")
|
||||||
|
|
||||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||||
arg_load_btn.click(
|
arg_load_btn.click(
|
||||||
engine.runner.load_args,
|
engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None
|
||||||
[engine.manager.get_elem_by_id("top.lang"), config_path],
|
|
||||||
list(input_elems) + [output_box],
|
|
||||||
concurrency_limit=None,
|
|
||||||
)
|
)
|
||||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||||
stop_btn.click(engine.runner.set_abort)
|
stop_btn.click(engine.runner.set_abort)
|
||||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||||
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
|
||||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
|
||||||
list_adapters,
|
reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
|
||||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
output_dir.change(
|
||||||
[reward_model],
|
list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None
|
||||||
queue=False,
|
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
|
||||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
from .chatter import WebChatModel
|
from .chatter import WebChatModel
|
||||||
from .common import get_model_path, list_dataset, load_config
|
from .common import load_config
|
||||||
from .locales import LOCALES
|
from .locales import LOCALES
|
||||||
from .manager import Manager
|
from .manager import Manager
|
||||||
from .runner import Runner
|
from .runner import Runner
|
||||||
from .utils import get_time, save_ds_config
|
from .utils import create_ds_config, get_time
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -20,7 +20,7 @@ class Engine:
|
|||||||
self.runner = Runner(self.manager, demo_mode)
|
self.runner = Runner(self.manager, demo_mode)
|
||||||
self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
|
self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
|
||||||
if not demo_mode:
|
if not demo_mode:
|
||||||
save_ds_config()
|
create_ds_config()
|
||||||
|
|
||||||
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
|
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
|
||||||
r"""
|
r"""
|
||||||
@ -40,16 +40,15 @@ class Engine:
|
|||||||
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
|
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
|
||||||
|
|
||||||
if not self.pure_chat:
|
if not self.pure_chat:
|
||||||
init_dict["train.dataset"] = {"choices": list_dataset().choices}
|
current_time = get_time()
|
||||||
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
|
init_dict["train.initial_dir"] = {"value": "train_{}".format(current_time)}
|
||||||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
|
||||||
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
|
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
|
||||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
|
||||||
init_dict["infer.image_box"] = {"visible": False}
|
init_dict["infer.image_box"] = {"visible": False}
|
||||||
|
|
||||||
if user_config.get("last_model", None):
|
if user_config.get("last_model", None):
|
||||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||||
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
|
|
||||||
|
|
||||||
yield self._update_component(init_dict)
|
yield self._update_component(init_dict)
|
||||||
|
|
||||||
|
@ -46,26 +46,15 @@ LOCALES = {
|
|||||||
"label": "微调方法",
|
"label": "微调方法",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"adapter_path": {
|
"checkpoint_path": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Adapter path",
|
"label": "Checkpoint path",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "Путь к адаптеру",
|
"label": "Путь контрольной точки",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "适配器路径",
|
"label": "检查点路径",
|
||||||
},
|
|
||||||
},
|
|
||||||
"refresh_btn": {
|
|
||||||
"en": {
|
|
||||||
"value": "Refresh adapters",
|
|
||||||
},
|
|
||||||
"ru": {
|
|
||||||
"value": "Обновить адаптеры",
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"value": "刷新适配器",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"advanced_tab": {
|
"advanced_tab": {
|
||||||
@ -1531,6 +1520,11 @@ ALERTS = {
|
|||||||
"ru": "Среда CUDA не обнаружена.",
|
"ru": "Среда CUDA не обнаружена.",
|
||||||
"zh": "未检测到 CUDA 环境。",
|
"zh": "未检测到 CUDA 环境。",
|
||||||
},
|
},
|
||||||
|
"warn_output_dir_exists": {
|
||||||
|
"en": "Output dir already exists, will resume training from here.",
|
||||||
|
"ru": "Выходной каталог уже существует, обучение будет продолжено отсюда.",
|
||||||
|
"zh": "输出目录已存在,将从该断点恢复训练。",
|
||||||
|
},
|
||||||
"info_aborting": {
|
"info_aborting": {
|
||||||
"en": "Aborted, wait for terminating...",
|
"en": "Aborted, wait for terminating...",
|
||||||
"ru": "Прервано, ожидание завершения...",
|
"ru": "Прервано, ожидание завершения...",
|
||||||
|
@ -55,7 +55,7 @@ class Manager:
|
|||||||
self._id_to_elem["top.model_name"],
|
self._id_to_elem["top.model_name"],
|
||||||
self._id_to_elem["top.model_path"],
|
self._id_to_elem["top.model_path"],
|
||||||
self._id_to_elem["top.finetuning_type"],
|
self._id_to_elem["top.finetuning_type"],
|
||||||
self._id_to_elem["top.adapter_path"],
|
self._id_to_elem["top.checkpoint_path"],
|
||||||
self._id_to_elem["top.quantization_bit"],
|
self._id_to_elem["top.quantization_bit"],
|
||||||
self._id_to_elem["top.template"],
|
self._id_to_elem["top.template"],
|
||||||
self._id_to_elem["top.rope_scaling"],
|
self._id_to_elem["top.rope_scaling"],
|
||||||
|
@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
|||||||
import psutil
|
import psutil
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
|
||||||
from ..extras.constants import TRAINING_STAGES
|
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
|
||||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||||
from ..extras.packages import is_gradio_available
|
from ..extras.packages import is_gradio_available
|
||||||
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_args, load_config, save_args
|
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
|
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||||
|
|
||||||
|
|
||||||
if is_gradio_available():
|
if is_gradio_available():
|
||||||
@ -85,26 +85,16 @@ class Runner:
|
|||||||
|
|
||||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||||
|
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.adapter_path"):
|
|
||||||
adapter_name_or_path = ",".join(
|
|
||||||
[
|
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
|
||||||
for adapter in get("top.adapter_path")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
adapter_name_or_path = None
|
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||||
do_train=True,
|
do_train=True,
|
||||||
model_name_or_path=get("top.model_path"),
|
model_name_or_path=get("top.model_path"),
|
||||||
adapter_name_or_path=adapter_name_or_path,
|
|
||||||
cache_dir=user_config.get("cache_dir", None),
|
cache_dir=user_config.get("cache_dir", None),
|
||||||
preprocessing_num_workers=16,
|
preprocessing_num_workers=16,
|
||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
@ -134,13 +124,23 @@ class Runner:
|
|||||||
report_to="all" if get("train.report_to") else "none",
|
report_to="all" if get("train.report_to") else "none",
|
||||||
use_galore=get("train.use_galore"),
|
use_galore=get("train.use_galore"),
|
||||||
use_badam=get("train.use_badam"),
|
use_badam=get("train.use_badam"),
|
||||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
|
||||||
fp16=(get("train.compute_type") == "fp16"),
|
fp16=(get("train.compute_type") == "fp16"),
|
||||||
bf16=(get("train.compute_type") == "bf16"),
|
bf16=(get("train.compute_type") == "bf16"),
|
||||||
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
||||||
plot_loss=True,
|
plot_loss=True,
|
||||||
|
ddp_timeout=180000000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# checkpoints
|
||||||
|
if get("top.checkpoint_path"):
|
||||||
|
if finetuning_type in PEFT_METHODS: # list
|
||||||
|
args["adapter_name_or_path"] = ",".join(
|
||||||
|
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
|
||||||
|
)
|
||||||
|
else: # str
|
||||||
|
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||||
|
|
||||||
# freeze config
|
# freeze config
|
||||||
if args["finetuning_type"] == "freeze":
|
if args["finetuning_type"] == "freeze":
|
||||||
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
||||||
@ -156,7 +156,7 @@ class Runner:
|
|||||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||||
args["use_rslora"] = get("train.use_rslora")
|
args["use_rslora"] = get("train.use_rslora")
|
||||||
args["use_dora"] = get("train.use_dora")
|
args["use_dora"] = get("train.use_dora")
|
||||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
args["lora_target"] = get("train.lora_target") or get_module(model_name)
|
||||||
args["additional_target"] = get("train.additional_target") or None
|
args["additional_target"] = get("train.additional_target") or None
|
||||||
|
|
||||||
if args["use_llama_pro"]:
|
if args["use_llama_pro"]:
|
||||||
@ -164,13 +164,14 @@ class Runner:
|
|||||||
|
|
||||||
# rlhf config
|
# rlhf config
|
||||||
if args["stage"] == "ppo":
|
if args["stage"] == "ppo":
|
||||||
args["reward_model"] = ",".join(
|
if finetuning_type in PEFT_METHODS:
|
||||||
[
|
args["reward_model"] = ",".join(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")]
|
||||||
for adapter in get("train.reward_model")
|
)
|
||||||
]
|
else:
|
||||||
)
|
args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model"))
|
||||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
|
||||||
|
args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full"
|
||||||
args["ppo_score_norm"] = get("train.ppo_score_norm")
|
args["ppo_score_norm"] = get("train.ppo_score_norm")
|
||||||
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
|
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
|
||||||
args["top_k"] = 0
|
args["top_k"] = 0
|
||||||
@ -211,25 +212,15 @@ class Runner:
|
|||||||
|
|
||||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||||
|
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.adapter_path"):
|
|
||||||
adapter_name_or_path = ",".join(
|
|
||||||
[
|
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
|
||||||
for adapter in get("top.adapter_path")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
adapter_name_or_path = None
|
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
model_name_or_path=get("top.model_path"),
|
model_name_or_path=get("top.model_path"),
|
||||||
adapter_name_or_path=adapter_name_or_path,
|
|
||||||
cache_dir=user_config.get("cache_dir", None),
|
cache_dir=user_config.get("cache_dir", None),
|
||||||
preprocessing_num_workers=16,
|
preprocessing_num_workers=16,
|
||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
@ -245,7 +236,7 @@ class Runner:
|
|||||||
max_new_tokens=get("eval.max_new_tokens"),
|
max_new_tokens=get("eval.max_new_tokens"),
|
||||||
top_p=get("eval.top_p"),
|
top_p=get("eval.top_p"),
|
||||||
temperature=get("eval.temperature"),
|
temperature=get("eval.temperature"),
|
||||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
|
||||||
)
|
)
|
||||||
|
|
||||||
if get("eval.predict"):
|
if get("eval.predict"):
|
||||||
@ -253,6 +244,14 @@ class Runner:
|
|||||||
else:
|
else:
|
||||||
args["do_eval"] = True
|
args["do_eval"] = True
|
||||||
|
|
||||||
|
if get("top.checkpoint_path"):
|
||||||
|
if finetuning_type in PEFT_METHODS: # list
|
||||||
|
args["adapter_name_or_path"] = ",".join(
|
||||||
|
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
|
||||||
|
)
|
||||||
|
else: # str
|
||||||
|
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||||
@ -296,9 +295,7 @@ class Runner:
|
|||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||||
lang = get("top.lang")
|
lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type")
|
||||||
model_name = get("top.model_name")
|
|
||||||
finetuning_type = get("top.finetuning_type")
|
|
||||||
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
|
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
|
||||||
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
||||||
|
|
||||||
@ -356,7 +353,7 @@ class Runner:
|
|||||||
config_dict: Dict[str, Any] = {}
|
config_dict: Dict[str, Any] = {}
|
||||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||||
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
||||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
|
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
|
||||||
for elem, value in data.items():
|
for elem, value in data.items():
|
||||||
elem_id = self.manager.get_id_by_elem(elem)
|
elem_id = self.manager.get_id_by_elem(elem)
|
||||||
if elem_id not in skip_ids:
|
if elem_id not in skip_ids:
|
||||||
|
@ -3,12 +3,13 @@ import os
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from yaml import safe_dump
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
from yaml import safe_dump, safe_load
|
||||||
|
|
||||||
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
|
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES
|
||||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||||
from ..extras.ploting import gen_loss_plot
|
from ..extras.ploting import gen_loss_plot
|
||||||
from .common import DEFAULT_CACHE_DIR
|
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
@ -17,13 +18,26 @@ if is_gradio_available():
|
|||||||
|
|
||||||
|
|
||||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||||
if finetuning_type != "lora":
|
r"""
|
||||||
|
Judges if the quantization is available in this finetuning type.
|
||||||
|
"""
|
||||||
|
if finetuning_type not in PEFT_METHODS:
|
||||||
return gr.Dropdown(value="none", interactive=False)
|
return gr.Dropdown(value="none", interactive=False)
|
||||||
else:
|
else:
|
||||||
return gr.Dropdown(interactive=True)
|
return gr.Dropdown(interactive=True)
|
||||||
|
|
||||||
|
|
||||||
|
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
|
||||||
|
r"""
|
||||||
|
Modifys states after changing the training stage.
|
||||||
|
"""
|
||||||
|
return [], TRAINING_STAGES[training_stage] == "pt"
|
||||||
|
|
||||||
|
|
||||||
def check_json_schema(text: str, lang: str) -> None:
|
def check_json_schema(text: str, lang: str) -> None:
|
||||||
|
r"""
|
||||||
|
Checks if the json schema is valid.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
tools = json.loads(text)
|
tools = json.loads(text)
|
||||||
if tools:
|
if tools:
|
||||||
@ -38,11 +52,17 @@ def check_json_schema(text: str, lang: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
|
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
r"""
|
||||||
|
Removes args with NoneType or False or empty string value.
|
||||||
|
"""
|
||||||
no_skip_keys = ["packing"]
|
no_skip_keys = ["packing"]
|
||||||
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
||||||
|
|
||||||
|
|
||||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||||
|
r"""
|
||||||
|
Generates arguments for previewing.
|
||||||
|
"""
|
||||||
cmd_lines = ["llamafactory-cli train "]
|
cmd_lines = ["llamafactory-cli train "]
|
||||||
for k, v in clean_cmd(args).items():
|
for k, v in clean_cmd(args).items():
|
||||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||||
@ -52,17 +72,39 @@ def gen_cmd(args: Dict[str, Any]) -> str:
|
|||||||
return cmd_text
|
return cmd_text
|
||||||
|
|
||||||
|
|
||||||
|
def save_cmd(args: Dict[str, Any]) -> str:
|
||||||
|
r"""
|
||||||
|
Saves arguments to launch training.
|
||||||
|
"""
|
||||||
|
output_dir = args["output_dir"]
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
|
||||||
|
safe_dump(clean_cmd(args), f)
|
||||||
|
|
||||||
|
return os.path.join(output_dir, TRAINER_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def get_eval_results(path: os.PathLike) -> str:
|
def get_eval_results(path: os.PathLike) -> str:
|
||||||
|
r"""
|
||||||
|
Gets scores after evaluation.
|
||||||
|
"""
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
result = json.dumps(json.load(f), indent=4)
|
result = json.dumps(json.load(f), indent=4)
|
||||||
return "```json\n{}\n```\n".format(result)
|
return "```json\n{}\n```\n".format(result)
|
||||||
|
|
||||||
|
|
||||||
def get_time() -> str:
|
def get_time() -> str:
|
||||||
|
r"""
|
||||||
|
Gets current date and time.
|
||||||
|
"""
|
||||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||||
|
|
||||||
|
|
||||||
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||||
|
r"""
|
||||||
|
Gets training infomation for monitor.
|
||||||
|
"""
|
||||||
running_log = ""
|
running_log = ""
|
||||||
running_progress = gr.Slider(visible=False)
|
running_progress = gr.Slider(visible=False)
|
||||||
running_loss = None
|
running_loss = None
|
||||||
@ -96,17 +138,56 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
|
|||||||
return running_log, running_progress, running_loss
|
return running_log, running_progress, running_loss
|
||||||
|
|
||||||
|
|
||||||
def save_cmd(args: Dict[str, Any]) -> str:
|
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||||
output_dir = args["output_dir"]
|
r"""
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
Loads saved arguments.
|
||||||
|
"""
|
||||||
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
|
try:
|
||||||
safe_dump(clean_cmd(args), f)
|
with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f:
|
||||||
|
return safe_load(f)
|
||||||
return os.path.join(output_dir, TRAINER_CONFIG)
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def save_ds_config() -> None:
|
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
||||||
|
r"""
|
||||||
|
Saves arguments.
|
||||||
|
"""
|
||||||
|
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||||
|
with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f:
|
||||||
|
safe_dump(config_dict, f)
|
||||||
|
|
||||||
|
return str(get_arg_save_path(config_path))
|
||||||
|
|
||||||
|
|
||||||
|
def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> "gr.Dropdown":
|
||||||
|
r"""
|
||||||
|
Lists all the directories that can resume from.
|
||||||
|
"""
|
||||||
|
output_dirs = [initial_dir]
|
||||||
|
if model_name:
|
||||||
|
save_dir = get_save_dir(model_name, finetuning_type)
|
||||||
|
if save_dir and os.path.isdir(save_dir):
|
||||||
|
for folder in os.listdir(save_dir):
|
||||||
|
output_dir = os.path.join(save_dir, folder)
|
||||||
|
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
|
||||||
|
output_dirs.append(folder)
|
||||||
|
|
||||||
|
return gr.Dropdown(choices=output_dirs)
|
||||||
|
|
||||||
|
|
||||||
|
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
|
||||||
|
r"""
|
||||||
|
Check if output dir exists.
|
||||||
|
"""
|
||||||
|
if os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
|
||||||
|
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
|
||||||
|
|
||||||
|
|
||||||
|
def create_ds_config() -> None:
|
||||||
|
r"""
|
||||||
|
Creates deepspeed config.
|
||||||
|
"""
|
||||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||||
ds_config = {
|
ds_config = {
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user