[misc] fix moe models (#9230)

This commit is contained in:
Yaowei Zheng 2025-10-05 02:41:02 +08:00 committed by GitHub
parent af8437095a
commit 40d3691e9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -40,7 +40,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
text_config = getattr(model.config, "text_config", None) text_config = getattr(model.config, "text_config", None)
text_architectures = getattr(text_config, "architectures", None) text_model_type = getattr(text_config, "model_type", None)
if model_type == "dbrx": if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN from transformers.models.dbrx.modeling_dbrx import DbrxFFN
@ -105,11 +105,21 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM": if model_type == "qwen3_moe" or text_model_type == "qwen3_moe": # internvl 3.5
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
if model_type == "qwen3_vl_moe":
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3VLMoeTextSparseMoeBlock])
if model_type == "qwen3_omni_moe":
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef: if not is_trainable or not model_args.moe_aux_loss_coef: