[v1] Support fused moe kernel for qwen3vlmoe model. (#9532)

This commit is contained in:
xvxuopop 2025-11-27 02:13:33 +08:00 committed by GitHub
parent 2b6f16f261
commit 2c4fb3c97e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 142 additions and 0 deletions

View File

@ -0,0 +1,42 @@
### model
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
use_kernels: true # replaced kernels: [NpuRMSNormKernel, NpuRoPEKernel, NpuQwen3VLMoEFusedMoEKernel]
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
disable_gradient_checkpointing: false
flash_attn: disabled
### dataset
dataset: alpaca_zh_demo, alpaca_en_demo
template: qwen3_vl
cutoff_len: 1024
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen3vlmoe/lora/sft
logging_steps: 1
plot_loss: true
overwrite_output_dir: true
save_only_model: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 8
gradient_accumulation_steps: 1
learning_rate: 1.0e-4
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
seed: 1234

View File

@ -11,3 +11,103 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import types
import torch
import torch_npu
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaMoEKernel
class GmmFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, group_list):
ctx.save_for_backward(x, weight)
ctx.group_list = group_list
fwd_output = torch_npu.npu_grouped_matmul(
[x], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
)[0]
return fwd_output
@staticmethod
def backward(ctx, grad_output):
input_tensor, weight = ctx.saved_tensors
group_list = ctx.group_list
weight = torch.transpose(weight, 1, 2)
grad_input = torch_npu.npu_grouped_matmul(
[grad_output], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
)[0]
grad_weight = torch_npu.npu_grouped_matmul(
[input_tensor.T],
[grad_output],
bias=None,
group_list=group_list,
split_item=3,
group_type=2,
group_list_type=1,
)[0]
return grad_input, grad_weight, None
def npu_group_gemm(x, weight, group_list):
output = GmmFunction.apply(x, weight, group_list)
return output
def npu_experts_qwen3vlmoe_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
hidden_states, router_indices.to(torch.int32)
)
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
intermediate_hidden_states = npu_group_gemm(permuted_hidden_states, self.gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = npu_group_gemm(intermediate_activations, self.down_proj, tokens_per_expert)
next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights)
next_states = next_states.view(batch_size, -1, self.hidden_size)
return next_states
def npu_moe_block_qwen3vlmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states)
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
routed_out = self.experts(hidden_states, routing_weights, router_indices)
return routed_out
class NpuQwen3VLMoEFusedMoEKernel(MetaMoEKernel):
type = KernelType.MOE
device = DeviceType.NPU
npu_experts_kernel = npu_experts_qwen3vlmoe_forward
npu_moe_block_kernel = npu_moe_block_qwen3vlmoe_forward
@classmethod
def apply(cls, model, **kwargs) -> HFModel:
if not is_torch_npu_available():
return model
npu_experts_pattern = re.compile("Qwen3VLMoeTextExperts", re.IGNORECASE)
npu_moe_block_pattern = re.compile("Qwen3VLMoeTextSparseMoeBlock", re.IGNORECASE)
for _, module in model.named_modules():
if re.search(npu_experts_pattern, module.__class__.__name__):
module.forward = types.MethodType(cls.npu_experts_kernel, module)
elif re.search(npu_moe_block_pattern, module.__class__.__name__):
module.forward = types.MethodType(cls.npu_moe_block_kernel, module)
return model