mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
fix #4209
DeepSpeed ZeRO3 has inflight param error when calling model.eval() Former-commit-id: cf9f2d6c42b5a37038c9eededbb767eae6a3f67d
This commit is contained in:
parent
833aa324c2
commit
81ed4d8abf
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
@ -10,7 +11,7 @@ from trl import DPOTrainer
|
|||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
|
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -61,6 +62,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
if not hasattr(self, "accelerator"):
|
if not hasattr(self, "accelerator"):
|
||||||
raise AttributeError("Please update `transformers`.")
|
raise AttributeError("Please update `transformers`.")
|
||||||
|
|
||||||
|
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||||
|
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
if not (
|
if not (
|
||||||
@ -176,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
if self.ref_model is None:
|
if self.ref_model is None:
|
||||||
ref_model = model
|
ref_model = model
|
||||||
ref_context = get_ref_context(self.accelerator, model)
|
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||||
else:
|
else:
|
||||||
ref_model = self.ref_model
|
ref_model = self.ref_model
|
||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
@ -9,7 +10,7 @@ from trl import KTOTrainer
|
|||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
|
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -60,6 +61,8 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
if not hasattr(self, "accelerator"):
|
if not hasattr(self, "accelerator"):
|
||||||
raise AttributeError("Please update `transformers`.")
|
raise AttributeError("Please update `transformers`.")
|
||||||
|
|
||||||
|
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||||
|
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
if not (
|
if not (
|
||||||
@ -143,7 +146,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
"""
|
"""
|
||||||
if self.ref_model is None:
|
if self.ref_model is None:
|
||||||
ref_model = model
|
ref_model = model
|
||||||
ref_context = get_ref_context(self.accelerator, model)
|
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||||
else:
|
else:
|
||||||
ref_model = self.ref_model
|
ref_model = self.ref_model
|
||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -136,6 +137,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
device_type = unwrapped_model.pretrained_model.device.type
|
device_type = unwrapped_model.pretrained_model.device.type
|
||||||
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
|
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
|
||||||
|
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||||
|
|
||||||
if finetuning_args.reward_model_type == "full":
|
if finetuning_args.reward_model_type == "full":
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -19,7 +18,6 @@ if is_galore_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from accelerate import Accelerator
|
|
||||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
@ -154,17 +152,6 @@ def create_reward_model(
|
|||||||
return reward_model
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
|
|
||||||
r"""
|
|
||||||
Gets adapter context for the reference model.
|
|
||||||
"""
|
|
||||||
with accelerator.unwrap_model(model).disable_adapter():
|
|
||||||
model.eval()
|
|
||||||
yield
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user