mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[model] fix moe zero3 (#7826)
This commit is contained in:
		
							parent
							
								
									fa0eb91f1f
								
							
						
					
					
						commit
						c1a7f2ebb2
					
				@ -44,6 +44,16 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [DbrxFFN])
 | 
			
		||||
 | 
			
		||||
    if model_type == "deepseek_v3":
 | 
			
		||||
        from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [DeepseekV3MoE])
 | 
			
		||||
 | 
			
		||||
    if model_type == "granitemoe":
 | 
			
		||||
        from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [GraniteMoeMoE])
 | 
			
		||||
 | 
			
		||||
    if model_type == "jamba":
 | 
			
		||||
        from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
 | 
			
		||||
 | 
			
		||||
@ -54,27 +64,55 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
 | 
			
		||||
 | 
			
		||||
    if model_type in ["kimi_vl", "deepseek_v3"]:
 | 
			
		||||
        check_version("transformers>=4.51.1")
 | 
			
		||||
        from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
 | 
			
		||||
    if model_type == "llama4":
 | 
			
		||||
        from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [DeepseekV3MoE])
 | 
			
		||||
        _set_z3_leaf_modules(model, [Llama4TextMoe])
 | 
			
		||||
 | 
			
		||||
    if model_type == "mixtral":
 | 
			
		||||
        from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
 | 
			
		||||
 | 
			
		||||
    if model_type == "olmoe":
 | 
			
		||||
        from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [OlmoeSparseMoeBlock])
 | 
			
		||||
 | 
			
		||||
    if model_type == "phimoe":
 | 
			
		||||
        from transformers.models.phimoe.modeling_phimoe import PhimoeSparseMoeBlock
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [PhimoeSparseMoeBlock])
 | 
			
		||||
 | 
			
		||||
    if model_type == "qwen2_moe":
 | 
			
		||||
        from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
 | 
			
		||||
 | 
			
		||||
    if model_type == "qwen3_moe":
 | 
			
		||||
        from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
    model_type = getattr(config, "model_type", None)
 | 
			
		||||
    if model_args.moe_aux_loss_coef is not None:
 | 
			
		||||
        if model_type in ["jamba", "mixtral", "qwen2_moe"]:
 | 
			
		||||
        if model_type in [
 | 
			
		||||
            "dbrx",
 | 
			
		||||
            "granitemoe",
 | 
			
		||||
            "jamba",
 | 
			
		||||
            "jetmoe",
 | 
			
		||||
            "llama4",
 | 
			
		||||
            "mixtral",
 | 
			
		||||
            "olmoe",
 | 
			
		||||
            "phimoe",
 | 
			
		||||
            "qwen2_moe",
 | 
			
		||||
            "qwen3_moe",
 | 
			
		||||
        ]:
 | 
			
		||||
            setattr(config, "output_router_logits", is_trainable)
 | 
			
		||||
 | 
			
		||||
        if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
 | 
			
		||||
            setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
			
		||||
 | 
			
		||||
        elif model_type == "deepseek":
 | 
			
		||||
@ -82,6 +120,3 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
 | 
			
		||||
 | 
			
		||||
        elif model_type == "jetmoe":
 | 
			
		||||
            setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
			
		||||
 | 
			
		||||
    if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
 | 
			
		||||
        setattr(config, "output_router_logits", is_trainable)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user