[model] fix moe zero3 (#7826)

This commit is contained in:
hoshi-hiyouga 2025-04-23 15:30:49 +08:00 committed by GitHub
parent 1dd67eb042
commit 1344416378
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -44,6 +44,16 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [DbrxFFN])
if model_type == "deepseek_v3":
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
_set_z3_leaf_modules(model, [DeepseekV3MoE])
if model_type == "granitemoe":
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
_set_z3_leaf_modules(model, [GraniteMoeMoE])
if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
@ -54,27 +64,55 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if model_type in ["kimi_vl", "deepseek_v3"]:
check_version("transformers>=4.51.1")
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
if model_type == "llama4":
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
_set_z3_leaf_modules(model, [DeepseekV3MoE])
_set_z3_leaf_modules(model, [Llama4TextMoe])
if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
_set_z3_leaf_modules(model, [OlmoeSparseMoeBlock])
if model_type == "phimoe":
from transformers.models.phimoe.modeling_phimoe import PhimoeSparseMoeBlock
_set_z3_leaf_modules(model, [PhimoeSparseMoeBlock])
if model_type == "qwen2_moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if model_type == "qwen3_moe":
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None:
if model_type in ["jamba", "mixtral", "qwen2_moe"]:
if model_type in [
"dbrx",
"granitemoe",
"jamba",
"jetmoe",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "output_router_logits", is_trainable)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
@ -82,6 +120,3 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)