mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 07:42:49 +08:00
[model]: add ernie4_5_moe support for DeepSpeed Zero3 training (#9262)
This commit is contained in:
parent
575e4099df
commit
48974783da
@ -55,6 +55,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
# deepseek v3 and kimi vl use custom code
|
||||
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
|
||||
|
||||
if model_type == "ernie4_5_moe":
|
||||
from transformers.models.ernie4_5_moe.modeling_ernie4_5_moe import Ernie4_5_MoeSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [Ernie4_5_MoeSparseMoeBlock])
|
||||
|
||||
if model_type == "granitemoe":
|
||||
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
|
||||
|
||||
@ -130,6 +135,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
||||
|
||||
if model_type in [
|
||||
"dbrx",
|
||||
"ernie4_5_moe",
|
||||
"granitemoe",
|
||||
"jamba",
|
||||
"jetmoe",
|
||||
@ -148,7 +154,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
||||
]:
|
||||
setattr(text_config, "output_router_logits", True)
|
||||
|
||||
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
||||
if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user