[model]: add ernie4_5_moe support for DeepSpeed Zero3 training (#9262)

This commit is contained in:
Jiayi Mao 2025-10-13 13:13:31 +08:00 committed by GitHub
parent 575e4099df
commit 48974783da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -55,6 +55,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
# deepseek v3 and kimi vl use custom code
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
if model_type == "ernie4_5_moe":
from transformers.models.ernie4_5_moe.modeling_ernie4_5_moe import Ernie4_5_MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Ernie4_5_MoeSparseMoeBlock])
if model_type == "granitemoe":
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
@ -130,6 +135,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
if model_type in [
"dbrx",
"ernie4_5_moe",
"granitemoe",
"jamba",
"jetmoe",
@ -148,7 +154,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
]:
setattr(text_config, "output_router_logits", True)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]: