From 1344416378ba3bbc19006bd849bcfdbb508e86fd Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 23 Apr 2025 15:30:49 +0800 Subject: [PATCH] [model] fix moe zero3 (#7826) --- src/llamafactory/model/model_utils/moe.py | 51 +++++++++++++++++++---- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index b3fca4f7..ec6e1e38 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -44,6 +44,16 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: _set_z3_leaf_modules(model, [DbrxFFN]) + if model_type == "deepseek_v3": + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + + _set_z3_leaf_modules(model, [DeepseekV3MoE]) + + if model_type == "granitemoe": + from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE + + _set_z3_leaf_modules(model, [GraniteMoeMoE]) + if model_type == "jamba": from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock @@ -54,27 +64,55 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) - if model_type in ["kimi_vl", "deepseek_v3"]: - check_version("transformers>=4.51.1") - from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + if model_type == "llama4": + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - _set_z3_leaf_modules(model, [DeepseekV3MoE]) + _set_z3_leaf_modules(model, [Llama4TextMoe]) if model_type == "mixtral": from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + if model_type == "olmoe": + from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock + + _set_z3_leaf_modules(model, [OlmoeSparseMoeBlock]) + + if model_type == "phimoe": + from transformers.models.phimoe.modeling_phimoe import PhimoeSparseMoeBlock + + _set_z3_leaf_modules(model, [PhimoeSparseMoeBlock]) + if model_type == "qwen2_moe": from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + if model_type == "qwen3_moe": + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock + + _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock]) + def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: model_type = getattr(config, "model_type", None) if model_args.moe_aux_loss_coef is not None: - if model_type in ["jamba", "mixtral", "qwen2_moe"]: + if model_type in [ + "dbrx", + "granitemoe", + "jamba", + "jetmoe", + "llama4", + "mixtral", + "olmoe", + "phimoe", + "qwen2_moe", + "qwen3_moe", + ]: + setattr(config, "output_router_logits", is_trainable) + + if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) elif model_type == "deepseek": @@ -82,6 +120,3 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t elif model_type == "jetmoe": setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) - - if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: - setattr(config, "output_router_logits", is_trainable)