[v1] add deepspeed zero3 trigger for low memory usage weight loading (#10300)

This commit is contained in:
jiaqiw09
2026-04-21 14:09:52 +08:00
committed by GitHub
parent f5d739b132
commit 28a6ea1cdc
13 changed files with 160 additions and 18 deletions

View File

@@ -178,6 +178,7 @@ class BaseTrainer:
self.model = DistributedPlugin(self.args.dist_config.name)(
self.model,
self.args.dist_config,
bf16=self.args.bf16,
)
def _init_optimizer(self) -> None:

View File

@@ -52,7 +52,11 @@ class ModelEngine:
is_train: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
def __init__(
self,
model_args: ModelArguments,
is_train: bool = False,
) -> None:
self.args = model_args
"""Model arguments."""
self.is_train = is_train
@@ -63,8 +67,26 @@ class ModelEngine:
"""Renderer."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
"""HF model."""
self._dist_config = DistributedInterface().dist_config
self._deepspeed_zero3_plugin = None
self._deepspeed_zero3_enabled = False
if self.is_train and self._dist_config is not None and self._dist_config.get("name") == "deepspeed":
from ..plugins.model_plugins.deepspeed_utils import (
setup_deepspeed_zero3_model_loading,
teardown_deepspeed_zero3_model_loading,
)
try:
self._deepspeed_zero3_plugin = setup_deepspeed_zero3_model_loading(self.is_train, self._dist_config)
self._deepspeed_zero3_enabled = self._deepspeed_zero3_plugin is not None
self.model = self._init_model()
finally:
teardown_deepspeed_zero3_model_loading(self._deepspeed_zero3_plugin)
self._deepspeed_zero3_plugin = None
self._deepspeed_zero3_enabled = False
else:
self.model = self._init_model()
def _init_processor(self) -> Processor:
"""Init processor.
@@ -97,7 +119,7 @@ class ModelEngine:
else:
init_device = DistributedInterface().current_device
init_kwargs = {"device_map": init_device}
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin