[fix] Fix NPU FusedMoE and RMSNorm (#10512)

This commit is contained in:
xvxuopop
2026-05-30 21:42:54 +08:00
committed by GitHub
parent 7d719182c9
commit e016d2480e
2 changed files with 57 additions and 23 deletions

View File

@@ -228,6 +228,30 @@ class NpuMoeFused:
routed_out = self.experts(hidden_states, routing_weights, router_indices)
return routed_out
@staticmethod
def npu_moe_experts_v5_forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
Transformers v5 stores expert weights in F.linear layout:
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
down_proj: [num_experts, hidden_dim, intermediate_dim]
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
"""
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
hidden_states, top_k_index.to(torch.int32)
)
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
gate_up_proj = self.gate_up_proj.transpose(1, 2)
down_proj = self.down_proj.transpose(1, 2)
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
class Qwen3NpuMoeFused:
"""Container for Qwen3 NPU fused MoE forward functions."""
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
# moe patch config mapping
kernel_moe_mapping = {
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
if is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForCausalLM": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForConditionalGeneration": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
}
}
if not is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
else:
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
},
}

View File

@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
.. note::
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
without hardcoding version numbers. Two methods are used: weight value inspection
(most reliable) and class name pattern matching (backward compatibility).
This must follow the module's forward semantics. Do not infer it from trained
weight values because standard RMSNorm weights can also be close to zero.
"""
if hasattr(module, "weight") and module.weight is not None:
weight_mean = module.weight.data.mean().item()
if abs(weight_mean) < 0.3:
return True
residual_rmsnorm_classes = {
"Qwen3_5RMSNorm",
"Qwen3_5MoeRMSNorm",
"Qwen3NextRMSNorm",
}
class_name = module.__class__.__name__
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
for pattern in residual_patterns:
if pattern in class_name:
return True
return False
return class_name in residual_rmsnorm_classes
def npu_rms_norm_forward(self, hidden_states):
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
if hasattr(self, "weight") and self.weight is not None:
if _should_use_residual_rmsnorm(self):
if getattr(self, "_npu_use_residual_rmsnorm", False):
effective_weight = 1.0 + self.weight.float()
else:
effective_weight = self.weight.float()
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
if "Gated" in module.__class__.__name__:
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
else:
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model