mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[model]: add ernie4_5_moe support for DeepSpeed Zero3 training (#9262)
This commit is contained in:
parent
575e4099df
commit
48974783da
@ -55,6 +55,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||||||
# deepseek v3 and kimi vl use custom code
|
# deepseek v3 and kimi vl use custom code
|
||||||
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
|
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
|
||||||
|
|
||||||
|
if model_type == "ernie4_5_moe":
|
||||||
|
from transformers.models.ernie4_5_moe.modeling_ernie4_5_moe import Ernie4_5_MoeSparseMoeBlock
|
||||||
|
|
||||||
|
_set_z3_leaf_modules(model, [Ernie4_5_MoeSparseMoeBlock])
|
||||||
|
|
||||||
if model_type == "granitemoe":
|
if model_type == "granitemoe":
|
||||||
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
|
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
|
||||||
|
|
||||||
@ -130,6 +135,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
|||||||
|
|
||||||
if model_type in [
|
if model_type in [
|
||||||
"dbrx",
|
"dbrx",
|
||||||
|
"ernie4_5_moe",
|
||||||
"granitemoe",
|
"granitemoe",
|
||||||
"jamba",
|
"jamba",
|
||||||
"jetmoe",
|
"jetmoe",
|
||||||
@ -148,7 +154,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
|||||||
]:
|
]:
|
||||||
setattr(text_config, "output_router_logits", True)
|
setattr(text_config, "output_router_logits", True)
|
||||||
|
|
||||||
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
||||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||||
|
|
||||||
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
|
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user