better llamaboard

* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui


Former-commit-id: 80708717329b4552920dd4ce8cebc683e65d54c5
This commit is contained in:
hiyouga 2024-05-29 23:55:38 +08:00
parent 19a3262387
commit 820404946e
14 changed files with 303 additions and 193 deletions

View File

@ -1,4 +1,4 @@
# Level: api, webui > chat, eval, train > data, model > extras, hparams # Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION from .cli import VERSION

View File

@ -2,6 +2,19 @@ from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
}
CHOICES = ["A", "B", "C", "D"] CHOICES = ["A", "B", "C", "D"]
@ -26,9 +39,9 @@ LAYERNORM_NAMES = {"norm", "ln"}
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"] MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
PEFT_METHODS = ["lora"] PEFT_METHODS = {"lora"}
RUNNING_LOG = "running_log.txt" RUNNING_LOG = "running_log.txt"
@ -49,9 +62,9 @@ TRAINING_STAGES = {
"Pre-Training": "pt", "Pre-Training": "pt",
} }
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"] STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"

View File

@ -11,6 +11,7 @@ from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments from .data_args import DataArguments
@ -255,13 +256,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and can_resume_from_checkpoint and can_resume_from_checkpoint
): ):
last_checkpoint = get_last_checkpoint(training_args.output_dir) last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and any(
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None: if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint training_args.resume_from_checkpoint = last_checkpoint
logger.info( logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format( logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
training_args.resume_from_checkpoint
)
)
if ( if (
finetuning_args.stage in ["rm", "ppo"] finetuning_args.stage in ["rm", "ppo"]

View File

@ -6,6 +6,7 @@ from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import get_save_dir from .common import get_save_dir
@ -44,13 +45,14 @@ class WebChatModel(ChatModel):
def load_model(self, data) -> Generator[str, None, None]: def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang") lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
error = "" error = ""
if self.loaded: if self.loaded:
error = ALERTS["err_exists"][lang] error = ALERTS["err_exists"][lang]
elif not get("top.model_name"): elif not model_name:
error = ALERTS["err_no_model"][lang] error = ALERTS["err_no_model"][lang]
elif not get("top.model_path"): elif not model_path:
error = ALERTS["err_no_path"][lang] error = ALERTS["err_no_path"][lang]
elif self.demo_mode: elif self.demo_mode:
error = ALERTS["err_demo"][lang] error = ALERTS["err_demo"][lang]
@ -60,21 +62,10 @@ class WebChatModel(ChatModel):
yield error yield error
return return
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
yield ALERTS["info_loading"][lang] yield ALERTS["info_loading"][lang]
args = dict( args = dict(
model_name_or_path=get("top.model_path"), model_name_or_path=model_path,
adapter_name_or_path=adapter_name_or_path, finetuning_type=finetuning_type,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@ -83,8 +74,16 @@ class WebChatModel(ChatModel):
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"), infer_backend=get("infer.infer_backend"),
) )
super().__init__(args)
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
super().__init__(args)
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]
def unload_model(self, data) -> Generator[str, None, None]: def unload_model(self, data) -> Generator[str, None, None]:

View File

@ -1,12 +1,12 @@
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from yaml import safe_dump, safe_load from yaml import safe_dump, safe_load
from ..extras.constants import ( from ..extras.constants import (
CHECKPOINT_NAMES,
DATA_CONFIG, DATA_CONFIG,
DEFAULT_MODULE, DEFAULT_MODULE,
DEFAULT_TEMPLATE, DEFAULT_TEMPLATE,
@ -29,7 +29,6 @@ if is_gradio_available():
logger = get_logger(__name__) logger = get_logger(__name__)
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config" DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data" DEFAULT_DATA_DIR = "data"
@ -38,19 +37,31 @@ USER_CONFIG = "user_config.yaml"
def get_save_dir(*paths: str) -> os.PathLike: def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
"""
paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths) paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths)
return os.path.join(DEFAULT_SAVE_DIR, *paths) return os.path.join(DEFAULT_SAVE_DIR, *paths)
def get_config_path() -> os.PathLike: def get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def get_save_path(config_path: str) -> os.PathLike: def get_arg_save_path(config_path: str) -> os.PathLike:
r"""
Gets the path to saved arguments.
"""
return os.path.join(DEFAULT_CONFIG_DIR, config_path) return os.path.join(DEFAULT_CONFIG_DIR, config_path)
def load_config() -> Dict[str, Any]: def load_config() -> Dict[str, Any]:
r"""
Loads user config if exists.
"""
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
return safe_load(f) return safe_load(f)
@ -59,6 +70,9 @@ def load_config() -> Dict[str, Any]:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
r"""
Saves user config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config() user_config = load_config()
user_config["lang"] = lang or user_config["lang"] user_config["lang"] = lang or user_config["lang"]
@ -69,23 +83,10 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
safe_dump(user_config, f) safe_dump(user_config, f)
def load_args(config_path: str) -> Optional[Dict[str, Any]]: def get_model_path(model_name: str) -> Optional[str]:
try: r"""
with open(get_save_path(config_path), "r", encoding="utf-8") as f: Gets the model path according to the model name.
return safe_load(f) """
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
return str(get_save_path(config_path))
def get_model_path(model_name: str) -> str:
user_config = load_config() user_config = load_config()
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None) model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
@ -99,40 +100,71 @@ def get_model_path(model_name: str) -> str:
def get_prefix(model_name: str) -> str: def get_prefix(model_name: str) -> str:
r"""
Gets the prefix of the model name to obtain the model family.
"""
return model_name.split("-")[0] return model_name.split("-")[0]
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
r"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
visual (bool)
"""
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
def get_module(model_name: str) -> str: def get_module(model_name: str) -> str:
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj") r"""
Gets the LoRA modules of this model.
"""
return DEFAULT_MODULE.get(get_prefix(model_name), "all")
def get_template(model_name: str) -> str: def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat model.
"""
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[get_prefix(model_name)] return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default" return "default"
def get_visual(model_name: str) -> bool: def get_visual(model_name: str) -> bool:
r"""
Judges if the model is a vision language model.
"""
return get_prefix(model_name) in VISION_MODELS return get_prefix(model_name) in VISION_MODELS
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown": def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
if finetuning_type not in PEFT_METHODS: r"""
return gr.Dropdown(value=[], choices=[], interactive=False) Lists all available checkpoints.
"""
adapters = [] checkpoints = []
if model_name and finetuning_type == "lora": if model_name:
save_dir = get_save_dir(model_name, finetuning_type) save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir): if save_dir and os.path.isdir(save_dir):
for adapter in os.listdir(save_dir): for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, adapter)) and any( if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
): ):
adapters.append(adapter) checkpoints.append(checkpoint)
return gr.Dropdown(value=[], choices=adapters, interactive=True)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
r"""
Loads dataset_info.json.
"""
if dataset_dir == "ONLINE": if dataset_dir == "ONLINE":
logger.info("dataset_dir is ONLINE, using online dataset.") logger.info("dataset_dir is ONLINE, using online dataset.")
return {} return {}
@ -145,12 +177,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {} return {}
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
r"""
Lists all available datasets in the dataset dir for the training stage.
"""
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(value=[], choices=datasets) return gr.Dropdown(choices=datasets)
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button":
return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt"))

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_dataset from ..common import DEFAULT_DATA_DIR, list_datasets
from .data import create_preview_box from .data import create_preview_box
@ -74,6 +74,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort) stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False)
return elem_dict return elem_dict

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List, Union
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ...train.tuner import export_model from ...train.tuner import export_model
@ -24,8 +25,8 @@ def save_model(
lang: str, lang: str,
model_name: str, model_name: str,
model_path: str, model_path: str,
adapter_path: List[str],
finetuning_type: str, finetuning_type: str,
checkpoint_path: Union[str, List[str]],
template: str, template: str,
visual_inputs: bool, visual_inputs: bool,
export_size: int, export_size: int,
@ -45,9 +46,9 @@ def save_model(
error = ALERTS["err_no_export_dir"][lang] error = ALERTS["err_no_export_dir"][lang]
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset: elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
error = ALERTS["err_no_dataset"][lang] error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not adapter_path: elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path:
error = ALERTS["err_no_adapter"][lang] error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and adapter_path: elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list):
error = ALERTS["err_gptq_lora"][lang] error = ALERTS["err_gptq_lora"][lang]
if error: if error:
@ -55,16 +56,8 @@ def save_model(
yield error yield error
return return
if adapter_path:
adapter_name_or_path = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
)
else:
adapter_name_or_path = None
args = dict( args = dict(
model_name_or_path=model_path, model_name_or_path=model_path,
adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template, template=template,
visual_inputs=visual_inputs, visual_inputs=visual_inputs,
@ -77,6 +70,14 @@ def save_model(
export_legacy_format=export_legacy_format, export_legacy_format=export_legacy_format,
) )
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
export_model(args) export_model(args)
torch_gc() torch_gc()
@ -86,7 +87,7 @@ def save_model(
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1) export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none") export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu") export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
export_legacy_format = gr.Checkbox() export_legacy_format = gr.Checkbox()
@ -104,8 +105,8 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.lang"), engine.manager.get_elem_by_id("top.lang"),
engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.model_name"),
engine.manager.get_elem_by_id("top.model_path"), engine.manager.get_elem_by_id("top.model_path"),
engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"), engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.checkpoint_path"),
engine.manager.get_elem_by_id("top.template"), engine.manager.get_elem_by_id("top.template"),
engine.manager.get_elem_by_id("top.visual_inputs"), engine.manager.get_elem_by_id("top.visual_inputs"),
export_size, export_size,

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
from ...data import templates from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config from ..common import get_model_info, list_checkpoints, save_config
from ..utils import can_quantize from ..utils import can_quantize
@ -25,8 +25,7 @@ def create_top() -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5) checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(open=False) as advanced_tab: with gr.Accordion(open=False) as advanced_tab:
with gr.Row(): with gr.Row():
@ -36,27 +35,17 @@ def create_top() -> Dict[str, "Component"]:
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1) visual_inputs = gr.Checkbox(scale=1)
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
get_model_path, [model_name], [model_path], queue=False
).then(get_template, [model_name], [template], queue=False).then(
get_visual, [model_name], [visual_inputs], queue=False
) # do not save config since the below line will save
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
can_quantize, [finetuning_type], [quantization_bit], queue=False
)
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict( return dict(
lang=lang, lang=lang,
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
adapter_path=adapter_path, checkpoint_path=checkpoint_path,
refresh_btn=refresh_btn,
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template, template=template,

View File

@ -5,8 +5,9 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from ...extras.misc import get_device_count from ...extras.misc import get_device_count
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
from ..components.data import create_preview_box from ..utils import change_stage, check_output_dir, list_output_dirs
from .data import create_preview_box
if is_gradio_available(): if is_gradio_available():
@ -256,11 +257,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
with gr.Row(): with gr.Row():
output_dir = gr.Textbox() initial_dir = gr.Textbox(visible=False, interactive=False)
output_dir = gr.Dropdown(allow_custom_value=True)
config_path = gr.Textbox() config_path = gr.Textbox()
with gr.Row(): with gr.Row():
device_count = gr.Textbox(value=str(get_device_count()), interactive=False) device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False)
ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none") ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
ds_offload = gr.Checkbox() ds_offload = gr.Checkbox()
@ -282,6 +284,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
arg_load_btn=arg_load_btn, arg_load_btn=arg_load_btn,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
initial_dir=initial_dir,
output_dir=output_dir, output_dir=output_dir,
config_path=config_path, config_path=config_path,
device_count=device_count, device_count=device_count,
@ -295,24 +298,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
output_elems = [output_box, progress_bar, loss_viewer] output_elems = [output_box, progress_bar, loss_viewer]
lang = engine.manager.get_elem_by_id("top.lang")
model_name = engine.manager.get_elem_by_id("top.model_name")
finetuning_type = engine.manager.get_elem_by_id("top.finetuning_type")
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click( arg_load_btn.click(
engine.runner.load_args, engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems) + [output_box],
concurrency_limit=None,
) )
start_btn.click(engine.runner.run_train, input_elems, output_elems) start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort) stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then( dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
list_adapters, reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")], output_dir.change(
[reward_model], list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None
queue=False, ).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
).then(autoset_packing, [training_stage], [packing], queue=False)
return elem_dict return elem_dict

View File

@ -1,11 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict
from .chatter import WebChatModel from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config from .common import load_config
from .locales import LOCALES from .locales import LOCALES
from .manager import Manager from .manager import Manager
from .runner import Runner from .runner import Runner
from .utils import get_time, save_ds_config from .utils import create_ds_config, get_time
if TYPE_CHECKING: if TYPE_CHECKING:
@ -20,7 +20,7 @@ class Engine:
self.runner = Runner(self.manager, demo_mode) self.runner = Runner(self.manager, demo_mode)
self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
if not demo_mode: if not demo_mode:
save_ds_config() create_ds_config()
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
r""" r"""
@ -40,16 +40,15 @@ class Engine:
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
if not self.pure_chat: if not self.pure_chat:
init_dict["train.dataset"] = {"choices": list_dataset().choices} current_time = get_time()
init_dict["eval.dataset"] = {"choices": list_dataset().choices} init_dict["train.initial_dir"] = {"value": "train_{}".format(current_time)}
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())} init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())} init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())} init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
init_dict["infer.image_box"] = {"visible": False} init_dict["infer.image_box"] = {"visible": False}
if user_config.get("last_model", None): if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]} init_dict["top.model_name"] = {"value": user_config["last_model"]}
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
yield self._update_component(init_dict) yield self._update_component(init_dict)

View File

@ -46,26 +46,15 @@ LOCALES = {
"label": "微调方法", "label": "微调方法",
}, },
}, },
"adapter_path": { "checkpoint_path": {
"en": { "en": {
"label": "Adapter path", "label": "Checkpoint path",
}, },
"ru": { "ru": {
"label": "Путь к адаптеру", "label": "Путь контрольной точки",
}, },
"zh": { "zh": {
"label": "适配器路径", "label": "检查点路径",
},
},
"refresh_btn": {
"en": {
"value": "Refresh adapters",
},
"ru": {
"value": "Обновить адаптеры",
},
"zh": {
"value": "刷新适配器",
}, },
}, },
"advanced_tab": { "advanced_tab": {
@ -1531,6 +1520,11 @@ ALERTS = {
"ru": "Среда CUDA не обнаружена.", "ru": "Среда CUDA не обнаружена.",
"zh": "未检测到 CUDA 环境。", "zh": "未检测到 CUDA 环境。",
}, },
"warn_output_dir_exists": {
"en": "Output dir already exists, will resume training from here.",
"ru": "Выходной каталог уже существует, обучение будет продолжено отсюда.",
"zh": "输出目录已存在,将从该断点恢复训练。",
},
"info_aborting": { "info_aborting": {
"en": "Aborted, wait for terminating...", "en": "Aborted, wait for terminating...",
"ru": "Прервано, ожидание завершения...", "ru": "Прервано, ожидание завершения...",

View File

@ -55,7 +55,7 @@ class Manager:
self._id_to_elem["top.model_name"], self._id_to_elem["top.model_name"],
self._id_to_elem["top.model_path"], self._id_to_elem["top.model_path"],
self._id_to_elem["top.finetuning_type"], self._id_to_elem["top.finetuning_type"],
self._id_to_elem["top.adapter_path"], self._id_to_elem["top.checkpoint_path"],
self._id_to_elem["top.quantization_bit"], self._id_to_elem["top.quantization_bit"],
self._id_to_elem["top.template"], self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"], self._id_to_elem["top.rope_scaling"],

View File

@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import psutil import psutil
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import TRAINING_STAGES from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_args, load_config, save_args from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
from .locales import ALERTS from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
if is_gradio_available(): if is_gradio_available():
@ -85,26 +85,16 @@ class Runner:
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
args = dict( args = dict(
stage=TRAINING_STAGES[get("train.training_stage")], stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True, do_train=True,
model_name_or_path=get("top.model_path"), model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16, preprocessing_num_workers=16,
finetuning_type=get("top.finetuning_type"), finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
@ -134,13 +124,23 @@ class Runner:
report_to="all" if get("train.report_to") else "none", report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"), use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"), use_badam=get("train.use_badam"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")), output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"), fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"), bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"), pure_bf16=(get("train.compute_type") == "pure_bf16"),
plot_loss=True, plot_loss=True,
ddp_timeout=180000000,
) )
# checkpoints
if get("top.checkpoint_path"):
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# freeze config # freeze config
if args["finetuning_type"] == "freeze": if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
@ -156,7 +156,7 @@ class Runner:
args["create_new_adapter"] = get("train.create_new_adapter") args["create_new_adapter"] = get("train.create_new_adapter")
args["use_rslora"] = get("train.use_rslora") args["use_rslora"] = get("train.use_rslora")
args["use_dora"] = get("train.use_dora") args["use_dora"] = get("train.use_dora")
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name")) args["lora_target"] = get("train.lora_target") or get_module(model_name)
args["additional_target"] = get("train.additional_target") or None args["additional_target"] = get("train.additional_target") or None
if args["use_llama_pro"]: if args["use_llama_pro"]:
@ -164,13 +164,14 @@ class Runner:
# rlhf config # rlhf config
if args["stage"] == "ppo": if args["stage"] == "ppo":
args["reward_model"] = ",".join( if finetuning_type in PEFT_METHODS:
[ args["reward_model"] = ",".join(
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")]
for adapter in get("train.reward_model") )
] else:
) args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model"))
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full"
args["ppo_score_norm"] = get("train.ppo_score_norm") args["ppo_score_norm"] = get("train.ppo_score_norm")
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards") args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
args["top_k"] = 0 args["top_k"] = 0
@ -211,25 +212,15 @@ class Runner:
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
args = dict( args = dict(
stage="sft", stage="sft",
model_name_or_path=get("top.model_path"), model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16, preprocessing_num_workers=16,
finetuning_type=get("top.finetuning_type"), finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"), template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
@ -245,7 +236,7 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"), max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"), top_p=get("eval.top_p"),
temperature=get("eval.temperature"), temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")), output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
) )
if get("eval.predict"): if get("eval.predict"):
@ -253,6 +244,14 @@ class Runner:
else: else:
args["do_eval"] = True args["do_eval"] = True
if get("top.checkpoint_path"):
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
return args return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
@ -296,9 +295,7 @@ class Runner:
self.running = True self.running = True
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang") lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type")
model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type")
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval")) output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
output_path = get_save_dir(model_name, finetuning_type, output_dir) output_path = get_save_dir(model_name, finetuning_type, output_dir)
@ -356,7 +353,7 @@ class Runner:
config_dict: Dict[str, Any] = {} config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")] lang = data[self.manager.get_elem_by_id("top.lang")]
config_path = data[self.manager.get_elem_by_id("train.config_path")] config_path = data[self.manager.get_elem_by_id("train.config_path")]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
for elem, value in data.items(): for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem) elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids: if elem_id not in skip_ids:

View File

@ -3,12 +3,13 @@ import os
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from yaml import safe_dump from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot from ..extras.ploting import gen_loss_plot
from .common import DEFAULT_CACHE_DIR from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir
from .locales import ALERTS from .locales import ALERTS
@ -17,13 +18,26 @@ if is_gradio_available():
def can_quantize(finetuning_type: str) -> "gr.Dropdown": def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora": r"""
Judges if the quantization is available in this finetuning type.
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False) return gr.Dropdown(value="none", interactive=False)
else: else:
return gr.Dropdown(interactive=True) return gr.Dropdown(interactive=True)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def check_json_schema(text: str, lang: str) -> None: def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try: try:
tools = json.loads(text) tools = json.loads(text)
if tools: if tools:
@ -38,11 +52,17 @@ def check_json_schema(text: str, lang: str) -> None:
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]: def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"] no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str: def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates arguments for previewing.
"""
cmd_lines = ["llamafactory-cli train "] cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items(): for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v))) cmd_lines.append(" --{} {} ".format(k, str(v)))
@ -52,17 +72,39 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return cmd_text return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves arguments to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)
def get_eval_results(path: os.PathLike) -> str: def get_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4) result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result) return "```json\n{}\n```\n".format(result)
def get_time() -> str: def get_time() -> str:
r"""
Gets current date and time.
"""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
r"""
Gets training infomation for monitor.
"""
running_log = "" running_log = ""
running_progress = gr.Slider(visible=False) running_progress = gr.Slider(visible=False)
running_loss = None running_loss = None
@ -96,17 +138,56 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
return running_log, running_progress, running_loss return running_log, running_progress, running_loss
def save_cmd(args: Dict[str, Any]) -> str: def load_args(config_path: str) -> Optional[Dict[str, Any]]:
output_dir = args["output_dir"] r"""
os.makedirs(output_dir, exist_ok=True) Loads saved arguments.
"""
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f: try:
safe_dump(clean_cmd(args), f) with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f:
return safe_load(f)
return os.path.join(output_dir, TRAINER_CONFIG) except Exception:
return None
def save_ds_config() -> None: def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
r"""
Saves arguments.
"""
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
return str(get_arg_save_path(config_path))
def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
"""
output_dirs = [initial_dir]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for folder in os.listdir(save_dir):
output_dir = os.path.join(save_dir, folder)
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
r"""
Check if output dir exists.
"""
if os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
def create_ds_config() -> None:
r"""
Creates deepspeed config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = { ds_config = {
"train_batch_size": "auto", "train_batch_size": "auto",