[v1] add cuda fused moe kernel, implementing with triton (#10481)

This commit is contained in:
浮梦
2026-05-20 20:49:42 +08:00
committed by GitHub
parent 368c48968f
commit 2322bf1cc2
7 changed files with 856 additions and 10 deletions

View File

@@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer):
self.running = RunningMoments(self.accelerator) self.running = RunningMoments(self.accelerator)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -69,10 +69,10 @@ class CustomTrainer(Trainer):
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer):
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -0,0 +1,429 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pure-Triton Fused MoE Kernel for NVIDIA GPUs.
Replaces the HuggingFace per-expert Python loop with a fully fused Triton pipeline:
- Forward: scatter → grouped GEMM fc1 → SiLU·gate → apply routing → grouped GEMM fc2 → gather
- Backward: all dX via grouped GEMM, all dW via grouped GEMM (no Python loops)
Supported models: Mixtral, Qwen3-MoE, Qwen3.5-MoE.
"""
import logging
import types
import torch
import torch.nn.functional as F
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
from .triton_grouped_gemm import (
group_gemm_same_mn,
group_gemm_same_nk,
moe_gather,
moe_scatter,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Autograd Function: Full Triton MoE forward + backward
# ---------------------------------------------------------------------------
class TritonFusedMoeFunction(torch.autograd.Function):
"""Fused MoE expert computation using Triton grouped GEMMs.
Forward: scatter → fc1 (group GEMM) → SiLU·gate → weight → fc2 (group GEMM) → gather
Backward: all gradients computed via grouped GEMMs in single kernel launches.
"""
@staticmethod
def forward(
ctx,
num_experts,
gate_weights,
expert_index,
hidden_states,
fc1_weight,
fc2_weight,
):
"""Forward pass.
Args:
ctx: autograd context
num_experts: int
gate_weights: (num_tokens, top_k) routing weights
expert_index: (num_tokens, top_k) expert assignments
hidden_states: (num_tokens, hidden_dim)
fc1_weight: (E, 2*inter, hidden) merged gate+up weight
fc2_weight: (E, hidden, inter) down projection weight
"""
# Compute scatter index: maps (token, topk) → position in sorted buffer
scatter_index = expert_index.flatten().argsort(stable=True).argsort().int().view(expert_index.shape)
# Token counts per expert and cumulative boundaries
splits = torch.zeros(num_experts, dtype=torch.int32, device=hidden_states.device)
flat_experts = expert_index.flatten().int()
splits.scatter_add_(0, flat_experts.long(), torch.ones_like(flat_experts))
cumsum_t = torch.cumsum(splits, dim=0)
# Scatter hidden states to sorted expert buffer
scatter_output = moe_scatter(hidden_states, scatter_index)
# FC1: grouped GEMM (scatter_output @ fc1_weight.T)
max_M = int(splits.max().item())
fc1_output = group_gemm_same_nk(
a=scatter_output,
b=fc1_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=True,
)
# SiLU gate activation
fc1_1_output, fc1_2_output = fc1_output.chunk(2, dim=-1)
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
fc1_activation = fc1_1_activation * fc1_2_output
# Apply routing weights before fc2 (mathematically equivalent to after)
reshaped_gate_weight = gate_weights.reshape(-1, 1)
scattered_gate_weight = torch.empty_like(reshaped_gate_weight)
scattered_gate_weight[scatter_index.flatten().long()] = reshaped_gate_weight
fc1_weighted_output = fc1_activation * scattered_gate_weight
# FC2: grouped GEMM (fc1_weighted @ fc2_weight.T)
fc2_output = group_gemm_same_nk(
a=fc1_weighted_output,
b=fc2_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=True,
)
# Gather back to original token positions (sum over topk)
expert_output = moe_gather(fc2_output, scatter_index)
ctx.num_experts = num_experts
ctx.save_for_backward(
gate_weights,
fc1_weight,
fc2_weight,
hidden_states,
scatter_index,
scatter_output,
cumsum_t,
fc1_1_output,
fc1_2_output,
fc1_activation,
scattered_gate_weight,
fc1_weighted_output,
)
return expert_output
@staticmethod
def backward(ctx, grad_output):
(
gate_weights,
fc1_weight,
fc2_weight,
hidden_states,
scatter_index,
scatter_output,
cumsum_t,
fc1_1_output,
fc1_2_output,
fc1_activation,
scattered_gate_weight,
fc1_weighted_output,
) = ctx.saved_tensors
num_experts = ctx.num_experts
hidden_dim = grad_output.shape[-1]
grad_output = grad_output.reshape(-1, hidden_dim).contiguous()
# Recompute max_M from cumsum
splits = torch.zeros(num_experts, dtype=cumsum_t.dtype, device=cumsum_t.device)
splits[0] = cumsum_t[0]
splits[1:] = cumsum_t[1:] - cumsum_t[:-1]
max_M = int(splits.max().item())
# Step 1: Scatter grad_output to expert buffer
grad_fc2_output = moe_scatter(grad_output, scatter_index)
# Step 2: FC2 backward
# dX for fc2: grad_fc2_output @ fc2_weight (transpose_b=False since fc2 is (E, hidden, inter))
grad_fc1_weighted_output = group_gemm_same_nk(
a=grad_fc2_output,
b=fc2_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=False,
)
# dW for fc2: grad_fc2_output.T @ fc1_weighted_output
grad_fc2_weight = None
if fc2_weight.requires_grad:
grad_fc2_weight = torch.empty_like(fc2_weight)
group_gemm_same_mn(
a=grad_fc2_output,
b=fc1_weighted_output,
c=grad_fc2_weight,
cumsum_K=cumsum_t,
)
# Step 3: Routing weight backward
grad_fc1_activation = grad_fc1_weighted_output * scattered_gate_weight
grad_scattered_gate_weight = torch.sum(fc1_activation * grad_fc1_weighted_output, dim=-1)
grad_gate_weight = grad_scattered_gate_weight[scatter_index.flatten().long()]
grad_gate_weight = grad_gate_weight.reshape(gate_weights.shape)
# Recompute silu activation for backward
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
# Step 4: SiLU gate backward
grad_fc1_1_activation = grad_fc1_activation * fc1_2_output
grad_fc1_2_output = fc1_1_activation * grad_fc1_activation
# SiLU backward: d/dx[x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
grad_fc1_1_output = torch.ops.aten.silu_backward(grad_fc1_1_activation, fc1_1_output)
# Merge fc1 gradients back to (total_M, 2*inter)
grad_fc1_output = torch.cat([grad_fc1_1_output, grad_fc1_2_output], dim=-1)
# Step 5: FC1 backward
# dX for fc1: grad_fc1_output @ fc1_weight (transpose_b=False)
grad_scatter_output = group_gemm_same_nk(
a=grad_fc1_output,
b=fc1_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=False,
)
# dW for fc1: grad_fc1_output.T @ scatter_output
grad_fc1_weight = None
if fc1_weight.requires_grad:
grad_fc1_weight = torch.empty_like(fc1_weight)
group_gemm_same_mn(
a=grad_fc1_output,
b=scatter_output,
c=grad_fc1_weight,
cumsum_K=cumsum_t,
)
# Step 6: Gather gradients back to original positions
grad_hidden_states = moe_gather(grad_scatter_output, scatter_index)
grad_hidden_states = grad_hidden_states.reshape(hidden_states.shape)
return (
None, # num_experts
grad_gate_weight, # gate_weights
None, # expert_index
grad_hidden_states, # hidden_states
grad_fc1_weight, # fc1_weight
grad_fc2_weight, # fc2_weight
)
# ---------------------------------------------------------------------------
# Patched forward functions
# ---------------------------------------------------------------------------
def _triton_moe_experts_forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""Replacement forward for v5+ MoE expert modules with stacked 3D weights."""
return TritonFusedMoeFunction.apply(
self.num_experts,
top_k_weights.to(hidden_states.dtype),
top_k_index,
hidden_states,
self.gate_up_proj,
self.down_proj,
)
# ---------------------------------------------------------------------------
# Legacy (transformers < 5.0) support: weight stacking + SparseMoeBlock patch
# ---------------------------------------------------------------------------
class _StackedExpertWeights(torch.nn.Module):
"""Lightweight container holding stacked 3D expert weight tensors."""
def __init__(self, gate_up_proj: torch.Tensor, down_proj: torch.Tensor, num_experts: int):
super().__init__()
self.gate_up_proj = torch.nn.Parameter(gate_up_proj)
self.down_proj = torch.nn.Parameter(down_proj)
self.num_experts = num_experts
def _stack_expert_weights(module: torch.nn.Module) -> None:
"""Replace nn.ModuleList of individual experts with stacked 3D parameter tensors."""
experts = module.experts
num_experts = len(experts)
gate_up_list = []
for expert in experts:
gate_w = expert.gate_proj.weight.data # (inter, hidden)
up_w = expert.up_proj.weight.data # (inter, hidden)
gate_up_list.append(torch.cat([gate_w, up_w], dim=0)) # (2*inter, hidden)
gate_up_proj = torch.stack(gate_up_list, dim=0) # (E, 2*inter, hidden)
down_proj = torch.stack([e.down_proj.weight.data for e in experts], dim=0) # (E, hidden, inter)
module.experts = _StackedExpertWeights(gate_up_proj, down_proj, num_experts)
logger.info(
f"cuda_fused_moe: Stacked {num_experts} expert weights into "
f"gate_up_proj {tuple(gate_up_proj.shape)}, down_proj {tuple(down_proj.shape)}"
)
def _triton_moe_sparse_block_forward(self, hidden_states: torch.Tensor):
"""Replacement forward for legacy SparseMoeBlock with inline routing."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = TritonFusedMoeFunction.apply(
self.num_experts,
routing_weights,
selected_experts,
hidden_states,
self.experts.gate_up_proj,
self.experts.down_proj,
)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
# ---------------------------------------------------------------------------
# Module mapping
# ---------------------------------------------------------------------------
_TRITON_MOE_MAPPING: dict[str, dict[str, object]] = {
"MixtralForCausalLM": {
"MixtralExperts": _triton_moe_experts_forward,
},
"Qwen3MoeForCausalLM": {
"Qwen3MoeExperts": _triton_moe_experts_forward,
"Qwen3MoeSparseMoeBlock": _triton_moe_sparse_block_forward,
},
"Qwen3_5MoeForCausalLM": {
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
},
"Qwen3_5MoeForConditionalGeneration": {
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
},
}
# ---------------------------------------------------------------------------
# Kernel registration
# ---------------------------------------------------------------------------
@register_kernel
class CudaFusedMoEKernel(BaseKernel):
"""Pure-Triton fused MoE kernel for NVIDIA CUDA GPUs.
Replaces HuggingFace per-expert Python loops with a fully fused Triton pipeline:
- Forward: scatter + grouped GEMMs + gather (single kernel per GEMM)
- Backward: all dX and dW via grouped GEMMs (no Python loops)
Requires: CUDA GPU + Triton
"""
_kernel_id = "cuda_fused_moe"
_device = DeviceType.CUDA
@classmethod
def check_deps(cls) -> bool:
if not super().check_deps():
return False
try:
import triton # noqa: F401
return True
except ImportError:
logger.info("cuda_fused_moe: Triton not available, kernel disabled.")
return False
@classmethod
def apply(cls, **kwargs) -> HFModel:
model = kwargs.get("model")
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
logger.warning("cuda_fused_moe: Dependencies not met. Skipping kernel application.")
return model
archs = getattr(model.config, "architectures", None) or []
target_mapping = None
for arch in archs:
if arch in _TRITON_MOE_MAPPING:
target_mapping = _TRITON_MOE_MAPPING[arch]
break
if target_mapping is None:
logger.info(
f"cuda_fused_moe: Model architecture {archs} not supported. "
f"Supported: {list(_TRITON_MOE_MAPPING.keys())}"
)
return model
patched_count = 0
for module in model.modules():
class_name = module.__class__.__name__
if class_name not in target_mapping:
continue
target_fn = target_mapping[class_name]
if hasattr(module, "gate_up_proj") and hasattr(module, "down_proj"):
module.forward = types.MethodType(target_fn, module)
patched_count += 1
elif (
hasattr(module, "experts")
and isinstance(module.experts, torch.nn.ModuleList)
and hasattr(module, "gate")
):
_stack_expert_weights(module)
module.forward = types.MethodType(target_fn, module)
patched_count += 1
if patched_count > 0:
logger.info(f"cuda_fused_moe: Patched {patched_count} MoE expert modules with pure Triton pipeline.")
else:
logger.warning("cuda_fused_moe: No MoE expert modules found to patch.")
return model

View File

@@ -0,0 +1,417 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Pure-Triton grouped GEMM and MoE scatter/gather kernels.
# Design adapted from VeOmni (ByteDance-Seed/VeOmni) group_gemm kernels.
"""Pure-Triton MoE kernels: grouped GEMM, scatter, and gather.
Provides four kernel types for fused MoE forward+backward without Python loops:
- group_gemm_same_nk: Variable-M grouped GEMM (forward & backward dX)
- group_gemm_same_mn: Variable-K grouped GEMM (backward dW)
- moe_scatter: Token dispatch to sorted expert buffers
- moe_gather: Token reduction from expert buffers
"""
import torch
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# Triton helper: grouped tile indexing with L2 cache-friendly swizzle
# ---------------------------------------------------------------------------
@triton.jit
def _get_pid_mn(pid, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr):
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
return pid_m, pid_n
# ---------------------------------------------------------------------------
# group_gemm_same_nk: All experts share same N, K; variable M per expert
# Used for: forward (x @ W.T) and backward dX (grad @ W)
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
],
key=["N", "K"],
)
@triton.jit
def _group_gemm_same_nk_kernel(
a_ptr,
b_ptr,
c_ptr,
cumsum_M,
max_M,
G: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
TRANSPOSE_B: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP: tl.constexpr,
):
pid_m, pid_n = _get_pid_mn(tl.program_id(0), max_M, N, BLOCK_M, BLOCK_N, GROUP)
gid = tl.program_id(1).to(tl.int64)
gtid_start = tl.load(cumsum_M + gid - 1, mask=gid > 0, other=0).to(tl.int64)
gtid_end = tl.load(cumsum_M + gid).to(tl.int64)
m_size = gtid_end - gtid_start
if pid_m * BLOCK_M >= m_size:
return
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# a is (total_M, K) row-major, offset by expert start
a_base = a_ptr + gtid_start * K
# b is (G, N, K) if TRANSPOSE_B else (G, K, N)
b_base = b_ptr + gid * K * N
# c is (total_M, N) row-major
c_base = c_ptr + gtid_start * N
if TRANSPOSE_B:
# b layout: (G, N, K), we compute a @ b.T = a(M,K) @ b(N,K).T -> (M,N)
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_base + offs_n[:, None] * K + offs_k[None, :]
else:
# b layout: (G, K, N), we compute a @ b = a(M,K) @ b(K,N) -> (M,N)
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_base + offs_k[:, None] * N + offs_n[None, :]
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
k_offs = k_start + offs_k
k_mask = k_offs < K
a_block = tl.load(a_ptrs, mask=(offs_m[:, None] < m_size) & k_mask[None, :], other=0.0)
if TRANSPOSE_B:
b_block = tl.load(b_ptrs, mask=(offs_n[:, None] < N) & k_mask[None, :], other=0.0)
acc += tl.dot(a_block, tl.trans(b_block))
else:
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a_block, b_block)
if TRANSPOSE_B:
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K
else:
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K * N
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < m_size) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
def group_gemm_same_nk(
a: torch.Tensor,
b: torch.Tensor,
cumsum_M: torch.Tensor,
max_M: int,
transpose_b: bool = False,
) -> torch.Tensor:
"""Grouped GEMM where all groups share same N, K dimensions but variable M.
Args:
a: (total_M, K) input tensor, rows grouped by expert
b: (G, N, K) if transpose_b else (G, K, N) weight tensor
cumsum_M: (G,) cumulative token counts per expert
max_M: maximum tokens any single expert has
transpose_b: if True, compute a @ b.T; else compute a @ b
Returns:
c: (total_M, N) output tensor
"""
if transpose_b:
G, N, K = b.shape
else:
G, K, N = b.shape
c = torch.empty((a.shape[0], N), dtype=a.dtype, device=a.device)
_group_gemm_same_nk_kernel[
(lambda meta: (triton.cdiv(max_M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))
](
a_ptr=a,
b_ptr=b,
c_ptr=c,
cumsum_M=cumsum_M,
max_M=max_M,
G=G,
N=N,
K=K,
TRANSPOSE_B=transpose_b,
)
return c
# ---------------------------------------------------------------------------
# group_gemm_same_mn: All experts share same M, N (weight dims); variable K
# Used for: backward dW (grad.T @ input)
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
],
key=["M", "N"],
)
@triton.jit
def _group_gemm_same_mn_kernel(
a_ptr,
b_ptr,
c_ptr,
cumsum_K,
G: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP: tl.constexpr,
):
pid_m, pid_n = _get_pid_mn(tl.program_id(0), M, N, BLOCK_M, BLOCK_N, GROUP)
gid = tl.program_id(1).to(tl.int64)
gtid_start = tl.load(cumsum_K + gid - 1, mask=gid > 0, other=0).to(tl.int64)
gtid_end = tl.load(cumsum_K + gid).to(tl.int64)
k_size = gtid_end - gtid_start
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# c is (G, M, N)
c_base = c_ptr + gid * M * N
if k_size == 0:
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, tl.zeros((BLOCK_M, BLOCK_N), dtype=c_ptr.dtype.element_ty), mask=c_mask)
return
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_k = tl.arange(0, BLOCK_K)
# a is (total_K, M), compute a.T @ b -> (M, N)
# b is (total_K, N)
a_base = a_ptr + gtid_start * M
b_base = b_ptr + gtid_start * N
for k_start in range(0, k_size, BLOCK_K):
k_offs = k_start + offs_k
k_mask = k_offs < k_size
a_ptrs = a_base + k_offs[:, None] * M + offs_m[None, :]
a_block_t = tl.trans(tl.load(a_ptrs, mask=k_mask[:, None] & (offs_m[None, :] < M), other=0.0))
# Load b block: (BLOCK_K, BLOCK_N)
b_ptrs = b_base + k_offs[:, None] * N + offs_n[None, :]
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a_block_t, b_block)
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
def group_gemm_same_mn(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
cumsum_K: torch.Tensor,
) -> None:
"""Grouped GEMM where all groups produce same (M, N) output; variable K reduction.
Computes: c[g] = a[s:e].T @ b[s:e] for each group g,
where s, e are defined by cumsum_K boundaries.
Args:
a: (total_K, M) input tensor grouped by expert
b: (total_K, N) input tensor grouped by expert
c: (G, M, N) output tensor (pre-allocated)
cumsum_K: (G,) cumulative token counts per expert
"""
G, M, N = c.shape
_group_gemm_same_mn_kernel[(lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))](
a_ptr=a,
b_ptr=b,
c_ptr=c,
cumsum_K=cumsum_K,
G=G,
M=M,
N=N,
)
# ---------------------------------------------------------------------------
# moe_scatter: Dispatch tokens to sorted expert buffer positions
# ---------------------------------------------------------------------------
@triton.jit
def _moe_scatter_kernel(
x_ptr,
out_ptr,
index_ptr,
M,
N: tl.constexpr,
TOPK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Scatter: for each token i, copy x[i] to out[index[i, k]] for k in 0..topk-1."""
pid_m = tl.program_id(0).to(tl.int64)
pid_n = tl.program_id(1)
if pid_m >= M:
return
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
# Load input row
x_ptrs = x_ptr + pid_m * N + offs_n
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0)
# Store to each topk destination
for k in tl.static_range(TOPK):
dst_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
out_ptrs = out_ptr + dst_idx * N + offs_n
tl.store(out_ptrs, x_vals, mask=n_mask)
def moe_scatter(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""Scatter tokens to sorted expert buffer.
For each token i and topk slot k, copies x[i] to output[index[i, k]].
Args:
x: (M, N) input hidden states
index: (M, topk) scatter indices
Returns:
out: (M * topk, N) scattered output
"""
M, N = x.shape
topk = index.shape[1]
out = torch.empty(M * topk, N, dtype=x.dtype, device=x.device)
BLOCK_N = min(triton.next_power_of_2(N), 1024)
grid = (M, triton.cdiv(N, BLOCK_N))
_moe_scatter_kernel[grid](
x_ptr=x,
out_ptr=out,
index_ptr=index,
M=M,
N=N,
TOPK=topk,
BLOCK_N=BLOCK_N,
)
return out
# ---------------------------------------------------------------------------
# moe_gather: Reduce expert outputs back to token positions (sum over topk)
# ---------------------------------------------------------------------------
@triton.jit
def _moe_gather_kernel(
x_ptr,
out_ptr,
index_ptr,
M,
N: tl.constexpr,
TOPK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Gather: for each token i, out[i] = sum_k(x[index[i, k]]) over topk."""
pid_m = tl.program_id(0).to(tl.int64)
pid_n = tl.program_id(1)
if pid_m >= M:
return
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
acc = tl.zeros([BLOCK_N], dtype=tl.float32)
for k in tl.static_range(TOPK):
src_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
x_ptrs = x_ptr + src_idx * N + offs_n
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0).to(tl.float32)
acc += x_vals
out_ptrs = out_ptr + pid_m * N + offs_n
tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=n_mask)
def moe_gather(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""Gather and reduce expert outputs back to original token positions.
For each token i, sums x[index[i, k]] over all topk slots.
Args:
x: (M * topk, N) expert outputs in sorted buffer
index: (M, topk) scatter indices (same as used in moe_scatter)
Returns:
out: (M, N) gathered output
"""
M, topk = index.shape
N = x.shape[1]
out = torch.empty(M, N, dtype=x.dtype, device=x.device)
BLOCK_N = min(triton.next_power_of_2(N), 1024)
grid = (M, triton.cdiv(N, BLOCK_N))
_moe_gather_kernel[grid](
x_ptr=x,
out_ptr=out,
index_ptr=index,
M=M,
N=N,
TOPK=topk,
BLOCK_N=BLOCK_N,
)
return out