From 2c4fb3c97e74d42c2bd31eff004d4cf407016509 Mon Sep 17 00:00:00 2001 From: xvxuopop <127376094+xvxuopop@users.noreply.github.com> Date: Thu, 27 Nov 2025 02:13:33 +0800 Subject: [PATCH] [v1] Support fused moe kernel for qwen3vlmoe model. (#9532) --- examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml | 42 ++++++++ .../kernels/mlp/npu_fused_moe.py | 100 ++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml diff --git a/examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml b/examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml new file mode 100644 index 00000000..aa45f2cb --- /dev/null +++ b/examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml @@ -0,0 +1,42 @@ +### model +model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct +image_max_pixels: 262144 +video_max_pixels: 16384 +trust_remote_code: true +use_kernels: true # replaced kernels: [NpuRMSNormKernel, NpuRoPEKernel, NpuQwen3VLMoEFusedMoEKernel] + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_rank: 8 +lora_target: all +disable_gradient_checkpointing: false +flash_attn: disabled + +### dataset +dataset: alpaca_zh_demo, alpaca_en_demo +template: qwen3_vl +cutoff_len: 1024 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/qwen3vlmoe/lora/sft +logging_steps: 1 +plot_loss: true +overwrite_output_dir: true +save_only_model: true +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 8 +gradient_accumulation_steps: 1 +learning_rate: 1.0e-4 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 +resume_from_checkpoint: null +seed: 1234 \ No newline at end of file diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py index ec0d6255..16304995 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py @@ -11,3 +11,103 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import re +import types + +import torch +import torch_npu + +from .....extras.types import HFModel +from ....trainer_plugins.distributed.accelerate import is_torch_npu_available +from ..constants import DeviceType, KernelType +from ..registry import MetaMoEKernel + + +class GmmFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, group_list): + ctx.save_for_backward(x, weight) + ctx.group_list = group_list + + fwd_output = torch_npu.npu_grouped_matmul( + [x], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1 + )[0] + return fwd_output + + @staticmethod + def backward(ctx, grad_output): + input_tensor, weight = ctx.saved_tensors + group_list = ctx.group_list + + weight = torch.transpose(weight, 1, 2) + grad_input = torch_npu.npu_grouped_matmul( + [grad_output], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1 + )[0] + grad_weight = torch_npu.npu_grouped_matmul( + [input_tensor.T], + [grad_output], + bias=None, + group_list=group_list, + split_item=3, + group_type=2, + group_list_type=1, + )[0] + return grad_input, grad_weight, None + + +def npu_group_gemm(x, weight, group_list): + output = GmmFunction.apply(x, weight, group_list) + return output + + +def npu_experts_qwen3vlmoe_forward( + self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor +) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute( + hidden_states, router_indices.to(torch.int32) + ) + tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts) + intermediate_hidden_states = npu_group_gemm(permuted_hidden_states, self.gate_up_proj, tokens_per_expert) + intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1) + output = npu_group_gemm(intermediate_activations, self.down_proj, tokens_per_expert) + next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) + next_states = next_states.view(batch_size, -1, self.hidden_size) + return next_states + + +def npu_moe_block_qwen3vlmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) + routed_out = self.experts(hidden_states, routing_weights, router_indices) + return routed_out + + +class NpuQwen3VLMoEFusedMoEKernel(MetaMoEKernel): + type = KernelType.MOE + device = DeviceType.NPU + npu_experts_kernel = npu_experts_qwen3vlmoe_forward + npu_moe_block_kernel = npu_moe_block_qwen3vlmoe_forward + + @classmethod + def apply(cls, model, **kwargs) -> HFModel: + if not is_torch_npu_available(): + return model + + npu_experts_pattern = re.compile("Qwen3VLMoeTextExperts", re.IGNORECASE) + npu_moe_block_pattern = re.compile("Qwen3VLMoeTextSparseMoeBlock", re.IGNORECASE) + + for _, module in model.named_modules(): + if re.search(npu_experts_pattern, module.__class__.__name__): + module.forward = types.MethodType(cls.npu_experts_kernel, module) + elif re.search(npu_moe_block_pattern, module.__class__.__name__): + module.forward = types.MethodType(cls.npu_moe_block_kernel, module) + return model