mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
[fix] Fix NPU FusedMoE and RMSNorm (#10512)
This commit is contained in:
@@ -228,6 +228,30 @@ class NpuMoeFused:
|
|||||||
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
||||||
return routed_out
|
return routed_out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def npu_moe_experts_v5_forward(
|
||||||
|
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
|
||||||
|
|
||||||
|
Transformers v5 stores expert weights in F.linear layout:
|
||||||
|
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
|
||||||
|
down_proj: [num_experts, hidden_dim, intermediate_dim]
|
||||||
|
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
|
||||||
|
"""
|
||||||
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||||
|
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||||
|
hidden_states, top_k_index.to(torch.int32)
|
||||||
|
)
|
||||||
|
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
|
||||||
|
|
||||||
|
gate_up_proj = self.gate_up_proj.transpose(1, 2)
|
||||||
|
down_proj = self.down_proj.transpose(1, 2)
|
||||||
|
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
|
||||||
|
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
|
||||||
|
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
|
||||||
|
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3NpuMoeFused:
|
class Qwen3NpuMoeFused:
|
||||||
"""Container for Qwen3 NPU fused MoE forward functions."""
|
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||||
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
|
|||||||
|
|
||||||
|
|
||||||
# moe patch config mapping
|
# moe patch config mapping
|
||||||
kernel_moe_mapping = {
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
|
kernel_moe_mapping = {
|
||||||
|
"Qwen3MoeForCausalLM": {
|
||||||
|
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||||
|
},
|
||||||
|
"Qwen3VLMoeForConditionalGeneration": {
|
||||||
|
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||||
|
},
|
||||||
|
"Qwen3_5MoeForCausalLM": {
|
||||||
|
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||||
|
},
|
||||||
|
"Qwen3_5MoeForConditionalGeneration": {
|
||||||
|
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
kernel_moe_mapping = {
|
||||||
|
"Qwen3MoeForCausalLM": {
|
||||||
|
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
|
||||||
|
},
|
||||||
"Qwen3VLMoeForConditionalGeneration": {
|
"Qwen3VLMoeForConditionalGeneration": {
|
||||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
||||||
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
||||||
}
|
},
|
||||||
}
|
|
||||||
|
|
||||||
if not is_transformers_version_greater_than("5.0.0"):
|
|
||||||
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
|
|
||||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
|
|||||||
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
|
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
|
This must follow the module's forward semantics. Do not infer it from trained
|
||||||
without hardcoding version numbers. Two methods are used: weight value inspection
|
weight values because standard RMSNorm weights can also be close to zero.
|
||||||
(most reliable) and class name pattern matching (backward compatibility).
|
|
||||||
"""
|
"""
|
||||||
if hasattr(module, "weight") and module.weight is not None:
|
residual_rmsnorm_classes = {
|
||||||
weight_mean = module.weight.data.mean().item()
|
"Qwen3_5RMSNorm",
|
||||||
if abs(weight_mean) < 0.3:
|
"Qwen3_5MoeRMSNorm",
|
||||||
return True
|
"Qwen3NextRMSNorm",
|
||||||
|
}
|
||||||
|
|
||||||
class_name = module.__class__.__name__
|
class_name = module.__class__.__name__
|
||||||
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
|
return class_name in residual_rmsnorm_classes
|
||||||
for pattern in residual_patterns:
|
|
||||||
if pattern in class_name:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def npu_rms_norm_forward(self, hidden_states):
|
def npu_rms_norm_forward(self, hidden_states):
|
||||||
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
|
|||||||
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||||
|
|
||||||
if hasattr(self, "weight") and self.weight is not None:
|
if hasattr(self, "weight") and self.weight is not None:
|
||||||
if _should_use_residual_rmsnorm(self):
|
if getattr(self, "_npu_use_residual_rmsnorm", False):
|
||||||
effective_weight = 1.0 + self.weight.float()
|
effective_weight = 1.0 + self.weight.float()
|
||||||
else:
|
else:
|
||||||
effective_weight = self.weight.float()
|
effective_weight = self.weight.float()
|
||||||
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
|
|||||||
if "Gated" in module.__class__.__name__:
|
if "Gated" in module.__class__.__name__:
|
||||||
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
|
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
|
||||||
else:
|
else:
|
||||||
|
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
|
||||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user