diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index a99c0b93..1a27e34b 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -40,7 +40,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: model_type = getattr(model.config, "model_type", None) text_config = getattr(model.config, "text_config", None) - text_architectures = getattr(text_config, "architectures", None) + text_model_type = getattr(text_config, "model_type", None) if model_type == "dbrx": from transformers.models.dbrx.modeling_dbrx import DbrxFFN @@ -105,11 +105,21 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) - if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM": + if model_type == "qwen3_moe" or text_model_type == "qwen3_moe": # internvl 3.5 from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock]) + if model_type == "qwen3_vl_moe": + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock + + _set_z3_leaf_modules(model, [Qwen3VLMoeTextSparseMoeBlock]) + + if model_type == "qwen3_omni_moe": + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextSparseMoeBlock + + _set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock]) + def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: if not is_trainable or not model_args.moe_aux_loss_coef: