mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-28 02:39:03 +08:00
[v1] add deepspeed zero3 trigger for low memory usage weight loading (#10300)
This commit is contained in:
@@ -178,6 +178,7 @@ class BaseTrainer:
|
||||
self.model = DistributedPlugin(self.args.dist_config.name)(
|
||||
self.model,
|
||||
self.args.dist_config,
|
||||
bf16=self.args.bf16,
|
||||
)
|
||||
|
||||
def _init_optimizer(self) -> None:
|
||||
|
||||
@@ -52,7 +52,11 @@ class ModelEngine:
|
||||
is_train: Whether to train the model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model_args: ModelArguments,
|
||||
is_train: bool = False,
|
||||
) -> None:
|
||||
self.args = model_args
|
||||
"""Model arguments."""
|
||||
self.is_train = is_train
|
||||
@@ -63,8 +67,26 @@ class ModelEngine:
|
||||
"""Renderer."""
|
||||
self.model_config = self._init_model_config()
|
||||
"""Model configuration."""
|
||||
self.model = self._init_model()
|
||||
"""HF model."""
|
||||
self._dist_config = DistributedInterface().dist_config
|
||||
self._deepspeed_zero3_plugin = None
|
||||
self._deepspeed_zero3_enabled = False
|
||||
|
||||
if self.is_train and self._dist_config is not None and self._dist_config.get("name") == "deepspeed":
|
||||
from ..plugins.model_plugins.deepspeed_utils import (
|
||||
setup_deepspeed_zero3_model_loading,
|
||||
teardown_deepspeed_zero3_model_loading,
|
||||
)
|
||||
|
||||
try:
|
||||
self._deepspeed_zero3_plugin = setup_deepspeed_zero3_model_loading(self.is_train, self._dist_config)
|
||||
self._deepspeed_zero3_enabled = self._deepspeed_zero3_plugin is not None
|
||||
self.model = self._init_model()
|
||||
finally:
|
||||
teardown_deepspeed_zero3_model_loading(self._deepspeed_zero3_plugin)
|
||||
self._deepspeed_zero3_plugin = None
|
||||
self._deepspeed_zero3_enabled = False
|
||||
else:
|
||||
self.model = self._init_model()
|
||||
|
||||
def _init_processor(self) -> Processor:
|
||||
"""Init processor.
|
||||
@@ -97,7 +119,7 @@ class ModelEngine:
|
||||
else:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
init_kwargs = {"device_map": init_device}
|
||||
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
|
||||
|
||||
if self.args.quant_config is not None:
|
||||
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||
|
||||
Reference in New Issue
Block a user