mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-21 20:36:02 +08:00
[v1] add deepspeed zero3 trigger for low memory usage weight loading (#10300)
This commit is contained in:
@@ -29,7 +29,6 @@ micro_batch_size: 1
|
|||||||
global_batch_size: 4
|
global_batch_size: 4
|
||||||
cutoff_len: 2048
|
cutoff_len: 2048
|
||||||
learning_rate: 2.0e-5
|
learning_rate: 2.0e-5
|
||||||
bf16: false
|
|
||||||
max_steps: 10
|
max_steps: 10
|
||||||
|
|
||||||
### sample
|
### sample
|
||||||
|
|||||||
@@ -19,5 +19,4 @@ output_dir: outputs/Qwen3-0.6B-deepspeed
|
|||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
cutoff_len: 2048
|
cutoff_len: 2048
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
bf16: true
|
|
||||||
max_steps: 10
|
max_steps: 10
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ output_dir: outputs/test_fsdp2
|
|||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
cutoff_len: 2048
|
cutoff_len: 2048
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
bf16: false
|
|
||||||
max_steps: 10
|
max_steps: 10
|
||||||
|
|
||||||
### sample
|
### sample
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ output_dir: ./outputs/test_lora
|
|||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
cutoff_len: 2048
|
cutoff_len: 2048
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
bf16: true
|
|
||||||
max_steps: 10
|
max_steps: 10
|
||||||
|
|
||||||
### sample
|
### sample
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ output_dir: outputs/test_quantization
|
|||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
cutoff_len: 2048
|
cutoff_len: 2048
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
bf16: false
|
|
||||||
max_steps: 10
|
max_steps: 10
|
||||||
|
|
||||||
### sample
|
### sample
|
||||||
|
|||||||
@@ -123,6 +123,8 @@ class DistributedInterface:
|
|||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.dist_config = config
|
||||||
|
|
||||||
helper.set_device_index()
|
helper.set_device_index()
|
||||||
self._is_distributed = helper.is_distributed()
|
self._is_distributed = helper.is_distributed()
|
||||||
self._rank = helper.get_rank()
|
self._rank = helper.get_rank()
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class TrainingArguments:
|
|||||||
metadata={"help": "Maximum gradient norm for training."},
|
metadata={"help": "Maximum gradient norm for training."},
|
||||||
)
|
)
|
||||||
bf16: bool = field(
|
bf16: bool = field(
|
||||||
default=False,
|
default=True,
|
||||||
metadata={"help": "Use bf16 for training."},
|
metadata={"help": "Use bf16 for training."},
|
||||||
)
|
)
|
||||||
batching_strategy: BatchingStrategy = field(
|
batching_strategy: BatchingStrategy = field(
|
||||||
@@ -66,7 +66,7 @@ class TrainingArguments:
|
|||||||
metadata={"help": "Number of workers for batching."},
|
metadata={"help": "Number of workers for batching."},
|
||||||
)
|
)
|
||||||
enable_activation_checkpointing: bool = field(
|
enable_activation_checkpointing: bool = field(
|
||||||
default=False,
|
default=True,
|
||||||
metadata={"help": "Enable activation checkpointing for training."},
|
metadata={"help": "Enable activation checkpointing for training."},
|
||||||
)
|
)
|
||||||
dist_config: PluginConfig | None = field(
|
dist_config: PluginConfig | None = field(
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ class BaseTrainer:
|
|||||||
self.model = DistributedPlugin(self.args.dist_config.name)(
|
self.model = DistributedPlugin(self.args.dist_config.name)(
|
||||||
self.model,
|
self.model,
|
||||||
self.args.dist_config,
|
self.args.dist_config,
|
||||||
|
bf16=self.args.bf16,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_optimizer(self) -> None:
|
def _init_optimizer(self) -> None:
|
||||||
|
|||||||
@@ -52,7 +52,11 @@ class ModelEngine:
|
|||||||
is_train: Whether to train the model.
|
is_train: Whether to train the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
is_train: bool = False,
|
||||||
|
) -> None:
|
||||||
self.args = model_args
|
self.args = model_args
|
||||||
"""Model arguments."""
|
"""Model arguments."""
|
||||||
self.is_train = is_train
|
self.is_train = is_train
|
||||||
@@ -63,8 +67,26 @@ class ModelEngine:
|
|||||||
"""Renderer."""
|
"""Renderer."""
|
||||||
self.model_config = self._init_model_config()
|
self.model_config = self._init_model_config()
|
||||||
"""Model configuration."""
|
"""Model configuration."""
|
||||||
|
self._dist_config = DistributedInterface().dist_config
|
||||||
|
self._deepspeed_zero3_plugin = None
|
||||||
|
self._deepspeed_zero3_enabled = False
|
||||||
|
|
||||||
|
if self.is_train and self._dist_config is not None and self._dist_config.get("name") == "deepspeed":
|
||||||
|
from ..plugins.model_plugins.deepspeed_utils import (
|
||||||
|
setup_deepspeed_zero3_model_loading,
|
||||||
|
teardown_deepspeed_zero3_model_loading,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._deepspeed_zero3_plugin = setup_deepspeed_zero3_model_loading(self.is_train, self._dist_config)
|
||||||
|
self._deepspeed_zero3_enabled = self._deepspeed_zero3_plugin is not None
|
||||||
|
self.model = self._init_model()
|
||||||
|
finally:
|
||||||
|
teardown_deepspeed_zero3_model_loading(self._deepspeed_zero3_plugin)
|
||||||
|
self._deepspeed_zero3_plugin = None
|
||||||
|
self._deepspeed_zero3_enabled = False
|
||||||
|
else:
|
||||||
self.model = self._init_model()
|
self.model = self._init_model()
|
||||||
"""HF model."""
|
|
||||||
|
|
||||||
def _init_processor(self) -> Processor:
|
def _init_processor(self) -> Processor:
|
||||||
"""Init processor.
|
"""Init processor.
|
||||||
@@ -97,7 +119,7 @@ class ModelEngine:
|
|||||||
else:
|
else:
|
||||||
init_device = DistributedInterface().current_device
|
init_device = DistributedInterface().current_device
|
||||||
|
|
||||||
init_kwargs = {"device_map": init_device}
|
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
|
||||||
|
|
||||||
if self.args.quant_config is not None:
|
if self.args.quant_config is not None:
|
||||||
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||||
|
|||||||
123
src/llamafactory/v1/plugins/model_plugins/deepspeed_utils.py
Normal file
123
src/llamafactory/v1/plugins/model_plugins/deepspeed_utils.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_precision_enabled(value: Any) -> bool | str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
value_lower = value.lower()
|
||||||
|
if value_lower == "true":
|
||||||
|
return True
|
||||||
|
if value_lower == "false":
|
||||||
|
return False
|
||||||
|
if value_lower == "auto":
|
||||||
|
return "auto"
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def infer_deepspeed_mixed_precision(ds_config: dict[str, Any]) -> str:
|
||||||
|
ds_config.setdefault("fp16", {})
|
||||||
|
ds_config.setdefault("bf16", {})
|
||||||
|
|
||||||
|
fp16_enabled = _normalize_precision_enabled(ds_config["fp16"].get("enabled", "auto"))
|
||||||
|
bf16_enabled = _normalize_precision_enabled(ds_config["bf16"].get("enabled", "auto"))
|
||||||
|
|
||||||
|
# This project only supports DeepSpeed bf16 or no mixed precision.
|
||||||
|
if fp16_enabled is True:
|
||||||
|
raise ValueError("DeepSpeed only supports bf16 mixed precision for now, fp16 is not supported.")
|
||||||
|
|
||||||
|
if bf16_enabled is True:
|
||||||
|
mixed_precision = "bf16"
|
||||||
|
elif bf16_enabled is False:
|
||||||
|
mixed_precision = "no"
|
||||||
|
elif fp16_enabled is False:
|
||||||
|
mixed_precision = "no"
|
||||||
|
else:
|
||||||
|
# When both bf16/fp16 are left as auto (or absent), default to bf16.
|
||||||
|
mixed_precision = "bf16"
|
||||||
|
|
||||||
|
ds_config["fp16"]["enabled"] = False
|
||||||
|
ds_config["bf16"]["enabled"] = mixed_precision == "bf16"
|
||||||
|
return mixed_precision
|
||||||
|
|
||||||
|
|
||||||
|
def _unset_hf_deepspeed_config() -> None:
|
||||||
|
try:
|
||||||
|
from transformers.integrations import unset_hf_deepspeed_config
|
||||||
|
except ImportError:
|
||||||
|
from transformers.deepspeed import unset_hf_deepspeed_config
|
||||||
|
|
||||||
|
unset_hf_deepspeed_config()
|
||||||
|
|
||||||
|
|
||||||
|
def _load_deepspeed_config(config_file: str) -> dict[str, Any]:
|
||||||
|
with open(config_file, encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_deepspeed_zero3_model_loading(is_train: bool, dist_config: dict[str, Any] | None):
|
||||||
|
"""Enable transformers' ZeRO-3-aware model loading for the current thread."""
|
||||||
|
config_file = dist_config.get("config_file")
|
||||||
|
if not config_file:
|
||||||
|
raise ValueError("DeepSpeed config_file is required in dist_config")
|
||||||
|
|
||||||
|
from accelerate.utils import DeepSpeedPlugin
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
except ImportError:
|
||||||
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
# DeepSpeed configs often use "auto" placeholders that only make sense once
|
||||||
|
# we know the current runtime batch settings and precision mode.
|
||||||
|
ds_config = deepcopy(_load_deepspeed_config(config_file))
|
||||||
|
if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto":
|
||||||
|
ds_config["gradient_accumulation_steps"] = 1
|
||||||
|
if "train_micro_batch_size_per_gpu" not in ds_config or ds_config["train_micro_batch_size_per_gpu"] == "auto":
|
||||||
|
ds_config["train_micro_batch_size_per_gpu"] = 1
|
||||||
|
if ds_config.get("train_batch_size") == "auto":
|
||||||
|
ds_config.pop("train_batch_size")
|
||||||
|
|
||||||
|
zero_stage = ds_config.get("zero_optimization", {}).get("stage")
|
||||||
|
if zero_stage != 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ZeRO-3 model loading needs concrete fp16/bf16 flags, not "auto".
|
||||||
|
mixed_precision = infer_deepspeed_mixed_precision(ds_config)
|
||||||
|
|
||||||
|
plugin = DeepSpeedPlugin(hf_ds_config=ds_config, zero3_init_flag=True)
|
||||||
|
|
||||||
|
if not plugin.hf_ds_config.is_zero3():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Reuse the same precision inference rule as the training-time DeepSpeed path
|
||||||
|
# so both model-loading and engine setup stay aligned.
|
||||||
|
plugin.set_mixed_precision(mixed_precision)
|
||||||
|
plugin.set_deepspeed_weakref()
|
||||||
|
|
||||||
|
if not is_deepspeed_zero3_enabled():
|
||||||
|
raise RuntimeError(
|
||||||
|
"DeepSpeed ZeRO-3 model-loading bootstrap failed: transformers still reports zero3 disabled "
|
||||||
|
"after constructing HfDeepSpeedConfig. This usually means the runtime is using a different transformers "
|
||||||
|
"installation than expected, or the DeepSpeed global state was not established correctly."
|
||||||
|
)
|
||||||
|
return plugin
|
||||||
|
|
||||||
|
|
||||||
|
def teardown_deepspeed_zero3_model_loading(plugin) -> None:
|
||||||
|
if plugin is not None:
|
||||||
|
_unset_hf_deepspeed_config()
|
||||||
@@ -28,6 +28,7 @@ from accelerate.utils import DeepSpeedPlugin
|
|||||||
|
|
||||||
from ....utils.logging import get_logger
|
from ....utils.logging import get_logger
|
||||||
from ....utils.types import HFModel, Processor
|
from ....utils.types import HFModel, Processor
|
||||||
|
from ...model_plugins.deepspeed_utils import infer_deepspeed_mixed_precision
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -51,6 +52,7 @@ class DeepSpeedEngine:
|
|||||||
raise ValueError("DeepSpeed config_file is required in dist_config")
|
raise ValueError("DeepSpeed config_file is required in dist_config")
|
||||||
|
|
||||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
|
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
|
||||||
|
ds_plugin.set_mixed_precision(infer_deepspeed_mixed_precision(ds_plugin.deepspeed_config))
|
||||||
|
|
||||||
self.accelerator = Accelerator(
|
self.accelerator = Accelerator(
|
||||||
deepspeed_plugin=ds_plugin,
|
deepspeed_plugin=ds_plugin,
|
||||||
|
|||||||
@@ -119,12 +119,12 @@ def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir:
|
|||||||
|
|
||||||
|
|
||||||
class FSDP2Engine:
|
class FSDP2Engine:
|
||||||
def __init__(self, dist_config: dict):
|
def __init__(self, dist_config: dict, bf16: bool = False):
|
||||||
self.dist_interface = DistributedInterface()
|
self.dist_interface = DistributedInterface()
|
||||||
self.rank = self.dist_interface.get_rank()
|
self.rank = self.dist_interface.get_rank()
|
||||||
self.local_rank = self.dist_interface.get_local_rank()
|
self.local_rank = self.dist_interface.get_local_rank()
|
||||||
self.world_size = self.dist_interface.get_world_size()
|
self.world_size = self.dist_interface.get_world_size()
|
||||||
self.mixed_precision = dist_config.get("mixed_precision", "bf16")
|
self.mixed_precision = "bf16" if bf16 else "fp32"
|
||||||
self.reshard_after_forward = dist_config.get("reshard_after_forward", True)
|
self.reshard_after_forward = dist_config.get("reshard_after_forward", True)
|
||||||
self.offload_params = dist_config.get("offload_params", False)
|
self.offload_params = dist_config.get("offload_params", False)
|
||||||
self.pin_memory = dist_config.get("pin_memory", True)
|
self.pin_memory = dist_config.get("pin_memory", True)
|
||||||
@@ -147,10 +147,7 @@ class FSDP2Engine:
|
|||||||
if self.mixed_precision == "bf16":
|
if self.mixed_precision == "bf16":
|
||||||
param_dtype = torch.bfloat16
|
param_dtype = torch.bfloat16
|
||||||
reduce_dtype = torch.float32
|
reduce_dtype = torch.float32
|
||||||
elif self.mixed_precision == "fp16":
|
elif self.mixed_precision == "fp32":
|
||||||
param_dtype = torch.float16
|
|
||||||
reduce_dtype = torch.float32
|
|
||||||
else:
|
|
||||||
param_dtype = torch.float32
|
param_dtype = torch.float32
|
||||||
reduce_dtype = torch.float32
|
reduce_dtype = torch.float32
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class DistributedPlugin(BasePlugin):
|
|||||||
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||||
from .fsdp2 import FSDP2Engine
|
from .fsdp2 import FSDP2Engine
|
||||||
|
|
||||||
return FSDP2Engine(dist_config).shard_model(model)
|
return FSDP2Engine(dist_config, bf16=bool(kwargs.get("bf16"))).shard_model(model)
|
||||||
|
|
||||||
|
|
||||||
@DistributedPlugin("fsdp2").register("save_model")
|
@DistributedPlugin("fsdp2").register("save_model")
|
||||||
|
|||||||
Reference in New Issue
Block a user