From e016d2480e842e9c7205231fca3eeb0f73dc80f2 Mon Sep 17 00:00:00 2001 From: xvxuopop <127376094+xvxuopop@users.noreply.github.com> Date: Sat, 30 May 2026 21:42:54 +0800 Subject: [PATCH] [fix] Fix NPU FusedMoE and RMSNorm (#10512) --- .../kernels/ops/mlp/npu_fused_moe.py | 56 ++++++++++++++++--- .../kernels/ops/rms_norm/npu_rms_norm.py | 24 ++++---- 2 files changed, 57 insertions(+), 23 deletions(-) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py index 7b4e29269..41cab6020 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py @@ -228,6 +228,30 @@ class NpuMoeFused: routed_out = self.experts(hidden_states, routing_weights, router_indices) return routed_out + @staticmethod + def npu_moe_experts_v5_forward( + self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + ) -> torch.Tensor: + """Forward pass for Transformers v5+ MoE experts using NPU fused operations. + + Transformers v5 stores expert weights in F.linear layout: + gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim] + down_proj: [num_experts, hidden_dim, intermediate_dim] + The NPU grouped matmul path expects matmul layout, so both weights are transposed. + """ + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute( + hidden_states, top_k_index.to(torch.int32) + ) + tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long() + + gate_up_proj = self.gate_up_proj.transpose(1, 2) + down_proj = self.down_proj.transpose(1, 2) + intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert) + intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1) + output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert) + return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights) + class Qwen3NpuMoeFused: """Container for Qwen3 NPU fused MoE forward functions.""" @@ -283,16 +307,30 @@ class Qwen3NpuMoeFused: # moe patch config mapping -kernel_moe_mapping = { - "Qwen3VLMoeForConditionalGeneration": { - "Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward, - "Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward, +if is_transformers_version_greater_than("5.0.0"): + kernel_moe_mapping = { + "Qwen3MoeForCausalLM": { + "Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward, + }, + "Qwen3VLMoeForConditionalGeneration": { + "Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward, + }, + "Qwen3_5MoeForCausalLM": { + "Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward, + }, + "Qwen3_5MoeForConditionalGeneration": { + "Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward, + }, } -} - -if not is_transformers_version_greater_than("5.0.0"): - kernel_moe_mapping["Qwen3MoeForCausalLM"] = { - "Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward +else: + kernel_moe_mapping = { + "Qwen3MoeForCausalLM": { + "Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward, + }, + "Qwen3VLMoeForConditionalGeneration": { + "Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward, + "Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward, + }, } diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py index 3b1c39f88..04a4ad16f 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py @@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module): bool: ``True`` if the module uses residual parameterization, ``False`` otherwise. .. note:: - This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0) - without hardcoding version numbers. Two methods are used: weight value inspection - (most reliable) and class name pattern matching (backward compatibility). + This must follow the module's forward semantics. Do not infer it from trained + weight values because standard RMSNorm weights can also be close to zero. """ - if hasattr(module, "weight") and module.weight is not None: - weight_mean = module.weight.data.mean().item() - if abs(weight_mean) < 0.3: - return True + residual_rmsnorm_classes = { + "Qwen3_5RMSNorm", + "Qwen3_5MoeRMSNorm", + "Qwen3NextRMSNorm", + } class_name = module.__class__.__name__ - residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"] - for pattern in residual_patterns: - if pattern in class_name: - return True - - return False + return class_name in residual_rmsnorm_classes def npu_rms_norm_forward(self, hidden_states): @@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states): _eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6) if hasattr(self, "weight") and self.weight is not None: - if _should_use_residual_rmsnorm(self): + if getattr(self, "_npu_use_residual_rmsnorm", False): effective_weight = 1.0 + self.weight.float() else: effective_weight = self.weight.float() @@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel): if "Gated" in module.__class__.__name__: module.forward = types.MethodType(npu_gated_rms_norm_forward, module) else: + module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module) module.forward = types.MethodType(npu_rms_norm_forward, module) return model