From 28a6ea1cdc6fba34760c236504fece8df5e642f4 Mon Sep 17 00:00:00 2001 From: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:09:52 +0800 Subject: [PATCH] [v1] add deepspeed zero3 trigger for low memory usage weight loading (#10300) --- .../v1/train_freeze/train_freeze_sft.yaml | 1 - .../v1/train_full/train_full_deepspeed.yaml | 1 - examples/v1/train_full/train_full_fsdp2.yaml | 1 - examples/v1/train_lora/train_lora_sft.yaml | 1 - examples/v1/train_qlora/quantization.yaml | 1 - src/llamafactory/v1/accelerator/interface.py | 2 + src/llamafactory/v1/config/training_args.py | 4 +- src/llamafactory/v1/core/base_trainer.py | 1 + src/llamafactory/v1/core/model_engine.py | 30 ++++- .../plugins/model_plugins/deepspeed_utils.py | 123 ++++++++++++++++++ .../trainer_plugins/distributed/deepspeed.py | 2 + .../trainer_plugins/distributed/fsdp2.py | 9 +- .../trainer_plugins/distributed/hub.py | 2 +- 13 files changed, 160 insertions(+), 18 deletions(-) create mode 100644 src/llamafactory/v1/plugins/model_plugins/deepspeed_utils.py diff --git a/examples/v1/train_freeze/train_freeze_sft.yaml b/examples/v1/train_freeze/train_freeze_sft.yaml index 29233d8e1..11c3a56d6 100644 --- a/examples/v1/train_freeze/train_freeze_sft.yaml +++ b/examples/v1/train_freeze/train_freeze_sft.yaml @@ -29,7 +29,6 @@ micro_batch_size: 1 global_batch_size: 4 cutoff_len: 2048 learning_rate: 2.0e-5 -bf16: false max_steps: 10 ### sample diff --git a/examples/v1/train_full/train_full_deepspeed.yaml b/examples/v1/train_full/train_full_deepspeed.yaml index 0a9851147..665a24f7b 100644 --- a/examples/v1/train_full/train_full_deepspeed.yaml +++ b/examples/v1/train_full/train_full_deepspeed.yaml @@ -19,5 +19,4 @@ output_dir: outputs/Qwen3-0.6B-deepspeed micro_batch_size: 1 cutoff_len: 2048 learning_rate: 1.0e-4 -bf16: true max_steps: 10 diff --git a/examples/v1/train_full/train_full_fsdp2.yaml b/examples/v1/train_full/train_full_fsdp2.yaml index 1378ec30b..f6a7384ca 100644 --- a/examples/v1/train_full/train_full_fsdp2.yaml +++ b/examples/v1/train_full/train_full_fsdp2.yaml @@ -21,7 +21,6 @@ output_dir: outputs/test_fsdp2 micro_batch_size: 1 cutoff_len: 2048 learning_rate: 1.0e-4 -bf16: false max_steps: 10 ### sample diff --git a/examples/v1/train_lora/train_lora_sft.yaml b/examples/v1/train_lora/train_lora_sft.yaml index 653b1df7f..070b14570 100644 --- a/examples/v1/train_lora/train_lora_sft.yaml +++ b/examples/v1/train_lora/train_lora_sft.yaml @@ -29,7 +29,6 @@ output_dir: ./outputs/test_lora micro_batch_size: 1 cutoff_len: 2048 learning_rate: 1.0e-4 -bf16: true max_steps: 10 ### sample diff --git a/examples/v1/train_qlora/quantization.yaml b/examples/v1/train_qlora/quantization.yaml index 6edc9745f..8ad676435 100644 --- a/examples/v1/train_qlora/quantization.yaml +++ b/examples/v1/train_qlora/quantization.yaml @@ -34,7 +34,6 @@ output_dir: outputs/test_quantization micro_batch_size: 1 cutoff_len: 2048 learning_rate: 1.0e-4 -bf16: false max_steps: 10 ### sample diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index 20b02b225..f32b4dc6b 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -123,6 +123,8 @@ class DistributedInterface: if self._initialized: return + self.dist_config = config + helper.set_device_index() self._is_distributed = helper.is_distributed() self._rank = helper.get_rank() diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 30b95f99e..8ede106f8 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -54,7 +54,7 @@ class TrainingArguments: metadata={"help": "Maximum gradient norm for training."}, ) bf16: bool = field( - default=False, + default=True, metadata={"help": "Use bf16 for training."}, ) batching_strategy: BatchingStrategy = field( @@ -66,7 +66,7 @@ class TrainingArguments: metadata={"help": "Number of workers for batching."}, ) enable_activation_checkpointing: bool = field( - default=False, + default=True, metadata={"help": "Enable activation checkpointing for training."}, ) dist_config: PluginConfig | None = field( diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 11cec3c65..c3f1d2c69 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -178,6 +178,7 @@ class BaseTrainer: self.model = DistributedPlugin(self.args.dist_config.name)( self.model, self.args.dist_config, + bf16=self.args.bf16, ) def _init_optimizer(self) -> None: diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index fbdbd6b0e..8e16e5363 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -52,7 +52,11 @@ class ModelEngine: is_train: Whether to train the model. """ - def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None: + def __init__( + self, + model_args: ModelArguments, + is_train: bool = False, + ) -> None: self.args = model_args """Model arguments.""" self.is_train = is_train @@ -63,8 +67,26 @@ class ModelEngine: """Renderer.""" self.model_config = self._init_model_config() """Model configuration.""" - self.model = self._init_model() - """HF model.""" + self._dist_config = DistributedInterface().dist_config + self._deepspeed_zero3_plugin = None + self._deepspeed_zero3_enabled = False + + if self.is_train and self._dist_config is not None and self._dist_config.get("name") == "deepspeed": + from ..plugins.model_plugins.deepspeed_utils import ( + setup_deepspeed_zero3_model_loading, + teardown_deepspeed_zero3_model_loading, + ) + + try: + self._deepspeed_zero3_plugin = setup_deepspeed_zero3_model_loading(self.is_train, self._dist_config) + self._deepspeed_zero3_enabled = self._deepspeed_zero3_plugin is not None + self.model = self._init_model() + finally: + teardown_deepspeed_zero3_model_loading(self._deepspeed_zero3_plugin) + self._deepspeed_zero3_plugin = None + self._deepspeed_zero3_enabled = False + else: + self.model = self._init_model() def _init_processor(self) -> Processor: """Init processor. @@ -97,7 +119,7 @@ class ModelEngine: else: init_device = DistributedInterface().current_device - init_kwargs = {"device_map": init_device} + init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device} if self.args.quant_config is not None: from ..plugins.model_plugins.quantization import QuantizationPlugin diff --git a/src/llamafactory/v1/plugins/model_plugins/deepspeed_utils.py b/src/llamafactory/v1/plugins/model_plugins/deepspeed_utils.py new file mode 100644 index 000000000..de737ed6f --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/deepspeed_utils.py @@ -0,0 +1,123 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from copy import deepcopy +from typing import Any + + +def _normalize_precision_enabled(value: Any) -> bool | str: + if isinstance(value, str): + value_lower = value.lower() + if value_lower == "true": + return True + if value_lower == "false": + return False + if value_lower == "auto": + return "auto" + return value + + +def infer_deepspeed_mixed_precision(ds_config: dict[str, Any]) -> str: + ds_config.setdefault("fp16", {}) + ds_config.setdefault("bf16", {}) + + fp16_enabled = _normalize_precision_enabled(ds_config["fp16"].get("enabled", "auto")) + bf16_enabled = _normalize_precision_enabled(ds_config["bf16"].get("enabled", "auto")) + + # This project only supports DeepSpeed bf16 or no mixed precision. + if fp16_enabled is True: + raise ValueError("DeepSpeed only supports bf16 mixed precision for now, fp16 is not supported.") + + if bf16_enabled is True: + mixed_precision = "bf16" + elif bf16_enabled is False: + mixed_precision = "no" + elif fp16_enabled is False: + mixed_precision = "no" + else: + # When both bf16/fp16 are left as auto (or absent), default to bf16. + mixed_precision = "bf16" + + ds_config["fp16"]["enabled"] = False + ds_config["bf16"]["enabled"] = mixed_precision == "bf16" + return mixed_precision + + +def _unset_hf_deepspeed_config() -> None: + try: + from transformers.integrations import unset_hf_deepspeed_config + except ImportError: + from transformers.deepspeed import unset_hf_deepspeed_config + + unset_hf_deepspeed_config() + + +def _load_deepspeed_config(config_file: str) -> dict[str, Any]: + with open(config_file, encoding="utf-8") as f: + return json.load(f) + + +def setup_deepspeed_zero3_model_loading(is_train: bool, dist_config: dict[str, Any] | None): + """Enable transformers' ZeRO-3-aware model loading for the current thread.""" + config_file = dist_config.get("config_file") + if not config_file: + raise ValueError("DeepSpeed config_file is required in dist_config") + + from accelerate.utils import DeepSpeedPlugin + + try: + from transformers.integrations import is_deepspeed_zero3_enabled + except ImportError: + from transformers.deepspeed import is_deepspeed_zero3_enabled + + # DeepSpeed configs often use "auto" placeholders that only make sense once + # we know the current runtime batch settings and precision mode. + ds_config = deepcopy(_load_deepspeed_config(config_file)) + if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto": + ds_config["gradient_accumulation_steps"] = 1 + if "train_micro_batch_size_per_gpu" not in ds_config or ds_config["train_micro_batch_size_per_gpu"] == "auto": + ds_config["train_micro_batch_size_per_gpu"] = 1 + if ds_config.get("train_batch_size") == "auto": + ds_config.pop("train_batch_size") + + zero_stage = ds_config.get("zero_optimization", {}).get("stage") + if zero_stage != 3: + return None + + # ZeRO-3 model loading needs concrete fp16/bf16 flags, not "auto". + mixed_precision = infer_deepspeed_mixed_precision(ds_config) + + plugin = DeepSpeedPlugin(hf_ds_config=ds_config, zero3_init_flag=True) + + if not plugin.hf_ds_config.is_zero3(): + return None + + # Reuse the same precision inference rule as the training-time DeepSpeed path + # so both model-loading and engine setup stay aligned. + plugin.set_mixed_precision(mixed_precision) + plugin.set_deepspeed_weakref() + + if not is_deepspeed_zero3_enabled(): + raise RuntimeError( + "DeepSpeed ZeRO-3 model-loading bootstrap failed: transformers still reports zero3 disabled " + "after constructing HfDeepSpeedConfig. This usually means the runtime is using a different transformers " + "installation than expected, or the DeepSpeed global state was not established correctly." + ) + return plugin + + +def teardown_deepspeed_zero3_model_loading(plugin) -> None: + if plugin is not None: + _unset_hf_deepspeed_config() diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py index a68b1f8ab..105478431 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py @@ -28,6 +28,7 @@ from accelerate.utils import DeepSpeedPlugin from ....utils.logging import get_logger from ....utils.types import HFModel, Processor +from ...model_plugins.deepspeed_utils import infer_deepspeed_mixed_precision logger = get_logger(__name__) @@ -51,6 +52,7 @@ class DeepSpeedEngine: raise ValueError("DeepSpeed config_file is required in dist_config") ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file) + ds_plugin.set_mixed_precision(infer_deepspeed_mixed_precision(ds_plugin.deepspeed_config)) self.accelerator = Accelerator( deepspeed_plugin=ds_plugin, diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index ddc9fb7f8..32b424443 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -119,12 +119,12 @@ def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: class FSDP2Engine: - def __init__(self, dist_config: dict): + def __init__(self, dist_config: dict, bf16: bool = False): self.dist_interface = DistributedInterface() self.rank = self.dist_interface.get_rank() self.local_rank = self.dist_interface.get_local_rank() self.world_size = self.dist_interface.get_world_size() - self.mixed_precision = dist_config.get("mixed_precision", "bf16") + self.mixed_precision = "bf16" if bf16 else "fp32" self.reshard_after_forward = dist_config.get("reshard_after_forward", True) self.offload_params = dist_config.get("offload_params", False) self.pin_memory = dist_config.get("pin_memory", True) @@ -147,10 +147,7 @@ class FSDP2Engine: if self.mixed_precision == "bf16": param_dtype = torch.bfloat16 reduce_dtype = torch.float32 - elif self.mixed_precision == "fp16": - param_dtype = torch.float16 - reduce_dtype = torch.float32 - else: + elif self.mixed_precision == "fp32": param_dtype = torch.float32 reduce_dtype = torch.float32 diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py index 47f64a3a6..f7389b28d 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py @@ -35,7 +35,7 @@ class DistributedPlugin(BasePlugin): def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel: from .fsdp2 import FSDP2Engine - return FSDP2Engine(dist_config).shard_model(model) + return FSDP2Engine(dist_config, bf16=bool(kwargs.get("bf16"))).shard_model(model) @DistributedPlugin("fsdp2").register("save_model")