mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 03:02:51 +08:00
[model] fix moe zero3 (#7826)
This commit is contained in:
parent
1dd67eb042
commit
1344416378
@ -44,6 +44,16 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
if model_type == "deepseek_v3":
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
_set_z3_leaf_modules(model, [DeepseekV3MoE])
|
||||
|
||||
if model_type == "granitemoe":
|
||||
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
|
||||
|
||||
_set_z3_leaf_modules(model, [GraniteMoeMoE])
|
||||
|
||||
if model_type == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
@ -54,27 +64,55 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
|
||||
if model_type in ["kimi_vl", "deepseek_v3"]:
|
||||
check_version("transformers>=4.51.1")
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
if model_type == "llama4":
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||
|
||||
_set_z3_leaf_modules(model, [DeepseekV3MoE])
|
||||
_set_z3_leaf_modules(model, [Llama4TextMoe])
|
||||
|
||||
if model_type == "mixtral":
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
|
||||
if model_type == "olmoe":
|
||||
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [OlmoeSparseMoeBlock])
|
||||
|
||||
if model_type == "phimoe":
|
||||
from transformers.models.phimoe.modeling_phimoe import PhimoeSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [PhimoeSparseMoeBlock])
|
||||
|
||||
if model_type == "qwen2_moe":
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
|
||||
if model_type == "qwen3_moe":
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_args.moe_aux_loss_coef is not None:
|
||||
if model_type in ["jamba", "mixtral", "qwen2_moe"]:
|
||||
if model_type in [
|
||||
"dbrx",
|
||||
"granitemoe",
|
||||
"jamba",
|
||||
"jetmoe",
|
||||
"llama4",
|
||||
"mixtral",
|
||||
"olmoe",
|
||||
"phimoe",
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
]:
|
||||
setattr(config, "output_router_logits", is_trainable)
|
||||
|
||||
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
elif model_type == "deepseek":
|
||||
@ -82,6 +120,3 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
||||
|
||||
elif model_type == "jetmoe":
|
||||
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
|
||||
setattr(config, "output_router_logits", is_trainable)
|
||||
|
Loading…
x
Reference in New Issue
Block a user