Merge pull request #2283 from A-Cepheus/main

fix: ZeRO3 does not work with MoE models
Former-commit-id: 8e4b3a959a7b94f389f31e222bb0e80c1dd83cbb
This commit is contained in:
hoshi-hiyouga 2024-01-22 23:28:45 +08:00 committed by GitHub
commit 8196e9f806

View File

@ -284,6 +284,12 @@ def patch_model(
if is_trainable: if is_trainable:
_prepare_model_for_training(model, model_args) _prepare_model_for_training(model, model_args)
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: