diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index d399106f..8db5c2ba 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Dict from transformers.trainer_utils import SchedulerType from ...extras.constants import TRAINING_STAGES +from ...extras.misc import get_device_count from ...extras.packages import is_gradio_available from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset from ..components.data import create_preview_box @@ -258,6 +259,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: output_dir = gr.Textbox() config_path = gr.Textbox() + with gr.Row(): + device_count = gr.Textbox(value=str(get_device_count()), interactive=False) + ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none") + ds_offload = gr.Checkbox() + with gr.Row(): resume_btn = gr.Checkbox(visible=False, interactive=False) progress_bar = gr.Slider(visible=False, interactive=False) @@ -268,6 +274,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Column(scale=1): loss_viewer = gr.Plot() + input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload}) elem_dict.update( dict( cmd_preview_btn=cmd_preview_btn, @@ -277,14 +284,15 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: stop_btn=stop_btn, output_dir=output_dir, config_path=config_path, + device_count=device_count, + ds_stage=ds_stage, + ds_offload=ds_offload, resume_btn=resume_btn, progress_bar=progress_bar, output_box=output_box, loss_viewer=loss_viewer, ) ) - - input_elems.update({output_dir, config_path}) output_elems = [output_box, progress_bar, loss_viewer] cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py index 964d65a2..fb568737 100644 --- a/src/llamafactory/webui/engine.py +++ b/src/llamafactory/webui/engine.py @@ -5,7 +5,7 @@ from .common import get_model_path, list_dataset, load_config from .locales import LOCALES from .manager import Manager from .runner import Runner -from .utils import get_time +from .utils import get_time, save_ds_config if TYPE_CHECKING: @@ -19,6 +19,8 @@ class Engine: self.manager = Manager() self.runner = Runner(self.manager, demo_mode) self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) + if not demo_mode: + save_ds_config() def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: r""" diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 570a8b42..4657f9a3 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1103,6 +1103,48 @@ LOCALES = { "info": "保存训练参数的配置文件路径。", }, }, + "device_count": { + "en": { + "label": "Device count", + "info": "Number of devices available.", + }, + "ru": { + "label": "Количество устройств", + "info": "Количество доступных устройств.", + }, + "zh": { + "label": "设备数量", + "info": "当前可用的运算设备数。", + }, + }, + "ds_stage": { + "en": { + "label": "DeepSpeed stage", + "info": "DeepSpeed stage for distributed training.", + }, + "ru": { + "label": "Этап DeepSpeed", + "info": "Этап DeepSpeed для распределенного обучения.", + }, + "zh": { + "label": "DeepSpeed stage", + "info": "多卡训练的 DeepSpeed stage。", + }, + }, + "ds_offload": { + "en": { + "label": "Enable offload", + "info": "Enable DeepSpeed offload (slow down training).", + }, + "ru": { + "label": "Включить выгрузку", + "info": "включить выгрузку DeepSpeed (замедлит обучение).", + }, + "zh": { + "label": "使用 offload", + "info": "使用 DeepSpeed offload(会减慢速度)。", + }, + }, "output_box": { "en": { "value": "Ready.", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 1310b999..c2e46e97 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -10,7 +10,7 @@ from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.packages import is_gradio_available -from .common import get_module, get_save_dir, load_args, load_config, save_args +from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_args, load_config, save_args from .locales import ALERTS from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd @@ -201,6 +201,12 @@ class Runner: args["eval_steps"] = args["save_steps"] args["per_device_eval_batch_size"] = args["per_device_train_batch_size"] + # ds config + if get("train.ds_stage") != "none": + ds_stage = get("train.ds_stage") + ds_offload = "offload_" if get("train.ds_offload") else "" + args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload)) + return args def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index ceeb9352..654d1f8d 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -8,6 +8,7 @@ from yaml import safe_dump from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.ploting import gen_loss_plot +from .common import DEFAULT_CACHE_DIR from .locales import ALERTS @@ -103,3 +104,63 @@ def save_cmd(args: Dict[str, Any]) -> str: safe_dump(clean_cmd(args), f) return os.path.join(output_dir, TRAINER_CONFIG) + + +def save_ds_config() -> None: + os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) + ds_config = { + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": True, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1, + }, + "bf16": {"enabled": "auto"}, + } + offload_config = { + "device": "cpu", + "pin_memory": True, + } + ds_config["zero_optimization"] = { + "stage": 2, + "allgather_partitions": True, + "allgather_bucket_size": 5e8, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 5e8, + "contiguous_gradients": True, + "round_robin_gradients": True, + } + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"]["offload_optimizer"] = offload_config + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"] = { + "stage": 3, + "overlap_comm": True, + "contiguous_gradients": True, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": True, + } + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"]["offload_optimizer"] = offload_config + ds_config["zero_optimization"]["offload_param"] = offload_config + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2)