fix deepspeed version

Former-commit-id: cca6f351081903ca3b5f79f10accc1bbbae0ee61
This commit is contained in:
hiyouga 2024-06-11 16:52:36 +08:00
parent 6c9cc199ef
commit 8c574eb3cb

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Sequence
import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -10,6 +11,13 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, leaf_modules)
def add_z3_leaf_module(model: "PreTrainedModel") -> None: def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r""" r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3. Sets module as a leaf module to skip partitioning in deepspeed zero3.
@ -17,33 +25,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
return return
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
if getattr(model.config, "model_type", None) == "dbrx": if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN]) _set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba": if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
set_z3_leaf_modules(model, [JambaSparseMoeBlock]) _set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe": if getattr(model.config, "model_type", None) == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral": if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe": if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: