[model] Patch GDN for NPU (#10504)

Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
This commit is contained in:
A1waysBeenHere
2026-06-04 16:39:02 +08:00
committed by GitHub
parent 053d43c0ac
commit 409e8a477f
14 changed files with 3219 additions and 3 deletions

View File

@@ -36,6 +36,7 @@ COPY . /app
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
RUN pip uninstall -y torch torchvision torchaudio
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
RUN pip install --no-cache-dir -r requirements/triton_ascend.txt
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation

View File

@@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer,Qwen3_5MoeVisionBlock
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
rdzv_backend: static
same_network: true
use_cpu: false

View File

@@ -0,0 +1,51 @@
# Start FSDP2 full fine-tuning on Ascend NPU
# Usage:
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config_qwen35_moe.yaml \
# src/train.py examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
#
# Note: Change `num_processes` in fsdp2_config_qwen35_moe.yaml to match your NPU count
### model
model_name_or_path: Qwen/Qwen3.5-35B-A3B
trust_remote_code: true
use_v1_kernels: false
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: alpaca_en_demo
template: qwen3_5_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
packing: false
### output
output_dir: saves/Qwen3.5-35B/lora/sft
logging_steps: 1
save_steps: 2000
max_steps: 2000
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null
disable_gradient_checkpointing: true

View File

@@ -0,0 +1,2 @@
--extra-index-url https://triton-ascend.osinfra.cn/pypi/simple
triton-ascend==3.2.1

View File

@@ -81,6 +81,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
elif model_type == "qwen3_5":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
elif model_type == "qwen3_5_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel

View File

@@ -20,6 +20,7 @@ from peft import PeftModel
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
from ..extras import logging
from ..extras.misc import infer_optim_dtype
@@ -84,7 +85,60 @@ def _check_fla_dependencies() -> None:
) from exc
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
def patch_qwen3_5_forward_npu(model: "PreTrainedModel") -> None:
"""Patch for Qwen3.5 models on NPU by importing torch_npu to enable torch.cuda compatibility.
On NPU, torch.cuda operations will fail unless torch_npu is imported.
torch_npu provides compatibility layer that maps torch.cuda calls to NPU operations.
Also replaces chunk_gated_delta_rule with NPU-compatible implementation.
"""
import importlib.metadata
if "Ascend910" not in torch.npu.get_device_name(0):
logger.warning_rank0("Currently only 910B series NPUs are supported for the NPU GDN patch.")
return
try:
importlib.metadata.version("triton_ascend")
except importlib.metadata.PackageNotFoundError:
logger.warning_rank0(
"triton_ascend not installed, skipping NPU GDN patch. "
"To enable it on NPU, reinstall Triton with the Ascend build: "
"`pip uninstall -y triton && pip install -r requirements/triton_ascend.txt`. "
"Note: triton and triton_ascend cannot coexist — triton must be uninstalled first."
)
return
logger.info_rank0("triton_ascend detected for NPU compatibility.")
from ..third_party.triton.chunk_gated_delta_rule import chunk_gated_delta_rule as npu_chunk_gated_delta_rule
if model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
try:
# Qwen3.5-MoE structure: model.model.language_model.layers
for layer in model.model.language_model.layers:
if hasattr(layer, "linear_attn"):
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
logger.info_rank0(
"Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5-MoE model."
)
except Exception as e:
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
elif model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
try:
# Qwen3.5 structure: model.model.layers
for layer in model.model.layers:
if hasattr(layer, "linear_attn"):
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
logger.info_rank0("Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5 model.")
except Exception as e:
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
def patch_qwen3_5_forward_gpu(model: "PreTrainedModel") -> None:
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
@@ -421,8 +475,12 @@ def patch_model(
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
patch_qwen3_5_forward(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"]:
if is_torch_npu_available():
patch_qwen3_5_forward_npu(model)
elif is_torch_cuda_available() and model_args.flash_attn == "fa2":
# this is the patch for packing/neat_packing for GPU GDN. And when setting packing, flash_attn must be fa2.
patch_qwen3_5_forward_gpu(model)
if not model_args.use_unsloth:
print_attn_implementation(model.config)

View File

@@ -0,0 +1,594 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import get_autotune_config, get_npu_properties, prepare_chunk_indices, prepare_chunk_offsets
CUBE_CORE_NUM = get_npu_properties()["num_aicore"]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=get_autotune_config(multibuffer_list=(False,)),
key=["H", "K", "V", "BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
w,
v_new,
g,
gk,
h,
h0,
ht,
cu_seqlens,
chunk_offsets,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_all = T
NT_all = NT
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# Initialize hidden states
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
if IS_VARLEN:
v = v + (i_h * T_all + bos) * V
k = k + (i_h * T_all + bos) * K
w = w + (i_h * T_all + bos) * K
g = g + i_h * T_all + bos
h = h + (i_h * NT_all + boh) * K * V
if SAVE_NEW_VALUE:
v_new_base = v_new + (i_h * T_all + bos) * V
else:
v = v + (i_n * H + i_h) * T * V
k = k + (i_n * H + i_h) * T * K
w = w + (i_n * H + i_h) * T * K
g = g + (i_n * H + i_h) * T
h = h + (i_n * H + i_h) * NT * K * V
if SAVE_NEW_VALUE:
v_new_base = v_new + (i_n * H + i_h) * T * V
if USE_INITIAL_STATE:
h0_ptr = h0 + i_nh * K * V
if STORE_FINAL_STATE:
ht_ptr = ht + i_nh * K * V
# Load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# Main recurrence over chunks
for i_t in range(NT):
# Store current hidden state h_t
p_h1 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
# Compute v_residual = v - w @ h
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
p_v = tl.make_block_ptr(v, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(v_new_base, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
# Apply output gate g
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
b_g_last = tl.load(g + last_idx)
p_g = tl.make_block_ptr(g, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_v *= (m_t * tl.exp(b_g_last - b_g))[:, None]
b_g_last_exp = tl.exp(b_g_last)
b_h1 *= b_g_last_exp
if K > 64:
b_h2 *= b_g_last_exp
if K > 128:
b_h3 *= b_g_last_exp
if K > 192:
b_h4 *= b_g_last_exp
# Apply key gate gk
if USE_GK:
o_k1 = tl.arange(0, 64).to(tl.float32)
gk_base_ptr = gk + (i_n * H + i_h) * T * K
b_gk_last1 = tl.load(gk_base_ptr + last_idx * K + o_k1, mask=(o_k1 < K), other=0.0)
b_h1 *= tl.exp(b_gk_last1)[:, None]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(gk_base_ptr + last_idx * K + o_k2, mask=(o_k2 < K), other=0.0)
b_h2 *= tl.exp(b_gk_last2)[:, None]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(gk_base_ptr + last_idx * K + o_k3, mask=(o_k3 < K), other=0.0)
b_h3 *= tl.exp(b_gk_last3)[:, None]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(gk_base_ptr + last_idx * K + o_k4, mask=(o_k4 < K), other=0.0)
b_h4 *= tl.exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)
# Update hidden state: h += k @ v
p_k = tl.make_block_ptr(k, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last1[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h1 += tl.dot(b_k, b_v)
if K > 64:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last2[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h2 += tl.dot(b_k, b_v)
if K > 128:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last3[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h3 += tl.dot(b_k, b_v)
if K > 192:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last4[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h4 += tl.dot(b_k, b_v)
# Store final state
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # default:64
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, u.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V).permute(0, 2, 1, 3, 4).contiguous()
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
BV = 128
v_new = torch.empty_like(u).permute(0, 2, 1, 3).contiguous() if save_new_value else None
k = k.permute(0, 2, 1, 3).contiguous()
w = w.permute(0, 2, 1, 3).contiguous()
u = u.permute(0, 2, 1, 3).contiguous()
g = g.permute(0, 2, 1).contiguous()
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[(triton.cdiv(V, BV), N * H)](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
NT=NT,
)
h = h.permute(0, 2, 1, 3, 4).contiguous()
v_new = v_new.permute(0, 2, 1, 3).contiguous()
return h, v_new, final_state
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["dh0"] is not None,
"USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=get_autotune_config(multibuffer_list=(True, False)),
key=["H", "K", "V", "BT", "BV", "USE_G", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
q,
k,
w,
g,
gk,
dht,
dh0,
do,
dh,
dv,
dv2,
cu_seqlens,
chunk_offsets,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_FINAL_STATE_GRADIENT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_all = T
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_dh2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_dh3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
w += (bos * H + i_h) * K
do += (bos * H + i_h) * V
dv += (bos * H + i_h) * V
dv2 += (bos * H + i_h) * V
dh += (boh * H + i_h) * K * V
if USE_GK:
gk += (bos * H + i_h) * K
if USE_INITIAL_STATE:
dh0 += i_nh * K * V
if USE_FINAL_STATE_GRADIENT:
dht += i_nh * K * V
stride_v = H * V
stride_h = H * K * V
stride_k = H * K
if USE_FINAL_STATE_GRADIENT:
p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_dh1 += tl.load(p_dht1, boundary_check=(0, 1))
if K > 64:
p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_dh2 += tl.load(p_dht2, boundary_check=(0, 1))
if K > 128:
p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_dh3 += tl.load(p_dht3, boundary_check=(0, 1))
if K > 192:
p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_dh4 += tl.load(p_dht4, boundary_check=(0, 1))
for i_t in range(NT - 1, -1, -1):
p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
if IS_VARLEN:
bos_g = i_h * T_all + bos
else:
bos_g = (i_n * H + i_h) * T_all
bg_last = tl.load(g + bos_g + last_idx)
bg_last_exp = tl.exp(bg_last)
p_g = tl.make_block_ptr(
base=g + bos_g, shape=(T,), strides=(1,), offsets=(i_t * BT,), block_shape=(BT,), order=(0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_exp = tl.exp(b_g)
p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
# Update dv
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0)
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
if K > 64:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
if K > 128:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
if K > 192:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
b_dv *= (m_t * tl.exp(bg_last - b_g))[:, None]
b_dv += tl.load(p_dv, boundary_check=(0, 1))
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# Update dh
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
if USE_G:
b_dh1 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh1 *= tl.exp(b_gk_last1[:, None])
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 64:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh2 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh2 *= tl.exp(b_gk_last2[:, None])
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 128:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh3 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh3 *= tl.exp(b_gk_last3[:, None])
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 192:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh4 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh4 *= tl.exp(b_gk_last4[:, None])
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if USE_INITIAL_STATE:
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_bwd_dhu(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
do: torch.Tensor,
dv: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
h0: torch.Tensor | None = None,
dht: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *q.shape, do.shape[-1]
# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
dh = q.new_empty(B, NT, H, K, V)
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
dv2 = torch.empty_like(dv)
BV = 128
g = g.permute(0, 2, 1).contiguous()
chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[(triton.cdiv(V, BV), N * H)](
q=q,
k=k,
w=w,
g=g,
gk=gk,
dht=dht,
dh0=dh0,
do=do,
dh=dh,
dv=dv,
dv2=dv2,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
)
return dh, dh0, dv2

View File

@@ -0,0 +1,347 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional
import torch
from .chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .solve_tril import solve_tril
from .utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
from .wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
):
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
return g, o, A, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
):
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=False,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dv = chunk_bwd_dv_local(
q=q,
k=k,
g=g,
do=do,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=q,
k=k,
w=w,
g=g,
h0=initial_state,
dht=dht,
do=do,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dq, dk, dw, dg = chunk_bwd_dqkwg(
q=q,
k=k,
v=v_new,
w=w,
g=g,
h=h,
dv=dv,
do=do,
dh=dh,
chunk_size=chunk_size,
scale=scale,
cu_seqlens=cu_seqlens,
)
dk2, dv, db, dg2 = prepare_wy_repr_bwd(
k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size
)
dk.add_(dk2)
dg.add_(dg2)
if dg.dtype != torch.float32:
raise ValueError(f"dg current type is {dg.dtype} , should be float32")
dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False)
return dq, dk, dv, db, dg, dh0
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@autocast_custom_fwd
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
chunk_size: int = 64,
):
q_rstd, k_rstd = None, None
g, o, A, final_state = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
ctx.chunk_size = chunk_size
return o.to(q.dtype), final_state
@staticmethod
@input_guard
@autocast_custom_bwd
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
A=A,
scale=ctx.scale,
initial_state=initial_state,
do=do,
dht=dht,
cu_seqlens=cu_seqlens,
chunk_size=ctx.chunk_size,
)
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
head_first: bool = False,
):
r"""Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor):
betas of shape `[B, T, H]`.
scale (Optional[float]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
use_qk_l2norm_in_kernel (bool):
Whether to apply L2norm to the q/k tensor internally. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
This argument has been deprecated.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
""" # noqa: D205
if q.dtype != k.dtype or k.dtype != v.dtype:
raise ValueError(
f"q current type is {q.dtype} , k current type is {k.dtype} ,v current type is {v.dtype} , they should are equal"
)
if q.dtype == torch.float32:
raise ValueError("ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.")
if len(beta.shape) != 3:
raise ValueError(
f"beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
)
if head_first:
warnings.warn(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
original_dtype = x.dtype
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
# Counteract verl's autocast promotion (bf16 -> fp32) by restoring original dtype
return (x * inv_norm).to(original_dtype)
if use_qk_l2norm_in_kernel:
q = l2norm(q, dim=-1, eps=1e-6)
k = l2norm(k, dim=-1, eps=1e-6)
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, False, chunk_size
)
return o, final_state

View File

@@ -0,0 +1,617 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import exp, prepare_chunk_indices, prepare_chunk_offsets
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"USE_DW": lambda args: args["dw"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_bwd_kernel_dqkwg(
q,
k,
v,
h,
g,
g_gamma,
do,
dh,
dq,
dk,
dg,
w,
dv,
dw,
cu_seqlens,
chunk_indices,
scale,
B: tl.constexpr,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
USE_DW: tl.constexpr,
IS_VARLEN: tl.constexpr,
gdiff,
):
i_t, i_b = tl.program_id(0), tl.program_id(1)
T_max = T
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
total = B * T_max
T = eos - bos
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
total = B * T_max
NK = tl.cdiv(K, BK)
for i_k in range(NK):
if USE_G:
dg_k = dg + i_k * total * H
for i_h in range(H):
v_h = v + (bos * H + i_h) * V
do_h = do + (bos * H + i_h) * V
h_h = h + (i_tg * H + i_h).to(tl.int64) * K * V
dh_h = dh + (i_tg * H + i_h).to(tl.int64) * K * V
q_h = q + (bos * H + i_h) * K
k_h = k + (bos * H + i_h) * K
dq_h = dq + (bos * H + i_h) * K
dk_h = dk + (bos * H + i_h) * K
if USE_DW:
w_h = w + (bos * H + i_h) * K # noqa: F841
dw_h = dw + (bos * H + i_h) * K
dv_h = dv + (bos * H + i_h) * V
if USE_G:
if IS_VARLEN:
dg_h = dg_k + i_h * T_max + bos
g_h = g + i_h * T_max + bos
else:
dg_h = dg_k + (i_b * H + i_h) * T_max
g_h = g + (i_b * H + i_h) * T_max
b_dg_last = tl.zeros(
[
1,
],
dtype=tl.float32,
)
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
b_g_last = b_gamma * min(BT, T - i_t * BT)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
if USE_G:
b_dg_last += tl.sum(b_h * b_dh)
b_ds += tl.dot(b_do, tl.trans(b_v))
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
if USE_DW:
p_dv = tl.make_block_ptr(dv_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
if USE_DW:
p_dw = tl.make_block_ptr(dw_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
p_q = tl.make_block_ptr(q_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
p_dq = tl.make_block_ptr(dq_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_dg = tl.zeros(
[
BT,
],
dtype=tl.float32,
)
p_g = tl.make_block_ptr(g_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_last = tl.load(g_h + (min(i_t * BT + BT, T) - 1) * 1)
b_dg_last *= tl.exp(b_g_last)
b_dq = b_dq * tl.exp(b_g)[:, None] * scale
b_dg += tl.sum(b_dq * b_q, axis=1)
b_dk = b_dk * tl.where(m_t, tl.exp(-b_g + b_g_last), 0)[:, None]
b_dg -= tl.sum(b_k * b_dk, axis=1)
b_dg_last += tl.sum(b_dk * b_k)
if IS_VARLEN:
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
else:
p_gdiff = tl.make_block_ptr(
gdiff + i_b * H * NT * BT * BT + i_h * NT * BT * BT + i_t * BT * BT,
(BT, BT),
(BT, 1),
(0, 0),
(BT, BT),
(1, 0),
)
gdiff_ = tl.load(p_gdiff)
b_ds = b_ds * gdiff_ * scale
b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
b_dg += tl.sum(b_ds2, axis=1)
b_dg -= tl.sum(b_ds2, axis=0)
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)
p_dg = tl.make_block_ptr(dg_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
last_index_local = min(BT, T - i_t * BT) - 1
if last_index_local >= 0:
is_last_mask = tl.arange(0, BT) == last_index_local
b_dg = tl.where(is_last_mask, b_dg + b_dg_last, b_dg)
else:
pass
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
elif USE_G_GAMMA:
b_dq = b_dq * exp(b_g)[:, None] * scale
b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None]
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
else:
b_ds = tl.where(m_A, b_ds, 0)
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_bwd_kernel_dv_local(
q,
k,
g,
g_gamma,
do,
dv,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_b = tl.program_id(0), tl.program_id(1)
T_max = T
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
for i_h in range(H):
offset_kh = (bos * H + i_h) * K
offset_vh = (bos * H + i_h) * V
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + offset_kh, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_q = tl.make_block_ptr(q + offset_kh, (K, T), (1, H * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, b_q)
if USE_G:
if IS_VARLEN:
offset_g = i_h * T_max + bos
else:
offset_g = i_b * H * T_max + i_h * T_max
p_g = tl.make_block_ptr(g + offset_g, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_A = tl.where(m_A, b_A * tl.exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
else:
b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_do = tl.make_block_ptr(do + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
v,
h,
g,
g_gamma,
o,
cu_seqlens,
chunk_offsets,
scale,
T,
H: tl.constexpr,
N: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_max = T
for i_v in range(tl.cdiv(V, BV)):
for i_n in range(N):
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
base_chunks_per_pid = NT // total_cores
remainder = NT % total_cores
if core_id < remainder:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * base_chunks_per_pid + remainder
# offset calculation
for i_h in range(0, H):
q_offset = (bos * Hg + i_h // (H // Hg)) * K
k_offset = (bos * Hg + i_h // (H // Hg)) * K
v_offset = (bos * H + i_h) * V
o_offset = (bos * H + i_h) * V
for i_t in range(start_idx, start_idx + chunks_this_pid):
i_tg = boh + i_t
h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + q_offset, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k + k_offset, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o += tl.dot(b_q, b_h)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A += tl.dot(b_q, b_k)
if USE_G:
if IS_VARLEN:
p_g = tl.make_block_ptr(g + bos + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
else:
p_g = tl.make_block_ptr(g + bos * H + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v + v_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + o_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_bwd_dqkwg(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
do: torch.Tensor,
h: torch.Tensor,
dh: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
w: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
scale: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 128 if cu_seqlens is None else 64
BV = 64
NK = triton.cdiv(K, BK)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
g = g.transpose(1, 2).contiguous()
dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
dw = torch.empty_like(w) if w is not None else None
grid = (NT, B)
if cu_seqlens is None:
if NT * BT == T:
g_ = g.reshape(B, H, NT, BT)
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
g_diff = g_diff.clamp(-60, 60).exp()
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
else:
diff = NT * BT - T
g_ = torch.cat((g, torch.zeros(B, H, diff).to(g.device)), dim=-1).reshape(B, H, NT, BT)
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
g_diff = g_diff.clamp(-60, 60).exp()
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
bias = torch.arange(0, BT).to(g.device)
o_t = (NT - 1) * BT + bias
m_t = o_t < T
m_A = m_t[:, None] & m_t
g_diff[:, :, -1] *= m_A
else:
g_diff = None
chunk_bwd_kernel_dqkwg[grid](
q=q,
k=k,
v=v,
h=h,
g=g,
g_gamma=g_gamma,
do=do,
dh=dh,
dv=dv,
w=w,
dw=dw,
dq=dq,
dk=dk,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
B=B,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
gdiff=g_diff,
)
if dg is not None:
dg = dg.sum(0)
dg = dg.transpose(1, 2).contiguous()
return dq, dk, dw, dg
def chunk_bwd_dv_local(
q: torch.Tensor,
k: torch.Tensor,
do: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
scale: float = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, H, K, V = *k.shape, do.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
BK = 128
BV = 128
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g = g.transpose(1, 2).contiguous()
dv = torch.empty_like(do)
grid = (NT, B)
chunk_bwd_kernel_dv_local[grid](
q=q,
k=k,
g=g,
g_gamma=g_gamma,
do=do,
dv=dv,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return dv
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) # noqa: F841
if scale is None:
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
if cu_seqlens is None:
N, chunk_offsets = B, None
else:
N, chunk_offsets = (
len(cu_seqlens) - 1,
prepare_chunk_offsets(cu_seqlens, BT),
)
def grid(meta):
return (triton.cdiv(V, meta["BV"]), N * H)
g = g.transpose(1, 2).contiguous()
h = h.contiguous()
CV_kernel_num = 24
chunk_fwd_kernel_o[(CV_kernel_num,)](
q,
k,
v,
h,
g,
g_gamma,
o,
cu_seqlens,
chunk_offsets,
scale,
T=T,
H=H,
N=N,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=128,
BV=128,
)
return o
bwd_chunk_dqkwg = chunk_bwd_dqkwg
bwd_chunk_dv_local = chunk_bwd_dv_local

View File

@@ -0,0 +1,359 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import prepare_chunk_indices
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
USE_G: tl.constexpr,
NT,
B,
TOTAL_TASKS,
):
core_id = tl.program_id(0)
num_blocks = tl.num_programs(0)
T_max = T
base_tasks_per_block = TOTAL_TASKS // num_blocks
remainder_tasks = TOTAL_TASKS % num_blocks
if core_id < remainder_tasks:
tasks_this_core = base_tasks_per_block + 1
start_idx = core_id * tasks_this_core
else:
tasks_this_core = base_tasks_per_block
start_idx = core_id * base_tasks_per_block + remainder_tasks
for idx in range(start_idx, start_idx + tasks_this_core):
i_b = idx // NT
local_idx = idx % NT
if IS_VARLEN:
i_n = tl.load(chunk_indices + local_idx * 2).to(tl.int32)
i_t = tl.load(chunk_indices + local_idx * 2 + 1).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_local = eos - bos
else:
bos, eos = 0, T
i_t = local_idx
T_local = T
for i_h in range(H):
k_batch_off = i_b * T_max * H * K
beta_batch_off = i_b * H * T_max
g_batch_off = i_b * H * T_max
A_batch_off = i_b * T_max * H * BT
p_beta = tl.make_block_ptr(
beta + beta_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + k_batch_off + (bos * H + i_h) * K,
(T_local, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
dot_product = tl.dot(b_k, tl.trans(b_k))
o_t = i_t * BT + tl.arange(0, BT)
o_t = o_t.to(tl.float32)
T_mask = (o_t < T_local).to(tl.float32)
row_indices = tl.arange(0, BT)[:, None]
col_indices = tl.arange(0, BT)[None, :]
tril_mask = (row_indices > col_indices).to(tl.float32)
tril_mask = tril_mask * T_mask[:, None]
masked_dot = dot_product * tril_mask
b_A += masked_dot
if USE_G:
p_g = tl.make_block_ptr(
g + g_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_g_diff = tl.minimum(tl.maximum(b_g_diff, -50.0), 50.0)
b_A *= tl.exp(b_g_diff)
b_A *= b_beta[:, None]
p_A = tl.make_block_ptr(
A + A_batch_off + (bos * H + i_h) * BT, (T_local, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(configs=[triton.Config({"BK": BK}) for BK in [32, 64]], key=["BC"])
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_i, i_j = i_c // NC, i_c % NC
for i_h in range(H):
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_val = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_val = T
should_compute = (i_t * BT + i_i * BC < T_val) and (i_i > i_j)
if should_compute:
k_ptr = k + (bos * H + i_h) * K
g_ptr = g + (bos * H + i_h) * K
A_ptr = A + (bos * H + i_h) * BT
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T_val,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
b_kt = tl.make_block_ptr(
k_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
p_gk = tl.make_block_ptr(
g_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
b_gn = tl.load(g_ptr + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
b_g = tl.load(p_g, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.exp(b_g - b_gn[None, :])
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kt = tl.load(b_kt, boundary_check=(0, 1)) * tl.exp(b_gn[:, None] - b_gk)
b_A += tl.dot(b_k, b_kt)
b_A *= b_beta[:, None]
p_A = tl.make_block_ptr(A_ptr, (T_val, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_i, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
for i_h in range(H):
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_val = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_val = T
should_compute = i_t * BT + i_i * BC < T_val
if should_compute:
o_i = tl.arange(0, BC)
o_k = tl.arange(0, BK)
m_k = o_k < K
m_A = (i_t * BT + i_i * BC + o_i) < T_val
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
)
p_beta = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.load(p_beta, mask=m_A, other=0)[:, None]
b_g = tl.load(p_g, boundary_check=(0, 1))
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
for j in range(0, min(BC, T_val - i_t * BT - i_i * BC)):
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
b_A = tl.sum(b_k * b_kt[None, :] * tl.exp(b_g - b_gk[None, :]), 1)
# 转化成f32
o_i_tmp = o_i.to(tl.float32)
b_A = tl.where(o_i_tmp > j, b_A, 0.0)
tl.store(A + o_A + j, b_A, mask=m_A)
p_kt += H * K
p_gk += H * K
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
beta: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, H, K = k.shape
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
beta = beta.transpose(1, 2).contiguous()
g = g.transpose(1, 2).contiguous()
BK = 128
kernel_num = 24
if gk is None:
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(kernel_num,)](
k=k,
g=g,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BK=BK,
NT=NT,
B=B,
TOTAL_TASKS=B * NT,
)
return A
BC = min(16, BT)
NC = triton.cdiv(BT, BC)
BK = max(triton.next_power_of_2(K), 16)
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
grid = (NT, NC * NC, B)
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
k=k,
g=gk,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
NC=NC,
)
grid = (NT, NC, B)
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
k=k,
g=gk,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
BK=BK,
)
return A

View File

@@ -0,0 +1,147 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import prepare_chunk_indices
@triton.heuristics(
{"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None}
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_scalar_kernel(
s,
o,
scale,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BLOCK_T: tl.constexpr,
REVERSE: tl.constexpr,
HAS_SCALE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
CHUNK_SIZE: tl.constexpr = 64,
):
i_block, i_b = tl.program_id(0), tl.program_id(1)
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
if IS_VARLEN:
i_s, i_block = (
tl.load(chunk_indices + i_block * 2).to(tl.int32),
tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
b_s = tl.trans(b_s, (1, 0, 2))
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
if HAS_SCALE:
b_o *= scale
b_o = tl.trans(b_o, (1, 0, 2))
b_o = tl.reshape(b_o, (BLOCK_T, H))
tl.store(ptr_o, b_o.to(ptr_o.dtype.element_ty), boundary_check=(0,))
return
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
) -> torch.Tensor:
B, T, H = g.shape
if chunk_size != 2 ** (chunk_size.bit_length() - 1):
raise ValueError(f"chunk_size must be a power of 2, chunk_size is{chunk_size}")
# We adjust the tiling strategy to prevent overflow in in backward passes and context parallel scenarios
# while maximizing UB utilization where possible.
# The tiling strategy is as follows:
# 1. BT must be greater than or equal to chunk_size.
# 2. UB estimation varies directly with H.
# 3. BT in reverse mode is smaller than in forward mode.
BT = max(chunk_size, triton.next_power_of_2((1 << 11 if reverse else 1 << 12) // H))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B)
chunk_local_cumsum_scalar_kernel[grid](
s=g_org,
o=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
BLOCK_T=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
CHUNK_SIZE=chunk_size,
)
return g
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is not None:
if g.shape[0] != 1:
raise ValueError(
f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}"
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(
g=g,
chunk_size=chunk_size,
reverse=reverse,
scale=scale,
cu_seqlens=cu_seqlens,
head_first=head_first,
output_dtype=output_dtype,
)
else:
raise ValueError(
f"Unsupported input shape {g.shape}, "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise"
)

View File

@@ -0,0 +1,272 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import input_guard, make_tensor_descriptor, prepare_chunk_indices
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def solve_tril_16x16_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = pid_bh // H, pid_bh % H
base_t = pid_t * TPP
if IS_VARLEN:
i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_eff = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_eff = T
o_i = tl.arange(0, 16) # noqa: F841
o_i_fp32 = tl.arange(0, 16).to(tl.float32)
m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
m_I = o_i_fp32[:, None] == o_i_fp32[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * BT
for tpp in tl.static_range(0, TPP):
tile_t = base_t + tpp
tile_row = tile_t * 16
offset = (tile_t * 16) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0))
b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16])
b_A_raw = desc.load([tile_row, offset]).to(tl.float32)
b_A_neg = -b_A_raw
b_A = b_A_neg * m_A
for i in range(2, min(16, T_eff - tile_row)):
slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1])
b_a_val = tl.reshape(slice_res, (16,), can_reorder=True)
dot_prod = tl.sum(b_a_val[:, None] * b_A, 0)
b_a_update = b_a_val + dot_prod
b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
b_Ai_11 += m_I
b_Ai_22 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def solve_tril_64x64_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 64)
m_I = o_i[:, None] == o_i[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * 64
offset = (i_t * 64) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 64, offset), (64, 64), (1, 0))
b_A = -tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [64, 64])
desc_o = make_tensor_descriptor(Ai, [T, 64], [H * 64, 1], [64, 64])
b_A = -desc.load([i_t * 64, offset]).to(tl.float32)
for i in range(2, min(64, T - i_t * 64)):
b_a = -tl.load(A + (i_t * 64 + i) * H * BT + o_i + offset)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (64, 64), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * 64, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@input_guard
def solve_tril(
A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float
) -> torch.Tensor:
"""Compute the inverse of the matrix I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
(I + A)^-1 with the same shape as A
""" # noqa: D205
if A.shape[-1] not in [16, 32, 64]:
raise ValueError(f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}")
output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype)
if BT == 16:
merge_fn = solve_tril_16x16_kernel
elif BT == 32:
merge_fn = merge_16x16_to_32x32_inverse_kernel
elif BT == 64:
merge_fn = solve_tril_64x64_kernel
merge_fn[NT, B * H](
A=A,
Ai=Ai,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
USE_TMA=False,
DOT_PRECISION=FLA_TRIL_PRECISION,
)
return Ai

View File

@@ -0,0 +1,359 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import itertools
import logging
import os
import warnings
from collections.abc import Callable
from enum import Enum
from typing import Any, Optional
import torch
import triton
import triton.language as tl
import triton.language.extra.libdevice as tldevice
import triton.runtime.driver as driver
from packaging import version
logger = logging.getLogger(__name__)
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
def tensor_cache(fn: Optional[Callable[..., torch.Tensor]] = None, *, maxsize: int = 1) -> Any:
"""A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the outputs of the decorated function for the most recent
set of input tensors, up to `maxsize` entries. If the function is called again with
the same input tensors, it will return the cached result.
When maxsize=1 (default), the behavior is identical to caching only the most recent result.
Can be used as @tensor_cache or @tensor_cache(maxsize=n).
Args:
fn (Callable[..., torch.Tensor], optional):
The function to be decorated when used without parentheses.
maxsize (int):
Maximum number of input combinations to cache. Default is 1.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with caching.
"""
if maxsize < 1:
raise ValueError("maxsize must be at least 1")
def _is_match(a: Any, b: Any) -> bool:
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
return a is b
try:
return a == b
except Exception:
return a is b
def _make_wrapper(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
cache: list = []
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
for i, (cached_args, cached_kwargs, cached_result) in enumerate(cache):
if len(args) == len(cached_args) and len(kwargs) == len(cached_kwargs):
if all(_is_match(a, b) for a, b in zip(args, cached_args)) and all(
k in cached_kwargs and _is_match(v, cached_kwargs[k]) for k, v in kwargs.items()
):
if i != 0:
cache.insert(0, cache.pop(i))
return cached_result
result = fn(*args, **kwargs)
cache.insert(0, (args, kwargs, result))
if len(cache) > maxsize:
cache.pop()
return result
return wrapper
if fn is not None:
return _make_wrapper(fn)
return _make_wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache(maxsize=3)
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
def get_abs_err(x, y):
return (x.detach() - y.detach()).flatten().abs().max().item()
def get_err_ratio(x, y):
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
base = (x.detach()).flatten().square().mean().sqrt().item()
return err / (base + 1e-8)
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
abs_atol = get_abs_err(ref, tri)
msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
logger.info(msg)
error_rate = get_err_ratio(ref, tri)
if abs_atol <= err_atol:
return
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
if error_rate > ratio:
warnings.warn(msg)
else:
assert error_rate < ratio, msg
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
# For Triton 3.3.x
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
elif hasattr(triton.language, "make_tensor_descriptor"):
# For Triton 3.4.x and later
make_tensor_descriptor = triton.language.make_tensor_descriptor
else:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@triton.jit
def make_tensor_descriptor(
base,
shape,
strides,
block_shape,
_builder=None,
):
return None
@functools.cache
def get_available_device() -> str:
try:
return triton.runtime.driver.active.get_current_target().backend
except BaseException:
_cpu_device_warning()
return "cpu"
def map_triton_backend_to_torch_device() -> str:
backend = get_available_device() # 'cuda' | 'hip' | 'xpu' | 'cpu' | ...
return {"cuda": "cuda", "hip": "cuda", "xpu": "xpu"}.get(backend, backend)
device = get_available_device() if get_available_device() != "hip" else "cuda"
device_torch_lib = getattr(torch, device)
device_platform = get_available_device()
is_amd = device_platform == "hip"
is_nvidia = device_platform == "cuda"
is_nvidia_hopper = is_nvidia and (
"NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
)
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
is_tma_supported = (
(is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9)
and os.environ.get("FLA_NO_USE_TMA", "0") != "1"
and (
hasattr(triton.language, "_experimental_make_tensor_descriptor")
or hasattr(triton.language, "make_tensor_descriptor")
)
)
if is_nvidia and not is_tf32_supported:
# Make old card happy, since triton will use tf32 by default.
# This is a workaround for old nvidia card.
os.environ["TRITON_F32_DEFAULT"] = "ieee"
@functools.cache
def check_pytorch_version(version_s: str = "2.4") -> bool:
return version.parse(torch.__version__) >= version.parse(version_s)
if check_pytorch_version("2.4"):
device = "cuda" if device == "cpu" else device
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
def custom_device_ctx(index: int):
return device_torch_lib.device(index)
else:
assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0."
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
def custom_device_ctx(index: int):
return torch.cuda.device(index)
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""A decorator to make sure all input tensors are contiguous and set the device based on input tensors."""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = arg
break
if tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
tensor = value
break
if tensor is not None:
ctx = custom_device_ctx(tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return fn(*contiguous_args, **contiguous_kwargs)
return wrapper
def _cpu_device_warning():
warnings.warn(("Triton is not supported on current platform, roll back to CPU."), stacklevel=1)
@tensor_cache
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
exp = tldevice.fast_expf
exp2 = tldevice.exp2
log = tldevice.fast_logf
log2 = tldevice.fast_log2f
else:
exp = tl.exp
exp2 = tl.math.exp2
log = tl.log
log2 = tl.log2
def get_all_max_shared_mem():
try:
return [
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
for i in range(device_torch_lib.device_count())
]
except BaseException:
_cpu_device_warning()
return [-1]
class Backend(Enum):
ADA = 101376 # RTX 4090
AMPERE = 166912 # A100
HOPPER = 232448 # H100
DEFAULT = 102400 # Default
@classmethod
def get_shared_memory(cls, arch: str) -> int:
try:
return cls[arch.upper()].value
except KeyError:
return cls.DEFAULT.value
@functools.cache
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
try:
device_shared_mem_list = get_all_max_shared_mem()
max_shared_memory = device_shared_mem_list[tensor_idx]
return max_shared_memory >= Backend.get_shared_memory(arch)
except Exception:
return False
def get_autotune_config(
multibuffer_list: tuple = (False,),
unit_flag_list: tuple = (False,),
limit_auto_multi_buffer_only_for_local_buffer_list: tuple = (False,),
limit_auto_multi_buffer_of_local_buffer_list: tuple = ("no-l0c",),
set_workspace_multibuffer_list: tuple = (2, 4),
enable_hivm_auto_cv_balance_list: tuple = (True,),
tile_mix_vector_loop_num_list: tuple = (2, 4),
tile_mix_cube_loop_num_list: tuple = (2, 4),
):
configs = []
for (
multibuffer,
unit_flag,
limit_auto_multi_buffer_only_for_local_buffer,
limit_auto_multi_buffer_of_local_buffer,
) in itertools.product(
list(multibuffer_list),
list(unit_flag_list),
list(limit_auto_multi_buffer_only_for_local_buffer_list),
list(limit_auto_multi_buffer_of_local_buffer_list),
):
base_config_dict = {
"multibuffer": multibuffer,
"unit_flag": unit_flag,
"limit_auto_multi_buffer_only_for_local_buffer": limit_auto_multi_buffer_only_for_local_buffer,
"limit_auto_multi_buffer_of_local_buffer": limit_auto_multi_buffer_of_local_buffer,
}
if limit_auto_multi_buffer_only_for_local_buffer:
configs.append(triton.Config(base_config_dict))
else:
for (
set_workspace_multibuffer,
enable_hivm_auto_cv_balance,
tile_mix_vector_loop,
tile_mix_cube_loop,
) in itertools.product(
list(set_workspace_multibuffer_list),
list(enable_hivm_auto_cv_balance_list),
list(tile_mix_vector_loop_num_list),
list(tile_mix_cube_loop_num_list),
):
full_config_dict = base_config_dict.copy()
full_config_dict.update(
{
"set_workspace_multibuffer": set_workspace_multibuffer,
"enable_hivm_auto_cv_balance": enable_hivm_auto_cv_balance,
"tile_mix_vector_loop": tile_mix_vector_loop,
"tile_mix_cube_loop": tile_mix_cube_loop,
}
)
configs.append(triton.Config(full_config_dict))
return configs
def get_npu_properties():
return driver.active.utils.get_device_properties(torch.npu.current_device())

View File

@@ -0,0 +1,387 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import exp, prepare_chunk_indices
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def prepare_wy_repr_bwd_kernel(
k,
v,
beta,
g,
A,
dw,
du,
dk,
dv,
dbeta,
dg,
cu_seqlens,
chunk_indices,
T,
B,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
NT: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
T_max = T
base_chunks_per_pid = NT // total_cores
remainder_chunks = NT % total_cores
if core_id < remainder_chunks:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * chunks_this_pid + remainder_chunks
for idx in range(start_idx, start_idx + chunks_this_pid):
for i_b in range(B):
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + idx * 2).to(tl.int32),
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
i_t = idx
bos, eos = i_b * T, i_b * T + T
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
for i_h in range(0, H):
if IS_VARLEN:
offset = bos + i_h * T_max
else:
offset = bos * H + i_h * T_max
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_exp = tl.exp(b_g)
b_dbeta = tl.zeros([BT], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
b_dg = tl.zeros([BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dk = tl.make_block_ptr(
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dw = tl.make_block_ptr(
dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_k_beta_g = (b_k * b_beta[:, None] * b_g_exp[:, None]).to(b_k.dtype)
b_dw = tl.load(p_dw, boundary_check=(0, 1))
b_dA += tl.dot(b_dw, tl.trans(b_k_beta_g))
b_dk_beta_g = tl.dot(b_A, b_dw)
b_dk = b_dk_beta_g * b_beta[:, None] * b_g_exp[:, None]
b_dbeta += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None], 1)
b_dg += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None] * b_beta[:, None], 1)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_dv = tl.make_block_ptr(
dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_du = tl.make_block_ptr(
du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
b_du = tl.load(p_du, boundary_check=(0, 1))
b_dA += tl.dot(b_du, tl.trans(b_v_beta))
b_dv_beta = tl.dot(b_A, b_du)
b_dv = b_dv_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
b_dA = tl.where(m_A, b_dA, 0)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA = tl.where(m_A, -b_dA * exp(b_g[:, None] - b_g[None, :]), 0)
b_dA = b_dA.to(k.dtype.element_ty)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dk = tl.make_block_ptr(
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.load(p_dk, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_A += tl.dot(b_k_beta, tl.trans(b_k))
b_dk_beta = tl.dot(b_dA, b_k)
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
b_dk += tl.dot(tl.trans(b_dA), b_k_beta)
b_dk += b_dk_beta * b_beta[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
b_dA_A = b_dA * b_A
b_dg += tl.sum(b_dA_A, axis=1) - tl.sum(b_dA_A, axis=0)
p_dg = tl.make_block_ptr(dg + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_dbeta = tl.make_block_ptr(dbeta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
gk,
cu_seqlens,
chunk_indices,
T_tmp,
B,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
NT: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
T_max = T_tmp
base_chunks_per_pid = NT // total_cores
remainder_chunks = NT % total_cores
if core_id < remainder_chunks:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * chunks_this_pid + remainder_chunks
for idx in range(start_idx, start_idx + chunks_this_pid):
for i_b in range(B):
for i_h in range(0, H):
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + idx * 2).to(tl.int32),
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
offset = bos + i_h * T_max
T = eos - bos
else:
T = T_tmp
i_t = idx
bos, eos = i_b * T, i_b * T + T
offset = bos * H + i_h * T_max
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
if USE_G:
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
if USE_G:
b_kb *= b_g[:, None]
if USE_GK:
p_gk = tl.make_block_ptr(
gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_kb *= tl.exp(tl.load(p_gk, boundary_check=(0, 1)))
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 128
BV = 128
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g = g.transpose(1, 2).contiguous() if g is not None else None
beta = beta.transpose(1, 2).contiguous()
w = torch.empty_like(k)
u = torch.empty_like(v)
cv_kernel_num = 24
recompute_w_u_fwd_kernel[(cv_kernel_num,)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g,
gk=gk,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T_tmp=T,
B=B,
H=H,
K=K,
V=V,
NT=NT,
BT=BT,
BK=BK,
BV=BV,
)
return w, u
def prepare_wy_repr_bwd(
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
dw: torch.Tensor,
du: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
chunk_size: int = 64,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 128
BV = 128
beta = beta.transpose(1, 2).contiguous()
g = g.transpose(1, 2).contiguous()
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dbeta = torch.empty_like(beta)
dg = torch.empty_like(g)
cv_kernel_num = 24
prepare_wy_repr_bwd_kernel[(cv_kernel_num,)](
k=k,
v=v,
beta=beta,
g=g,
A=A,
dw=dw,
du=du,
dk=dk,
dv=dv,
dbeta=dbeta,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
K=K,
V=V,
NT=NT,
BT=BT,
BK=BK,
BV=BV,
)
dbeta = dbeta.transpose(1, 2).contiguous()
dg = dg.transpose(1, 2).contiguous()
return dk, dv, dbeta, dg
bwd_prepare_wy_repr = prepare_wy_repr_bwd
fwd_recompute_w_u = recompute_w_u_fwd