[v1] add deepspeed zero3 trigger for low memory usage weight loading (#10300)

This commit is contained in:
jiaqiw09
2026-04-21 14:09:52 +08:00
committed by GitHub
parent f5d739b132
commit 28a6ea1cdc
13 changed files with 160 additions and 18 deletions

View File

@@ -29,7 +29,6 @@ micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 2.0e-5
bf16: false
max_steps: 10
### sample

View File

@@ -19,5 +19,4 @@ output_dir: outputs/Qwen3-0.6B-deepspeed
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10

View File

@@ -21,7 +21,6 @@ output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample

View File

@@ -29,7 +29,6 @@ output_dir: ./outputs/test_lora
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample

View File

@@ -34,7 +34,6 @@ output_dir: outputs/test_quantization
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample

View File

@@ -123,6 +123,8 @@ class DistributedInterface:
if self._initialized:
return
self.dist_config = config
helper.set_device_index()
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()

View File

@@ -54,7 +54,7 @@ class TrainingArguments:
metadata={"help": "Maximum gradient norm for training."},
)
bf16: bool = field(
default=False,
default=True,
metadata={"help": "Use bf16 for training."},
)
batching_strategy: BatchingStrategy = field(
@@ -66,7 +66,7 @@ class TrainingArguments:
metadata={"help": "Number of workers for batching."},
)
enable_activation_checkpointing: bool = field(
default=False,
default=True,
metadata={"help": "Enable activation checkpointing for training."},
)
dist_config: PluginConfig | None = field(

View File

@@ -178,6 +178,7 @@ class BaseTrainer:
self.model = DistributedPlugin(self.args.dist_config.name)(
self.model,
self.args.dist_config,
bf16=self.args.bf16,
)
def _init_optimizer(self) -> None:

View File

@@ -52,7 +52,11 @@ class ModelEngine:
is_train: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
def __init__(
self,
model_args: ModelArguments,
is_train: bool = False,
) -> None:
self.args = model_args
"""Model arguments."""
self.is_train = is_train
@@ -63,8 +67,26 @@ class ModelEngine:
"""Renderer."""
self.model_config = self._init_model_config()
"""Model configuration."""
self._dist_config = DistributedInterface().dist_config
self._deepspeed_zero3_plugin = None
self._deepspeed_zero3_enabled = False
if self.is_train and self._dist_config is not None and self._dist_config.get("name") == "deepspeed":
from ..plugins.model_plugins.deepspeed_utils import (
setup_deepspeed_zero3_model_loading,
teardown_deepspeed_zero3_model_loading,
)
try:
self._deepspeed_zero3_plugin = setup_deepspeed_zero3_model_loading(self.is_train, self._dist_config)
self._deepspeed_zero3_enabled = self._deepspeed_zero3_plugin is not None
self.model = self._init_model()
finally:
teardown_deepspeed_zero3_model_loading(self._deepspeed_zero3_plugin)
self._deepspeed_zero3_plugin = None
self._deepspeed_zero3_enabled = False
else:
self.model = self._init_model()
"""HF model."""
def _init_processor(self) -> Processor:
"""Init processor.
@@ -97,7 +119,7 @@ class ModelEngine:
else:
init_device = DistributedInterface().current_device
init_kwargs = {"device_map": init_device}
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin

View File

@@ -0,0 +1,123 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from copy import deepcopy
from typing import Any
def _normalize_precision_enabled(value: Any) -> bool | str:
if isinstance(value, str):
value_lower = value.lower()
if value_lower == "true":
return True
if value_lower == "false":
return False
if value_lower == "auto":
return "auto"
return value
def infer_deepspeed_mixed_precision(ds_config: dict[str, Any]) -> str:
ds_config.setdefault("fp16", {})
ds_config.setdefault("bf16", {})
fp16_enabled = _normalize_precision_enabled(ds_config["fp16"].get("enabled", "auto"))
bf16_enabled = _normalize_precision_enabled(ds_config["bf16"].get("enabled", "auto"))
# This project only supports DeepSpeed bf16 or no mixed precision.
if fp16_enabled is True:
raise ValueError("DeepSpeed only supports bf16 mixed precision for now, fp16 is not supported.")
if bf16_enabled is True:
mixed_precision = "bf16"
elif bf16_enabled is False:
mixed_precision = "no"
elif fp16_enabled is False:
mixed_precision = "no"
else:
# When both bf16/fp16 are left as auto (or absent), default to bf16.
mixed_precision = "bf16"
ds_config["fp16"]["enabled"] = False
ds_config["bf16"]["enabled"] = mixed_precision == "bf16"
return mixed_precision
def _unset_hf_deepspeed_config() -> None:
try:
from transformers.integrations import unset_hf_deepspeed_config
except ImportError:
from transformers.deepspeed import unset_hf_deepspeed_config
unset_hf_deepspeed_config()
def _load_deepspeed_config(config_file: str) -> dict[str, Any]:
with open(config_file, encoding="utf-8") as f:
return json.load(f)
def setup_deepspeed_zero3_model_loading(is_train: bool, dist_config: dict[str, Any] | None):
"""Enable transformers' ZeRO-3-aware model loading for the current thread."""
config_file = dist_config.get("config_file")
if not config_file:
raise ValueError("DeepSpeed config_file is required in dist_config")
from accelerate.utils import DeepSpeedPlugin
try:
from transformers.integrations import is_deepspeed_zero3_enabled
except ImportError:
from transformers.deepspeed import is_deepspeed_zero3_enabled
# DeepSpeed configs often use "auto" placeholders that only make sense once
# we know the current runtime batch settings and precision mode.
ds_config = deepcopy(_load_deepspeed_config(config_file))
if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto":
ds_config["gradient_accumulation_steps"] = 1
if "train_micro_batch_size_per_gpu" not in ds_config or ds_config["train_micro_batch_size_per_gpu"] == "auto":
ds_config["train_micro_batch_size_per_gpu"] = 1
if ds_config.get("train_batch_size") == "auto":
ds_config.pop("train_batch_size")
zero_stage = ds_config.get("zero_optimization", {}).get("stage")
if zero_stage != 3:
return None
# ZeRO-3 model loading needs concrete fp16/bf16 flags, not "auto".
mixed_precision = infer_deepspeed_mixed_precision(ds_config)
plugin = DeepSpeedPlugin(hf_ds_config=ds_config, zero3_init_flag=True)
if not plugin.hf_ds_config.is_zero3():
return None
# Reuse the same precision inference rule as the training-time DeepSpeed path
# so both model-loading and engine setup stay aligned.
plugin.set_mixed_precision(mixed_precision)
plugin.set_deepspeed_weakref()
if not is_deepspeed_zero3_enabled():
raise RuntimeError(
"DeepSpeed ZeRO-3 model-loading bootstrap failed: transformers still reports zero3 disabled "
"after constructing HfDeepSpeedConfig. This usually means the runtime is using a different transformers "
"installation than expected, or the DeepSpeed global state was not established correctly."
)
return plugin
def teardown_deepspeed_zero3_model_loading(plugin) -> None:
if plugin is not None:
_unset_hf_deepspeed_config()

View File

@@ -28,6 +28,7 @@ from accelerate.utils import DeepSpeedPlugin
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
from ...model_plugins.deepspeed_utils import infer_deepspeed_mixed_precision
logger = get_logger(__name__)
@@ -51,6 +52,7 @@ class DeepSpeedEngine:
raise ValueError("DeepSpeed config_file is required in dist_config")
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
ds_plugin.set_mixed_precision(infer_deepspeed_mixed_precision(ds_plugin.deepspeed_config))
self.accelerator = Accelerator(
deepspeed_plugin=ds_plugin,

View File

@@ -119,12 +119,12 @@ def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir:
class FSDP2Engine:
def __init__(self, dist_config: dict):
def __init__(self, dist_config: dict, bf16: bool = False):
self.dist_interface = DistributedInterface()
self.rank = self.dist_interface.get_rank()
self.local_rank = self.dist_interface.get_local_rank()
self.world_size = self.dist_interface.get_world_size()
self.mixed_precision = dist_config.get("mixed_precision", "bf16")
self.mixed_precision = "bf16" if bf16 else "fp32"
self.reshard_after_forward = dist_config.get("reshard_after_forward", True)
self.offload_params = dist_config.get("offload_params", False)
self.pin_memory = dist_config.get("pin_memory", True)
@@ -147,10 +147,7 @@ class FSDP2Engine:
if self.mixed_precision == "bf16":
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
elif self.mixed_precision == "fp16":
param_dtype = torch.float16
reduce_dtype = torch.float32
else:
elif self.mixed_precision == "fp32":
param_dtype = torch.float32
reduce_dtype = torch.float32

View File

@@ -35,7 +35,7 @@ class DistributedPlugin(BasePlugin):
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .fsdp2 import FSDP2Engine
return FSDP2Engine(dist_config).shard_model(model)
return FSDP2Engine(dist_config, bf16=bool(kwargs.get("bf16"))).shard_model(model)
@DistributedPlugin("fsdp2").register("save_model")