mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
parent
034b658348
commit
ecdea0036c
@ -1,4 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import deepspeed # type: ignore
|
||||||
|
from copy import deepcopy
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
from transformers import BatchEncoding, Trainer
|
from transformers import BatchEncoding, Trainer
|
||||||
@ -9,6 +11,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
|
|
||||||
class CustomDPOTrainer(DPOTrainer):
|
class CustomDPOTrainer(DPOTrainer):
|
||||||
@ -47,6 +50,36 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"):
|
||||||
|
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||||
|
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||||
|
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||||
|
if model is not None:
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
hidden_size = (
|
||||||
|
max(model.config.hidden_sizes)
|
||||||
|
if getattr(model.config, "hidden_sizes", None)
|
||||||
|
else getattr(model.config, "hidden_size", None)
|
||||||
|
)
|
||||||
|
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
||||||
|
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
||||||
|
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
||||||
|
config_kwargs.update(
|
||||||
|
{
|
||||||
|
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||||
|
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||||
|
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# If ZeRO-3 is used, we shard both the active and reference model.
|
||||||
|
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||||
|
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||||
|
config_kwargs["zero_optimization"]["stage"] = 0
|
||||||
|
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self,
|
self,
|
||||||
model: Optional[torch.nn.Module] = None,
|
model: Optional[torch.nn.Module] = None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user