mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
[fix] Fix NPU FusedMoE and RMSNorm (#10512)
This commit is contained in:
@@ -228,6 +228,30 @@ class NpuMoeFused:
|
||||
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
||||
return routed_out
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_v5_forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
|
||||
|
||||
Transformers v5 stores expert weights in F.linear layout:
|
||||
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
|
||||
down_proj: [num_experts, hidden_dim, intermediate_dim]
|
||||
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
|
||||
"""
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||
hidden_states, top_k_index.to(torch.int32)
|
||||
)
|
||||
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
|
||||
|
||||
gate_up_proj = self.gate_up_proj.transpose(1, 2)
|
||||
down_proj = self.down_proj.transpose(1, 2)
|
||||
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
|
||||
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
|
||||
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
|
||||
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
|
||||
|
||||
|
||||
# moe patch config mapping
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
||||
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3_5MoeForCausalLM": {
|
||||
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3_5MoeForConditionalGeneration": {
|
||||
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
|
||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
|
||||
else:
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
|
||||
},
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
||||
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
|
||||
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
|
||||
|
||||
.. note::
|
||||
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
|
||||
without hardcoding version numbers. Two methods are used: weight value inspection
|
||||
(most reliable) and class name pattern matching (backward compatibility).
|
||||
This must follow the module's forward semantics. Do not infer it from trained
|
||||
weight values because standard RMSNorm weights can also be close to zero.
|
||||
"""
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
weight_mean = module.weight.data.mean().item()
|
||||
if abs(weight_mean) < 0.3:
|
||||
return True
|
||||
residual_rmsnorm_classes = {
|
||||
"Qwen3_5RMSNorm",
|
||||
"Qwen3_5MoeRMSNorm",
|
||||
"Qwen3NextRMSNorm",
|
||||
}
|
||||
|
||||
class_name = module.__class__.__name__
|
||||
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
|
||||
for pattern in residual_patterns:
|
||||
if pattern in class_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
return class_name in residual_rmsnorm_classes
|
||||
|
||||
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
|
||||
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||
|
||||
if hasattr(self, "weight") and self.weight is not None:
|
||||
if _should_use_residual_rmsnorm(self):
|
||||
if getattr(self, "_npu_use_residual_rmsnorm", False):
|
||||
effective_weight = 1.0 + self.weight.float()
|
||||
else:
|
||||
effective_weight = self.weight.float()
|
||||
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
if "Gated" in module.__class__.__name__:
|
||||
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
|
||||
else:
|
||||
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
|
||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user