diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index f0ecdba2c..5b800a72b 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer): self.running = RunningMoments(self.accelerator) @override - def create_optimizer(self) -> "torch.optim.Optimizer": + def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) - return super().create_optimizer() + return super().create_optimizer(*args, **kwargs) @override def create_scheduler( diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 1d679821f..cb9b73b39 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer): self.add_callback(BAdamCallback) @override - def create_optimizer(self) -> "torch.optim.Optimizer": + def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) - return super().create_optimizer() + return super().create_optimizer(*args, **kwargs) @override def create_scheduler( diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 0a4bef3dd..ffb040fdd 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -69,10 +69,10 @@ class CustomTrainer(Trainer): verify_fp8_status(self.accelerator, training_args) @override - def create_optimizer(self) -> "torch.optim.Optimizer": + def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) - return super().create_optimizer() + return super().create_optimizer(*args, **kwargs) @override def create_scheduler( diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index f0384681b..7ee1ab7d2 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer): self.add_callback(BAdamCallback) @override - def create_optimizer(self) -> "torch.optim.Optimizer": + def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) - return super().create_optimizer() + return super().create_optimizer(*args, **kwargs) @override def create_scheduler( diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index dfed95a5c..993cba839 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): verify_fp8_status(self.accelerator, training_args) @override - def create_optimizer(self) -> "torch.optim.Optimizer": + def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) - return super().create_optimizer() + return super().create_optimizer(*args, **kwargs) @override def create_scheduler( diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/cuda_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/cuda_fused_moe.py new file mode 100644 index 000000000..1a2cc4b69 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/cuda_fused_moe.py @@ -0,0 +1,429 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure-Triton Fused MoE Kernel for NVIDIA GPUs. + +Replaces the HuggingFace per-expert Python loop with a fully fused Triton pipeline: +- Forward: scatter → grouped GEMM fc1 → SiLU·gate → apply routing → grouped GEMM fc2 → gather +- Backward: all dX via grouped GEMM, all dW via grouped GEMM (no Python loops) + +Supported models: Mixtral, Qwen3-MoE, Qwen3.5-MoE. +""" + +import logging +import types + +import torch +import torch.nn.functional as F + +from ......accelerator.helper import DeviceType +from ......utils.types import HFModel +from ...base import BaseKernel +from ...registry import register_kernel +from .triton_grouped_gemm import ( + group_gemm_same_mn, + group_gemm_same_nk, + moe_gather, + moe_scatter, +) + + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Autograd Function: Full Triton MoE forward + backward +# --------------------------------------------------------------------------- + + +class TritonFusedMoeFunction(torch.autograd.Function): + """Fused MoE expert computation using Triton grouped GEMMs. + + Forward: scatter → fc1 (group GEMM) → SiLU·gate → weight → fc2 (group GEMM) → gather + Backward: all gradients computed via grouped GEMMs in single kernel launches. + """ + + @staticmethod + def forward( + ctx, + num_experts, + gate_weights, + expert_index, + hidden_states, + fc1_weight, + fc2_weight, + ): + """Forward pass. + + Args: + ctx: autograd context + num_experts: int + gate_weights: (num_tokens, top_k) routing weights + expert_index: (num_tokens, top_k) expert assignments + hidden_states: (num_tokens, hidden_dim) + fc1_weight: (E, 2*inter, hidden) merged gate+up weight + fc2_weight: (E, hidden, inter) down projection weight + """ + # Compute scatter index: maps (token, topk) → position in sorted buffer + scatter_index = expert_index.flatten().argsort(stable=True).argsort().int().view(expert_index.shape) + + # Token counts per expert and cumulative boundaries + splits = torch.zeros(num_experts, dtype=torch.int32, device=hidden_states.device) + flat_experts = expert_index.flatten().int() + splits.scatter_add_(0, flat_experts.long(), torch.ones_like(flat_experts)) + cumsum_t = torch.cumsum(splits, dim=0) + + # Scatter hidden states to sorted expert buffer + scatter_output = moe_scatter(hidden_states, scatter_index) + + # FC1: grouped GEMM (scatter_output @ fc1_weight.T) + max_M = int(splits.max().item()) + fc1_output = group_gemm_same_nk( + a=scatter_output, + b=fc1_weight, + cumsum_M=cumsum_t, + max_M=max_M, + transpose_b=True, + ) + + # SiLU gate activation + fc1_1_output, fc1_2_output = fc1_output.chunk(2, dim=-1) + fc1_1_activation = torch.nn.functional.silu(fc1_1_output) + fc1_activation = fc1_1_activation * fc1_2_output + + # Apply routing weights before fc2 (mathematically equivalent to after) + reshaped_gate_weight = gate_weights.reshape(-1, 1) + scattered_gate_weight = torch.empty_like(reshaped_gate_weight) + scattered_gate_weight[scatter_index.flatten().long()] = reshaped_gate_weight + fc1_weighted_output = fc1_activation * scattered_gate_weight + + # FC2: grouped GEMM (fc1_weighted @ fc2_weight.T) + fc2_output = group_gemm_same_nk( + a=fc1_weighted_output, + b=fc2_weight, + cumsum_M=cumsum_t, + max_M=max_M, + transpose_b=True, + ) + + # Gather back to original token positions (sum over topk) + expert_output = moe_gather(fc2_output, scatter_index) + + ctx.num_experts = num_experts + ctx.save_for_backward( + gate_weights, + fc1_weight, + fc2_weight, + hidden_states, + scatter_index, + scatter_output, + cumsum_t, + fc1_1_output, + fc1_2_output, + fc1_activation, + scattered_gate_weight, + fc1_weighted_output, + ) + + return expert_output + + @staticmethod + def backward(ctx, grad_output): + ( + gate_weights, + fc1_weight, + fc2_weight, + hidden_states, + scatter_index, + scatter_output, + cumsum_t, + fc1_1_output, + fc1_2_output, + fc1_activation, + scattered_gate_weight, + fc1_weighted_output, + ) = ctx.saved_tensors + num_experts = ctx.num_experts + hidden_dim = grad_output.shape[-1] + grad_output = grad_output.reshape(-1, hidden_dim).contiguous() + + # Recompute max_M from cumsum + splits = torch.zeros(num_experts, dtype=cumsum_t.dtype, device=cumsum_t.device) + splits[0] = cumsum_t[0] + splits[1:] = cumsum_t[1:] - cumsum_t[:-1] + max_M = int(splits.max().item()) + + # Step 1: Scatter grad_output to expert buffer + grad_fc2_output = moe_scatter(grad_output, scatter_index) + + # Step 2: FC2 backward + # dX for fc2: grad_fc2_output @ fc2_weight (transpose_b=False since fc2 is (E, hidden, inter)) + grad_fc1_weighted_output = group_gemm_same_nk( + a=grad_fc2_output, + b=fc2_weight, + cumsum_M=cumsum_t, + max_M=max_M, + transpose_b=False, + ) + + # dW for fc2: grad_fc2_output.T @ fc1_weighted_output + grad_fc2_weight = None + if fc2_weight.requires_grad: + grad_fc2_weight = torch.empty_like(fc2_weight) + group_gemm_same_mn( + a=grad_fc2_output, + b=fc1_weighted_output, + c=grad_fc2_weight, + cumsum_K=cumsum_t, + ) + + # Step 3: Routing weight backward + grad_fc1_activation = grad_fc1_weighted_output * scattered_gate_weight + grad_scattered_gate_weight = torch.sum(fc1_activation * grad_fc1_weighted_output, dim=-1) + grad_gate_weight = grad_scattered_gate_weight[scatter_index.flatten().long()] + grad_gate_weight = grad_gate_weight.reshape(gate_weights.shape) + + # Recompute silu activation for backward + fc1_1_activation = torch.nn.functional.silu(fc1_1_output) + + # Step 4: SiLU gate backward + grad_fc1_1_activation = grad_fc1_activation * fc1_2_output + grad_fc1_2_output = fc1_1_activation * grad_fc1_activation + + # SiLU backward: d/dx[x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x)) + grad_fc1_1_output = torch.ops.aten.silu_backward(grad_fc1_1_activation, fc1_1_output) + + # Merge fc1 gradients back to (total_M, 2*inter) + grad_fc1_output = torch.cat([grad_fc1_1_output, grad_fc1_2_output], dim=-1) + + # Step 5: FC1 backward + # dX for fc1: grad_fc1_output @ fc1_weight (transpose_b=False) + grad_scatter_output = group_gemm_same_nk( + a=grad_fc1_output, + b=fc1_weight, + cumsum_M=cumsum_t, + max_M=max_M, + transpose_b=False, + ) + + # dW for fc1: grad_fc1_output.T @ scatter_output + grad_fc1_weight = None + if fc1_weight.requires_grad: + grad_fc1_weight = torch.empty_like(fc1_weight) + group_gemm_same_mn( + a=grad_fc1_output, + b=scatter_output, + c=grad_fc1_weight, + cumsum_K=cumsum_t, + ) + + # Step 6: Gather gradients back to original positions + grad_hidden_states = moe_gather(grad_scatter_output, scatter_index) + grad_hidden_states = grad_hidden_states.reshape(hidden_states.shape) + + return ( + None, # num_experts + grad_gate_weight, # gate_weights + None, # expert_index + grad_hidden_states, # hidden_states + grad_fc1_weight, # fc1_weight + grad_fc2_weight, # fc2_weight + ) + + +# --------------------------------------------------------------------------- +# Patched forward functions +# --------------------------------------------------------------------------- + + +def _triton_moe_experts_forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + """Replacement forward for v5+ MoE expert modules with stacked 3D weights.""" + return TritonFusedMoeFunction.apply( + self.num_experts, + top_k_weights.to(hidden_states.dtype), + top_k_index, + hidden_states, + self.gate_up_proj, + self.down_proj, + ) + + +# --------------------------------------------------------------------------- +# Legacy (transformers < 5.0) support: weight stacking + SparseMoeBlock patch +# --------------------------------------------------------------------------- + + +class _StackedExpertWeights(torch.nn.Module): + """Lightweight container holding stacked 3D expert weight tensors.""" + + def __init__(self, gate_up_proj: torch.Tensor, down_proj: torch.Tensor, num_experts: int): + super().__init__() + self.gate_up_proj = torch.nn.Parameter(gate_up_proj) + self.down_proj = torch.nn.Parameter(down_proj) + self.num_experts = num_experts + + +def _stack_expert_weights(module: torch.nn.Module) -> None: + """Replace nn.ModuleList of individual experts with stacked 3D parameter tensors.""" + experts = module.experts + num_experts = len(experts) + + gate_up_list = [] + for expert in experts: + gate_w = expert.gate_proj.weight.data # (inter, hidden) + up_w = expert.up_proj.weight.data # (inter, hidden) + gate_up_list.append(torch.cat([gate_w, up_w], dim=0)) # (2*inter, hidden) + gate_up_proj = torch.stack(gate_up_list, dim=0) # (E, 2*inter, hidden) + + down_proj = torch.stack([e.down_proj.weight.data for e in experts], dim=0) # (E, hidden, inter) + + module.experts = _StackedExpertWeights(gate_up_proj, down_proj, num_experts) + logger.info( + f"cuda_fused_moe: Stacked {num_experts} expert weights into " + f"gate_up_proj {tuple(gate_up_proj.shape)}, down_proj {tuple(down_proj.shape)}" + ) + + +def _triton_moe_sparse_block_forward(self, hidden_states: torch.Tensor): + """Replacement forward for legacy SparseMoeBlock with inline routing.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits = self.gate(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = TritonFusedMoeFunction.apply( + self.num_experts, + routing_weights, + selected_experts, + hidden_states, + self.experts.gate_up_proj, + self.experts.down_proj, + ) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +# --------------------------------------------------------------------------- +# Module mapping +# --------------------------------------------------------------------------- + +_TRITON_MOE_MAPPING: dict[str, dict[str, object]] = { + "MixtralForCausalLM": { + "MixtralExperts": _triton_moe_experts_forward, + }, + "Qwen3MoeForCausalLM": { + "Qwen3MoeExperts": _triton_moe_experts_forward, + "Qwen3MoeSparseMoeBlock": _triton_moe_sparse_block_forward, + }, + "Qwen3_5MoeForCausalLM": { + "Qwen3_5MoeExperts": _triton_moe_experts_forward, + }, + "Qwen3_5MoeForConditionalGeneration": { + "Qwen3_5MoeExperts": _triton_moe_experts_forward, + }, +} + + +# --------------------------------------------------------------------------- +# Kernel registration +# --------------------------------------------------------------------------- + + +@register_kernel +class CudaFusedMoEKernel(BaseKernel): + """Pure-Triton fused MoE kernel for NVIDIA CUDA GPUs. + + Replaces HuggingFace per-expert Python loops with a fully fused Triton pipeline: + - Forward: scatter + grouped GEMMs + gather (single kernel per GEMM) + - Backward: all dX and dW via grouped GEMMs (no Python loops) + + Requires: CUDA GPU + Triton + """ + + _kernel_id = "cuda_fused_moe" + _device = DeviceType.CUDA + + @classmethod + def check_deps(cls) -> bool: + if not super().check_deps(): + return False + try: + import triton # noqa: F401 + + return True + except ImportError: + logger.info("cuda_fused_moe: Triton not available, kernel disabled.") + return False + + @classmethod + def apply(cls, **kwargs) -> HFModel: + model = kwargs.get("model") + if model is None: + raise ValueError(f"HFModel instance is required for {cls.__name__}.") + + if not cls.check_deps(): + logger.warning("cuda_fused_moe: Dependencies not met. Skipping kernel application.") + return model + + archs = getattr(model.config, "architectures", None) or [] + target_mapping = None + for arch in archs: + if arch in _TRITON_MOE_MAPPING: + target_mapping = _TRITON_MOE_MAPPING[arch] + break + + if target_mapping is None: + logger.info( + f"cuda_fused_moe: Model architecture {archs} not supported. " + f"Supported: {list(_TRITON_MOE_MAPPING.keys())}" + ) + return model + + patched_count = 0 + for module in model.modules(): + class_name = module.__class__.__name__ + if class_name not in target_mapping: + continue + + target_fn = target_mapping[class_name] + + if hasattr(module, "gate_up_proj") and hasattr(module, "down_proj"): + module.forward = types.MethodType(target_fn, module) + patched_count += 1 + elif ( + hasattr(module, "experts") + and isinstance(module.experts, torch.nn.ModuleList) + and hasattr(module, "gate") + ): + _stack_expert_weights(module) + module.forward = types.MethodType(target_fn, module) + patched_count += 1 + + if patched_count > 0: + logger.info(f"cuda_fused_moe: Patched {patched_count} MoE expert modules with pure Triton pipeline.") + else: + logger.warning("cuda_fused_moe: No MoE expert modules found to patch.") + + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/triton_grouped_gemm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/triton_grouped_gemm.py new file mode 100644 index 000000000..c75b8f5ce --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/triton_grouped_gemm.py @@ -0,0 +1,417 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Pure-Triton grouped GEMM and MoE scatter/gather kernels. +# Design adapted from VeOmni (ByteDance-Seed/VeOmni) group_gemm kernels. + +"""Pure-Triton MoE kernels: grouped GEMM, scatter, and gather. + +Provides four kernel types for fused MoE forward+backward without Python loops: +- group_gemm_same_nk: Variable-M grouped GEMM (forward & backward dX) +- group_gemm_same_mn: Variable-K grouped GEMM (backward dW) +- moe_scatter: Token dispatch to sorted expert buffers +- moe_gather: Token reduction from expert buffers +""" + +import torch +import triton +import triton.language as tl + + +# --------------------------------------------------------------------------- +# Triton helper: grouped tile indexing with L2 cache-friendly swizzle +# --------------------------------------------------------------------------- + + +@triton.jit +def _get_pid_mn(pid, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr): + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +# --------------------------------------------------------------------------- +# group_gemm_same_nk: All experts share same N, K; variable M per expert +# Used for: forward (x @ W.T) and backward dX (grad @ W) +# --------------------------------------------------------------------------- + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3), + ], + key=["N", "K"], +) +@triton.jit +def _group_gemm_same_nk_kernel( + a_ptr, + b_ptr, + c_ptr, + cumsum_M, + max_M, + G: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + TRANSPOSE_B: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP: tl.constexpr, +): + pid_m, pid_n = _get_pid_mn(tl.program_id(0), max_M, N, BLOCK_M, BLOCK_N, GROUP) + gid = tl.program_id(1).to(tl.int64) + + gtid_start = tl.load(cumsum_M + gid - 1, mask=gid > 0, other=0).to(tl.int64) + gtid_end = tl.load(cumsum_M + gid).to(tl.int64) + m_size = gtid_end - gtid_start + + if pid_m * BLOCK_M >= m_size: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + # a is (total_M, K) row-major, offset by expert start + a_base = a_ptr + gtid_start * K + # b is (G, N, K) if TRANSPOSE_B else (G, K, N) + b_base = b_ptr + gid * K * N + # c is (total_M, N) row-major + c_base = c_ptr + gtid_start * N + + if TRANSPOSE_B: + # b layout: (G, N, K), we compute a @ b.T = a(M,K) @ b(N,K).T -> (M,N) + a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_base + offs_n[:, None] * K + offs_k[None, :] + else: + # b layout: (G, K, N), we compute a @ b = a(M,K) @ b(K,N) -> (M,N) + a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_base + offs_k[:, None] * N + offs_n[None, :] + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k_start in range(0, K, BLOCK_K): + k_offs = k_start + offs_k + k_mask = k_offs < K + + a_block = tl.load(a_ptrs, mask=(offs_m[:, None] < m_size) & k_mask[None, :], other=0.0) + + if TRANSPOSE_B: + b_block = tl.load(b_ptrs, mask=(offs_n[:, None] < N) & k_mask[None, :], other=0.0) + acc += tl.dot(a_block, tl.trans(b_block)) + else: + b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0) + acc += tl.dot(a_block, b_block) + + if TRANSPOSE_B: + a_ptrs += BLOCK_K + b_ptrs += BLOCK_K + else: + a_ptrs += BLOCK_K + b_ptrs += BLOCK_K * N + + c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :] + c_mask = (offs_m[:, None] < m_size) & (offs_n[None, :] < N) + tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) + + +def group_gemm_same_nk( + a: torch.Tensor, + b: torch.Tensor, + cumsum_M: torch.Tensor, + max_M: int, + transpose_b: bool = False, +) -> torch.Tensor: + """Grouped GEMM where all groups share same N, K dimensions but variable M. + + Args: + a: (total_M, K) input tensor, rows grouped by expert + b: (G, N, K) if transpose_b else (G, K, N) weight tensor + cumsum_M: (G,) cumulative token counts per expert + max_M: maximum tokens any single expert has + transpose_b: if True, compute a @ b.T; else compute a @ b + + Returns: + c: (total_M, N) output tensor + """ + if transpose_b: + G, N, K = b.shape + else: + G, K, N = b.shape + + c = torch.empty((a.shape[0], N), dtype=a.dtype, device=a.device) + + _group_gemm_same_nk_kernel[ + (lambda meta: (triton.cdiv(max_M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G)) + ]( + a_ptr=a, + b_ptr=b, + c_ptr=c, + cumsum_M=cumsum_M, + max_M=max_M, + G=G, + N=N, + K=K, + TRANSPOSE_B=transpose_b, + ) + return c + + +# --------------------------------------------------------------------------- +# group_gemm_same_mn: All experts share same M, N (weight dims); variable K +# Used for: backward dW (grad.T @ input) +# --------------------------------------------------------------------------- + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3), + ], + key=["M", "N"], +) +@triton.jit +def _group_gemm_same_mn_kernel( + a_ptr, + b_ptr, + c_ptr, + cumsum_K, + G: tl.constexpr, + M: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP: tl.constexpr, +): + pid_m, pid_n = _get_pid_mn(tl.program_id(0), M, N, BLOCK_M, BLOCK_N, GROUP) + gid = tl.program_id(1).to(tl.int64) + + gtid_start = tl.load(cumsum_K + gid - 1, mask=gid > 0, other=0).to(tl.int64) + gtid_end = tl.load(cumsum_K + gid).to(tl.int64) + k_size = gtid_end - gtid_start + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # c is (G, M, N) + c_base = c_ptr + gid * M * N + + if k_size == 0: + c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :] + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, tl.zeros((BLOCK_M, BLOCK_N), dtype=c_ptr.dtype.element_ty), mask=c_mask) + return + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + offs_k = tl.arange(0, BLOCK_K) + + # a is (total_K, M), compute a.T @ b -> (M, N) + # b is (total_K, N) + a_base = a_ptr + gtid_start * M + b_base = b_ptr + gtid_start * N + + for k_start in range(0, k_size, BLOCK_K): + k_offs = k_start + offs_k + k_mask = k_offs < k_size + + a_ptrs = a_base + k_offs[:, None] * M + offs_m[None, :] + a_block_t = tl.trans(tl.load(a_ptrs, mask=k_mask[:, None] & (offs_m[None, :] < M), other=0.0)) + + # Load b block: (BLOCK_K, BLOCK_N) + b_ptrs = b_base + k_offs[:, None] * N + offs_n[None, :] + b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0) + + acc += tl.dot(a_block_t, b_block) + + c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :] + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) + + +def group_gemm_same_mn( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + cumsum_K: torch.Tensor, +) -> None: + """Grouped GEMM where all groups produce same (M, N) output; variable K reduction. + + Computes: c[g] = a[s:e].T @ b[s:e] for each group g, + where s, e are defined by cumsum_K boundaries. + + Args: + a: (total_K, M) input tensor grouped by expert + b: (total_K, N) input tensor grouped by expert + c: (G, M, N) output tensor (pre-allocated) + cumsum_K: (G,) cumulative token counts per expert + """ + G, M, N = c.shape + + _group_gemm_same_mn_kernel[(lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))]( + a_ptr=a, + b_ptr=b, + c_ptr=c, + cumsum_K=cumsum_K, + G=G, + M=M, + N=N, + ) + + +# --------------------------------------------------------------------------- +# moe_scatter: Dispatch tokens to sorted expert buffer positions +# --------------------------------------------------------------------------- + + +@triton.jit +def _moe_scatter_kernel( + x_ptr, + out_ptr, + index_ptr, + M, + N: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Scatter: for each token i, copy x[i] to out[index[i, k]] for k in 0..topk-1.""" + pid_m = tl.program_id(0).to(tl.int64) + pid_n = tl.program_id(1) + + if pid_m >= M: + return + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + + # Load input row + x_ptrs = x_ptr + pid_m * N + offs_n + x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0) + + # Store to each topk destination + for k in tl.static_range(TOPK): + dst_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64) + out_ptrs = out_ptr + dst_idx * N + offs_n + tl.store(out_ptrs, x_vals, mask=n_mask) + + +def moe_scatter(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor: + """Scatter tokens to sorted expert buffer. + + For each token i and topk slot k, copies x[i] to output[index[i, k]]. + + Args: + x: (M, N) input hidden states + index: (M, topk) scatter indices + + Returns: + out: (M * topk, N) scattered output + """ + M, N = x.shape + topk = index.shape[1] + out = torch.empty(M * topk, N, dtype=x.dtype, device=x.device) + + BLOCK_N = min(triton.next_power_of_2(N), 1024) + grid = (M, triton.cdiv(N, BLOCK_N)) + + _moe_scatter_kernel[grid]( + x_ptr=x, + out_ptr=out, + index_ptr=index, + M=M, + N=N, + TOPK=topk, + BLOCK_N=BLOCK_N, + ) + return out + + +# --------------------------------------------------------------------------- +# moe_gather: Reduce expert outputs back to token positions (sum over topk) +# --------------------------------------------------------------------------- + + +@triton.jit +def _moe_gather_kernel( + x_ptr, + out_ptr, + index_ptr, + M, + N: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Gather: for each token i, out[i] = sum_k(x[index[i, k]]) over topk.""" + pid_m = tl.program_id(0).to(tl.int64) + pid_n = tl.program_id(1) + + if pid_m >= M: + return + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + for k in tl.static_range(TOPK): + src_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64) + x_ptrs = x_ptr + src_idx * N + offs_n + x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0).to(tl.float32) + acc += x_vals + + out_ptrs = out_ptr + pid_m * N + offs_n + tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=n_mask) + + +def moe_gather(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor: + """Gather and reduce expert outputs back to original token positions. + + For each token i, sums x[index[i, k]] over all topk slots. + + Args: + x: (M * topk, N) expert outputs in sorted buffer + index: (M, topk) scatter indices (same as used in moe_scatter) + + Returns: + out: (M, N) gathered output + """ + M, topk = index.shape + N = x.shape[1] + out = torch.empty(M, N, dtype=x.dtype, device=x.device) + + BLOCK_N = min(triton.next_power_of_2(N), 1024) + grid = (M, triton.cdiv(N, BLOCK_N)) + + _moe_gather_kernel[grid]( + x_ptr=x, + out_ptr=out, + index_ptr=index, + M=M, + N=N, + TOPK=topk, + BLOCK_N=BLOCK_N, + ) + return out