From efb13b7483249ca2cc0149d7b99d33ae0ba514b8 Mon Sep 17 00:00:00 2001 From: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com> Date: Tue, 2 Dec 2025 00:22:03 +0800 Subject: [PATCH] [V1] Refactor ascend MoE kernel patch logic & Support Qwen3-MoE (#9557) --- src/llamafactory/v1/extras/packages.py | 43 ++++ .../kernels/mlp/npu_fused_moe.py | 200 ++++++++++++++---- 2 files changed, 203 insertions(+), 40 deletions(-) create mode 100644 src/llamafactory/v1/extras/packages.py diff --git a/src/llamafactory/v1/extras/packages.py b/src/llamafactory/v1/extras/packages.py new file mode 100644 index 00000000..8d86e01a --- /dev/null +++ b/src/llamafactory/v1/extras/packages.py @@ -0,0 +1,43 @@ +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +import importlib.util +from functools import lru_cache +from typing import TYPE_CHECKING + +from packaging import version + + +if TYPE_CHECKING: + from packaging.version import Version + + +def _is_package_available(name: str) -> bool: + return importlib.util.find_spec(name) is not None + + +def _get_package_version(name: str) -> "Version": + try: + return version.parse(importlib.metadata.version(name)) + except Exception: + return version.parse("0.0.0") + + +@lru_cache +def is_transformers_version_greater_than(content: str): + return _get_package_version("transformers") >= version.parse(content) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py index 16304995..21055760 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import types import torch +import torch.nn.functional as F import torch_npu +from .....extras.packages import is_transformers_version_greater_than from .....extras.types import HFModel from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ..constants import DeviceType, KernelType @@ -56,58 +57,177 @@ class GmmFunction(torch.autograd.Function): return grad_input, grad_weight, None -def npu_group_gemm(x, weight, group_list): - output = GmmFunction.apply(x, weight, group_list) - return output +class HybridGmmFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, num_experts, *args): + x_list = list(args[:num_experts]) + weight_list = list(args[num_experts:]) + + split_sizes = [x.shape[0] for x in x_list] + ctx.split_sizes = split_sizes + ctx.num_experts = num_experts + + ctx.save_for_backward(*args) + + outputs = torch_npu.npu_grouped_matmul( + x_list, weight_list, bias=None, group_list=None, split_item=0, group_type=-1 + ) + return tuple(outputs) + + @staticmethod + def backward(ctx, *grad_outputs): + saved_tensors = ctx.saved_tensors + num_experts = ctx.num_experts + split_sizes = ctx.split_sizes + + x_list = list(saved_tensors[:num_experts]) + weight_list = list(saved_tensors[num_experts:]) + + grad_outputs_contiguous = [g.contiguous() for g in grad_outputs] + + w_t_list = [w.t() for w in weight_list] + grad_x_list = torch_npu.npu_grouped_matmul( + grad_outputs_contiguous, # List[Tensor], 每个 [M_i, N] + w_t_list, # List[Tensor], 每个 [N, K] (view) + bias=None, + group_list=None, + split_item=0, + group_type=-1, + ) + + x_concat = torch.cat(x_list, dim=0) + dy_concat = torch.cat(grad_outputs_contiguous, dim=0) # [Total_M, N] + + group_list = torch.tensor(split_sizes, device=x_concat.device, dtype=torch.int64) + + grad_w_stack = torch_npu.npu_grouped_matmul( + [x_concat.t()], + [dy_concat], + bias=None, + group_list=group_list, + split_item=3, + group_type=2, + group_list_type=1, + )[0] + + if grad_w_stack.dim() == 3: + grad_w_list = list(torch.unbind(grad_w_stack, dim=0)) + else: + raise RuntimeError(f"Unexpected grad_w_stack shape: {grad_w_stack.shape}") + + return (None, *grad_x_list, *grad_w_list) -def npu_experts_qwen3vlmoe_forward( - self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor -) -> torch.Tensor: - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute( - hidden_states, router_indices.to(torch.int32) - ) - tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts) - intermediate_hidden_states = npu_group_gemm(permuted_hidden_states, self.gate_up_proj, tokens_per_expert) - intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1) - output = npu_group_gemm(intermediate_activations, self.down_proj, tokens_per_expert) - next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) - next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states +class NpuMoeFused: + @staticmethod + def npu_moe_experts_forward( + self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute( + hidden_states, router_indices.to(torch.int32) + ) + tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts) + intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, self.gate_up_proj, tokens_per_expert) + intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1) + output = GmmFunction.apply(intermediate_activations, self.down_proj, tokens_per_expert) + next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) + next_states = next_states.view(batch_size, -1, self.hidden_size) + return next_states + + @staticmethod + def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) + routed_out = self.experts(hidden_states, routing_weights, router_indices) + return routed_out -def npu_moe_block_qwen3vlmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - router_logits = self.gate(hidden_states) - routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) - routed_out = self.experts(hidden_states, routing_weights, router_indices) - return routed_out +class Qwen3NpuMoeFused: + @staticmethod + def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor): + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits = self.gate(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, selected_experts.int()) + + tokens_per_expert = torch.histc( + selected_experts.float(), bins=self.num_experts, min=0, max=self.num_experts + ).long() + split_sizes = tokens_per_expert.tolist() + + input_list = list(torch.split(permuted_hidden_states, split_sizes, dim=0)) + + gate_weights = [e.gate_proj.weight.t() for e in self.experts] + up_weights = [e.up_proj.weight.t() for e in self.experts] + down_weights = [e.down_proj.weight.t() for e in self.experts] + + gate_out_tuple = HybridGmmFunction.apply(len(input_list), *input_list, *gate_weights) + up_out_tuple = HybridGmmFunction.apply(len(input_list), *input_list, *up_weights) + + inter_list = [F.silu(g) * u for g, u in zip(gate_out_tuple, up_out_tuple)] + + down_out_tuple = HybridGmmFunction.apply(len(inter_list), *inter_list, *down_weights) + + grouped_output = torch.cat(down_out_tuple, dim=0) + + next_states = torch_npu.npu_moe_token_unpermute(grouped_output, row_ids_map, probs=routing_weights) + + next_states = next_states.view(batch_size, sequence_length, -1) + return next_states, router_logits -class NpuQwen3VLMoEFusedMoEKernel(MetaMoEKernel): +# moe patch config mapping +kernel_moe_mapping = { + "Qwen3VLMoeForConditionalGeneration": { + "Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward, + "Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward, + } +} + +if not is_transformers_version_greater_than("5.0.0"): + kernel_moe_mapping["Qwen3MoeForCausalLM"] = { + "Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward + } + + +class NpuMoEFusedMoEKernel(MetaMoEKernel): type = KernelType.MOE device = DeviceType.NPU - npu_experts_kernel = npu_experts_qwen3vlmoe_forward - npu_moe_block_kernel = npu_moe_block_qwen3vlmoe_forward @classmethod def apply(cls, model, **kwargs) -> HFModel: if not is_torch_npu_available(): return model - npu_experts_pattern = re.compile("Qwen3VLMoeTextExperts", re.IGNORECASE) - npu_moe_block_pattern = re.compile("Qwen3VLMoeTextSparseMoeBlock", re.IGNORECASE) + archs = getattr(model.config, "architectures", []) + target_moe_mapping = None + for arch in archs: + if arch in kernel_moe_mapping: + target_moe_mapping = kernel_moe_mapping[arch] + break + + if target_moe_mapping is None: + return model + for module in model.modules(): + class_name = module.__class__.__name__ + if class_name in target_moe_mapping: + new_forward_func = target_moe_mapping[class_name] + module.forward = types.MethodType(new_forward_func, module) - for _, module in model.named_modules(): - if re.search(npu_experts_pattern, module.__class__.__name__): - module.forward = types.MethodType(cls.npu_experts_kernel, module) - elif re.search(npu_moe_block_pattern, module.__class__.__name__): - module.forward = types.MethodType(cls.npu_moe_block_kernel, module) return model