diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index c2b0b581..8a9f8dd6 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -1,4 +1,6 @@ import torch +import deepspeed # type: ignore +from copy import deepcopy from collections import defaultdict from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from transformers import BatchEncoding, Trainer @@ -9,6 +11,7 @@ from llmtuner.extras.constants import IGNORE_INDEX if TYPE_CHECKING: from transformers import PreTrainedModel + from trl import PreTrainedModelWrapper class CustomDPOTrainer(DPOTrainer): @@ -47,6 +50,36 @@ class CustomDPOTrainer(DPOTrainer): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + def concatenated_forward( self, model: Optional[torch.nn.Module] = None,