mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
[misc] fix moe models (#9230)
This commit is contained in:
parent
af8437095a
commit
40d3691e9e
@ -40,7 +40,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||||||
|
|
||||||
model_type = getattr(model.config, "model_type", None)
|
model_type = getattr(model.config, "model_type", None)
|
||||||
text_config = getattr(model.config, "text_config", None)
|
text_config = getattr(model.config, "text_config", None)
|
||||||
text_architectures = getattr(text_config, "architectures", None)
|
text_model_type = getattr(text_config, "model_type", None)
|
||||||
|
|
||||||
if model_type == "dbrx":
|
if model_type == "dbrx":
|
||||||
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||||
@ -105,11 +105,21 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||||||
|
|
||||||
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||||
|
|
||||||
if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM":
|
if model_type == "qwen3_moe" or text_model_type == "qwen3_moe": # internvl 3.5
|
||||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||||
|
|
||||||
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
|
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
|
||||||
|
|
||||||
|
if model_type == "qwen3_vl_moe":
|
||||||
|
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock
|
||||||
|
|
||||||
|
_set_z3_leaf_modules(model, [Qwen3VLMoeTextSparseMoeBlock])
|
||||||
|
|
||||||
|
if model_type == "qwen3_omni_moe":
|
||||||
|
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||||
|
|
||||||
|
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
|
||||||
|
|
||||||
|
|
||||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
if not is_trainable or not model_args.moe_aux_loss_coef:
|
if not is_trainable or not model_args.moe_aux_loss_coef:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user