diff --git a/docker/docker-npu/Dockerfile b/docker/docker-npu/Dockerfile index 61da0a60b..dfbb2d674 100644 --- a/docker/docker-npu/Dockerfile +++ b/docker/docker-npu/Dockerfile @@ -36,6 +36,7 @@ COPY . /app RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh RUN pip uninstall -y torch torchvision torchaudio RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}" +RUN pip install --no-cache-dir -r requirements/triton_ascend.txt RUN pip install --no-cache-dir -r requirements/deepspeed.txt RUN pip install --no-cache-dir -e . --no-build-isolation && \ pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation diff --git a/examples/accelerate/fsdp2_config_qwen35_moe.yaml b/examples/accelerate/fsdp2_config_qwen35_moe.yaml new file mode 100644 index 000000000..ae335f11a --- /dev/null +++ b/examples/accelerate/fsdp2_config_qwen35_moe.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_version: 2 + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer,Qwen3_5MoeVisionBlock + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3) +rdzv_backend: static +same_network: true +use_cpu: false diff --git a/examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml b/examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml new file mode 100644 index 000000000..738ee06b2 --- /dev/null +++ b/examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml @@ -0,0 +1,51 @@ +# Start FSDP2 full fine-tuning on Ascend NPU +# Usage: +# accelerate launch \ +# --config_file examples/accelerate/fsdp2_config_qwen35_moe.yaml \ +# src/train.py examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml +# +# Note: Change `num_processes` in fsdp2_config_qwen35_moe.yaml to match your NPU count + +### model +model_name_or_path: Qwen/Qwen3.5-35B-A3B +trust_remote_code: true +use_v1_kernels: false +flash_attn: fa2 + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_rank: 8 +lora_target: all + +### dataset +dataset: alpaca_en_demo +template: qwen3_5_nothink +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 +packing: false + +### output +output_dir: saves/Qwen3.5-35B/lora/sft +logging_steps: 1 +save_steps: 2000 +max_steps: 2000 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 1 +learning_rate: 1.0e-5 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 1800 +resume_from_checkpoint: null +disable_gradient_checkpointing: true diff --git a/requirements/triton_ascend.txt b/requirements/triton_ascend.txt new file mode 100644 index 000000000..40f052aa9 --- /dev/null +++ b/requirements/triton_ascend.txt @@ -0,0 +1,2 @@ +--extra-index-url https://triton-ascend.osinfra.cn/pypi/simple +triton-ascend==3.2.1 diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index b1e7dc762..56b64b03a 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -81,6 +81,8 @@ def apply_liger_kernel( from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel elif model_type == "qwen3_5": from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel + elif model_type == "qwen3_5_moe": + from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe as apply_liger_kernel elif model_type == "gpt_oss": try: from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 87e0f5791..8bec9526a 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -20,6 +20,7 @@ from peft import PeftModel from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled +from transformers.utils import is_torch_cuda_available, is_torch_npu_available from ..extras import logging from ..extras.misc import infer_optim_dtype @@ -84,7 +85,60 @@ def _check_fla_dependencies() -> None: ) from exc -def patch_qwen3_5_forward(model: "PreTrainedModel") -> None: +def patch_qwen3_5_forward_npu(model: "PreTrainedModel") -> None: + """Patch for Qwen3.5 models on NPU by importing torch_npu to enable torch.cuda compatibility. + + On NPU, torch.cuda operations will fail unless torch_npu is imported. + torch_npu provides compatibility layer that maps torch.cuda calls to NPU operations. + + Also replaces chunk_gated_delta_rule with NPU-compatible implementation. + """ + import importlib.metadata + + if "Ascend910" not in torch.npu.get_device_name(0): + logger.warning_rank0("Currently only 910B series NPUs are supported for the NPU GDN patch.") + return + + try: + importlib.metadata.version("triton_ascend") + except importlib.metadata.PackageNotFoundError: + logger.warning_rank0( + "triton_ascend not installed, skipping NPU GDN patch. " + "To enable it on NPU, reinstall Triton with the Ascend build: " + "`pip uninstall -y triton && pip install -r requirements/triton_ascend.txt`. " + "Note: triton and triton_ascend cannot coexist — triton must be uninstalled first." + ) + return + + logger.info_rank0("triton_ascend detected for NPU compatibility.") + + from ..third_party.triton.chunk_gated_delta_rule import chunk_gated_delta_rule as npu_chunk_gated_delta_rule + + if model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration": + try: + # Qwen3.5-MoE structure: model.model.language_model.layers + for layer in model.model.language_model.layers: + if hasattr(layer, "linear_attn"): + layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule + + logger.info_rank0( + "Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5-MoE model." + ) + except Exception as e: + logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}") + elif model.config.architectures[0] == "Qwen3_5ForConditionalGeneration": + try: + # Qwen3.5 structure: model.model.layers + for layer in model.model.layers: + if hasattr(layer, "linear_attn"): + layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule + + logger.info_rank0("Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5 model.") + except Exception as e: + logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}") + + +def patch_qwen3_5_forward_gpu(model: "PreTrainedModel") -> None: """Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training. Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py. @@ -421,8 +475,12 @@ def patch_model( autocast_projector_dtype(model, model_args) add_z3_leaf_module(model) - if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2": - patch_qwen3_5_forward(model) + if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"]: + if is_torch_npu_available(): + patch_qwen3_5_forward_npu(model) + elif is_torch_cuda_available() and model_args.flash_attn == "fa2": + # this is the patch for packing/neat_packing for GPU GDN. And when setting packing, flash_attn must be fa2. + patch_qwen3_5_forward_gpu(model) if not model_args.use_unsloth: print_attn_implementation(model.config) diff --git a/src/llamafactory/third_party/triton/chunk_delta_h.py b/src/llamafactory/third_party/triton/chunk_delta_h.py new file mode 100644 index 000000000..ddbf1c6a6 --- /dev/null +++ b/src/llamafactory/third_party/triton/chunk_delta_h.py @@ -0,0 +1,594 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_autotune_config, get_npu_properties, prepare_chunk_indices, prepare_chunk_offsets + + +CUBE_CORE_NUM = get_npu_properties()["num_aicore"] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=get_autotune_config(multibuffer_list=(False,)), + key=["H", "K", "V", "BT"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + T_all = T + NT_all = NT + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # Initialize hidden states + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + if IS_VARLEN: + v = v + (i_h * T_all + bos) * V + k = k + (i_h * T_all + bos) * K + w = w + (i_h * T_all + bos) * K + g = g + i_h * T_all + bos + h = h + (i_h * NT_all + boh) * K * V + if SAVE_NEW_VALUE: + v_new_base = v_new + (i_h * T_all + bos) * V + else: + v = v + (i_n * H + i_h) * T * V + k = k + (i_n * H + i_h) * T * K + w = w + (i_n * H + i_h) * T * K + g = g + (i_n * H + i_h) * T + h = h + (i_n * H + i_h) * NT * K * V + if SAVE_NEW_VALUE: + v_new_base = v_new + (i_n * H + i_h) * T * V + + if USE_INITIAL_STATE: + h0_ptr = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht_ptr = ht + i_nh * K * V + + # Load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # Main recurrence over chunks + for i_t in range(NT): + # Store current hidden state h_t + p_h1 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + # Compute v_residual = v - w @ h + p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + + p_v = tl.make_block_ptr(v, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr(v_new_base, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + + # Apply output gate g + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T + b_g_last = tl.load(g + last_idx) + p_g = tl.make_block_ptr(g, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v *= (m_t * tl.exp(b_g_last - b_g))[:, None] + b_g_last_exp = tl.exp(b_g_last) + b_h1 *= b_g_last_exp + if K > 64: + b_h2 *= b_g_last_exp + if K > 128: + b_h3 *= b_g_last_exp + if K > 192: + b_h4 *= b_g_last_exp + + # Apply key gate gk + if USE_GK: + o_k1 = tl.arange(0, 64).to(tl.float32) + gk_base_ptr = gk + (i_n * H + i_h) * T * K + b_gk_last1 = tl.load(gk_base_ptr + last_idx * K + o_k1, mask=(o_k1 < K), other=0.0) + b_h1 *= tl.exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk_base_ptr + last_idx * K + o_k2, mask=(o_k2 < K), other=0.0) + b_h2 *= tl.exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk_base_ptr + last_idx * K + o_k3, mask=(o_k3 < K), other=0.0) + b_h3 *= tl.exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk_base_ptr + last_idx * K + o_k4, mask=(o_k4 < K), other=0.0) + b_h4 *= tl.exp(b_gk_last4)[:, None] + + b_v = b_v.to(k.dtype.element_ty) + + # Update hidden state: h += k @ v + p_k = tl.make_block_ptr(k, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1)) + b_k = (b_k * tl.exp(b_gk_last1[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype) + b_h1 += tl.dot(b_k, b_v) + + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1)) + b_k = (b_k * tl.exp(b_gk_last2[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype) + b_h2 += tl.dot(b_k, b_v) + + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1)) + b_k = (b_k * tl.exp(b_gk_last3[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype) + b_h3 += tl.dot(b_k, b_v) + + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1)) + b_k = (b_k * tl.exp(b_gk_last4[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype) + b_h4 += tl.dot(b_k, b_v) + + # Store final state + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # default:64 + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V).permute(0, 2, 1, 3, 4).contiguous() + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + BV = 128 + + v_new = torch.empty_like(u).permute(0, 2, 1, 3).contiguous() if save_new_value else None + k = k.permute(0, 2, 1, 3).contiguous() + w = w.permute(0, 2, 1, 3).contiguous() + u = u.permute(0, 2, 1, 3).contiguous() + g = g.permute(0, 2, 1).contiguous() + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[(triton.cdiv(V, BV), N * H)]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BV=BV, + NT=NT, + ) + h = h.permute(0, 2, 1, 3, 4).contiguous() + v_new = v_new.permute(0, 2, 1, 3).contiguous() + return h, v_new, final_state + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["dh0"] is not None, + "USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=get_autotune_config(multibuffer_list=(True, False)), + key=["H", "K", "V", "BT", "BV", "USE_G", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( + q, + k, + w, + g, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + T_all = T + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + b_dh1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_dh2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_dh3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_dh4 = tl.zeros([64, BV], dtype=tl.float32) + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + w += (bos * H + i_h) * K + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + dv2 += (bos * H + i_h) * V + dh += (boh * H + i_h) * K * V + if USE_GK: + gk += (bos * H + i_h) * K + + if USE_INITIAL_STATE: + dh0 += i_nh * K * V + if USE_FINAL_STATE_GRADIENT: + dht += i_nh * K * V + + stride_v = H * V + stride_h = H * K * V + stride_k = H * K + + if USE_FINAL_STATE_GRADIENT: + p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) + if K > 64: + p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) + if K > 128: + p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) + if K > 192: + p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) + + for i_t in range(NT - 1, -1, -1): + p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + if IS_VARLEN: + bos_g = i_h * T_all + bos + else: + bos_g = (i_n * H + i_h) * T_all + bg_last = tl.load(g + bos_g + last_idx) + bg_last_exp = tl.exp(bg_last) + p_g = tl.make_block_ptr( + base=g + bos_g, shape=(T,), strides=(1,), offsets=(i_t * BT,), block_shape=(BT,), order=(0,) + ) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_exp = tl.exp(b_g) + + p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # Update dv + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0) + b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) + + if K > 64: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) + + if K > 128: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) + + if K > 192: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T + b_dv *= (m_t * tl.exp(bg_last - b_g))[:, None] + b_dv += tl.load(p_dv, boundary_check=(0, 1)) + + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # Update dh + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + if USE_G: + b_dh1 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + b_dh1 *= tl.exp(b_gk_last1[:, None]) + b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 64: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh2 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + b_dh2 *= tl.exp(b_gk_last2[:, None]) + b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 128: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh3 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + b_dh3 *= tl.exp(b_gk_last3[:, None]) + b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 192: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh4 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + b_dh4 *= tl.exp(b_gk_last4[:, None]) + b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_bwd_dhu( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + do: torch.Tensor, + dv: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + h0: torch.Tensor | None = None, + dht: torch.Tensor | None = None, + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + chunk_indices: torch.LongTensor | None = None, + use_exp2: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *q.shape, do.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + BT = 64 + assert K <= 256, "current kernel does not support head dimension being larger than 256." + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + dh = q.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.empty_like(dv) + + BV = 128 + + g = g.permute(0, 2, 1).contiguous() + + chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[(triton.cdiv(V, BV), N * H)]( + q=q, + k=k, + w=w, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BV=BV, + ) + return dh, dh0, dv2 diff --git a/src/llamafactory/third_party/triton/chunk_gated_delta_rule.py b/src/llamafactory/third_party/triton/chunk_gated_delta_rule.py new file mode 100644 index 000000000..b1d5e83c0 --- /dev/null +++ b/src/llamafactory/third_party/triton/chunk_gated_delta_rule.py @@ -0,0 +1,347 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional + +import torch + +from .chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .solve_tril import solve_tril +from .utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +from .wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd( + k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32 + ) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + return g, o, A, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + g=g, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dq, dk, dw, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + chunk_size=chunk_size, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dk2, dv, db, dg2 = prepare_wy_repr_bwd( + k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size + ) + dk.add_(dk2) + dg.add_(dg2) + if dg.dtype != torch.float32: + raise ValueError(f"dg current type is {dg.dtype} , should be float32") + dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False) + return dq, dk, dv, db, dg, dh0 + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + chunk_size: int = 64, + ): + q_rstd, k_rstd = None, None + g, o, A, final_state = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + ctx.chunk_size = chunk_size + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do: torch.Tensor, dht: torch.Tensor): + q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A=A, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + chunk_size=ctx.chunk_size, + ) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + head_first: bool = False, +): + r"""Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]`. + beta (torch.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[float]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (bool): + Whether to apply L2norm to the q/k tensor internally. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + This argument has been deprecated. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ # noqa: D205 + if q.dtype != k.dtype or k.dtype != v.dtype: + raise ValueError( + f"q current type is {q.dtype} , k current type is {k.dtype} ,v current type is {v.dtype} , they should are equal" + ) + if q.dtype == torch.float32: + raise ValueError("ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.") + if len(beta.shape) != 3: + raise ValueError( + f"beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + ) + + if head_first: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + + def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + original_dtype = x.dtype + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + # Counteract verl's autocast promotion (bf16 -> fp32) by restoring original dtype + return (x * inv_norm).to(original_dtype) + + if use_qk_l2norm_in_kernel: + q = l2norm(q, dim=-1, eps=1e-6) + k = l2norm(k, dim=-1, eps=1e-6) + + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, False, chunk_size + ) + return o, final_state diff --git a/src/llamafactory/third_party/triton/chunk_o.py b/src/llamafactory/third_party/triton/chunk_o.py new file mode 100644 index 000000000..96cd84138 --- /dev/null +++ b/src/llamafactory/third_party/triton/chunk_o.py @@ -0,0 +1,617 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import exp, prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_G_GAMMA": lambda args: args["g_gamma"] is not None, + "USE_DW": lambda args: args["dw"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_bwd_kernel_dqkwg( + q, + k, + v, + h, + g, + g_gamma, + do, + dh, + dq, + dk, + dg, + w, + dv, + dw, + cu_seqlens, + chunk_indices, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_DW: tl.constexpr, + IS_VARLEN: tl.constexpr, + gdiff, +): + i_t, i_b = tl.program_id(0), tl.program_id(1) + T_max = T + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + total = B * T_max + T = eos - bos + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + total = B * T_max + + NK = tl.cdiv(K, BK) + for i_k in range(NK): + if USE_G: + dg_k = dg + i_k * total * H + + for i_h in range(H): + v_h = v + (bos * H + i_h) * V + do_h = do + (bos * H + i_h) * V + h_h = h + (i_tg * H + i_h).to(tl.int64) * K * V + dh_h = dh + (i_tg * H + i_h).to(tl.int64) * K * V + q_h = q + (bos * H + i_h) * K + k_h = k + (bos * H + i_h) * K + dq_h = dq + (bos * H + i_h) * K + dk_h = dk + (bos * H + i_h) * K + + if USE_DW: + w_h = w + (bos * H + i_h) * K # noqa: F841 + dw_h = dw + (bos * H + i_h) * K + dv_h = dv + (bos * H + i_h) * V + + if USE_G: + if IS_VARLEN: + dg_h = dg_k + i_h * T_max + bos + g_h = g + i_h * T_max + bos + else: + dg_h = dg_k + (i_b * H + i_h) * T_max + g_h = g + (i_b * H + i_h) * T_max + b_dg_last = tl.zeros( + [ + 1, + ], + dtype=tl.float32, + ) + + if USE_G_GAMMA: + b_gamma = tl.load(g_gamma + i_h) + b_g = b_gamma * (tl.arange(0, BT) + 1) + b_g_last = b_gamma * min(BT, T - i_t * BT) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + if USE_G: + b_dg_last += tl.sum(b_h * b_dh) + + b_ds += tl.dot(b_do, tl.trans(b_v)) + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + + if USE_DW: + p_dv = tl.make_block_ptr(dv_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + if USE_DW: + p_dw = tl.make_block_ptr(dw_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + + p_q = tl.make_block_ptr(q_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + p_dq = tl.make_block_ptr(dq_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + + if USE_G: + b_dg = tl.zeros( + [ + BT, + ], + dtype=tl.float32, + ) + p_g = tl.make_block_ptr(g_h, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g_h + (min(i_t * BT + BT, T) - 1) * 1) + b_dg_last *= tl.exp(b_g_last) + + b_dq = b_dq * tl.exp(b_g)[:, None] * scale + b_dg += tl.sum(b_dq * b_q, axis=1) + + b_dk = b_dk * tl.where(m_t, tl.exp(-b_g + b_g_last), 0)[:, None] + b_dg -= tl.sum(b_k * b_dk, axis=1) + b_dg_last += tl.sum(b_dk * b_k) + + if IS_VARLEN: + b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale + else: + p_gdiff = tl.make_block_ptr( + gdiff + i_b * H * NT * BT * BT + i_h * NT * BT * BT + i_t * BT * BT, + (BT, BT), + (BT, 1), + (0, 0), + (BT, BT), + (1, 0), + ) + gdiff_ = tl.load(p_gdiff) + b_ds = b_ds * gdiff_ * scale + + b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k)) + b_dg += tl.sum(b_ds2, axis=1) + b_dg -= tl.sum(b_ds2, axis=0) + + b_ds = b_ds.to(b_k.dtype) + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) + p_dg = tl.make_block_ptr(dg_h, (T,), (1,), (i_t * BT,), (BT,), (0,)) + + last_index_local = min(BT, T - i_t * BT) - 1 + if last_index_local >= 0: + is_last_mask = tl.arange(0, BT) == last_index_local + b_dg = tl.where(is_last_mask, b_dg + b_dg_last, b_dg) + else: + pass + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + elif USE_G_GAMMA: + b_dq = b_dq * exp(b_g)[:, None] * scale + b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None] + b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale + b_ds = b_ds.to(b_k.dtype) + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + else: + b_ds = tl.where(m_A, b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) * scale + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_G_GAMMA": lambda args: args["g_gamma"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_bwd_kernel_dv_local( + q, + k, + g, + g_gamma, + do, + dv, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_b = tl.program_id(0), tl.program_id(1) + T_max = T + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + for i_h in range(H): + offset_kh = (bos * H + i_h) * K + offset_vh = (bos * H + i_h) * V + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + offset_kh, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q + offset_kh, (K, T), (1, H * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + + if USE_G: + if IS_VARLEN: + offset_g = i_h * T_max + bos + else: + offset_g = i_b * H * T_max + i_h * T_max + + p_g = tl.make_block_ptr(g + offset_g, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + + if USE_G_GAMMA: + b_gamma = tl.load(g_gamma + i_h) + b_g = b_gamma * (tl.arange(0, BT) + 1) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) + + if USE_G: + b_A = tl.where(m_A, b_A * tl.exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + else: + b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_G_GAMMA": lambda args: args["g_gamma"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + g_gamma, + o, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + N: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + T_max = T + for i_v in range(tl.cdiv(V, BV)): + for i_n in range(N): + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int64) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + core_id = tl.program_id(0) + total_cores = tl.num_programs(0) + base_chunks_per_pid = NT // total_cores + remainder = NT % total_cores + + if core_id < remainder: + chunks_this_pid = base_chunks_per_pid + 1 + start_idx = core_id * chunks_this_pid + else: + chunks_this_pid = base_chunks_per_pid + start_idx = core_id * base_chunks_per_pid + remainder + + # offset calculation + for i_h in range(0, H): + q_offset = (bos * Hg + i_h // (H // Hg)) * K + k_offset = (bos * Hg + i_h // (H // Hg)) * K + v_offset = (bos * H + i_h) * V + o_offset = (bos * H + i_h) * V + + for i_t in range(start_idx, start_idx + chunks_this_pid): + i_tg = boh + i_t + h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr( + q + q_offset, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_k = tl.make_block_ptr( + k + k_offset, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1) + ) + p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + if IS_VARLEN: + p_g = tl.make_block_ptr(g + bos + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_g = tl.make_block_ptr(g + bos * H + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + if USE_G_GAMMA: + b_gamma = tl.load(g_gamma + i_h) + b_g = b_gamma * (tl.arange(0, BT) + 1) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v + v_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + o_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_bwd_dqkwg( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + do: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + g: Optional[torch.Tensor] = None, + g_gamma: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + w: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + scale: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BK = 128 if cu_seqlens is None else 64 + BV = 64 + NK = triton.cdiv(K, BK) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + g = g.transpose(1, 2).contiguous() + dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None + dw = torch.empty_like(w) if w is not None else None + grid = (NT, B) + + if cu_seqlens is None: + if NT * BT == T: + g_ = g.reshape(B, H, NT, BT) + g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :] + g_diff = g_diff.clamp(-60, 60).exp() + g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device) + else: + diff = NT * BT - T + g_ = torch.cat((g, torch.zeros(B, H, diff).to(g.device)), dim=-1).reshape(B, H, NT, BT) + g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :] + g_diff = g_diff.clamp(-60, 60).exp() + g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device) + bias = torch.arange(0, BT).to(g.device) + o_t = (NT - 1) * BT + bias + m_t = o_t < T + m_A = m_t[:, None] & m_t + g_diff[:, :, -1] *= m_A + else: + g_diff = None + + chunk_bwd_kernel_dqkwg[grid]( + q=q, + k=k, + v=v, + h=h, + g=g, + g_gamma=g_gamma, + do=do, + dh=dh, + dv=dv, + w=w, + dw=dw, + dq=dq, + dk=dk, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + gdiff=g_diff, + ) + + if dg is not None: + dg = dg.sum(0) + dg = dg.transpose(1, 2).contiguous() + return dq, dk, dw, dg + + +def chunk_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + do: torch.Tensor, + g: Optional[torch.Tensor] = None, + g_gamma: Optional[torch.Tensor] = None, + scale: float = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + + BK = 128 + BV = 128 + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + g = g.transpose(1, 2).contiguous() + dv = torch.empty_like(do) + grid = (NT, B) + chunk_bwd_kernel_dv_local[grid]( + q=q, + k=k, + g=g, + g_gamma=g_gamma, + do=do, + dv=dv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dv + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, + g_gamma: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) # noqa: F841 + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + if cu_seqlens is None: + N, chunk_offsets = B, None + else: + N, chunk_offsets = ( + len(cu_seqlens) - 1, + prepare_chunk_offsets(cu_seqlens, BT), + ) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + g = g.transpose(1, 2).contiguous() + h = h.contiguous() + CV_kernel_num = 24 + chunk_fwd_kernel_o[(CV_kernel_num,)]( + q, + k, + v, + h, + g, + g_gamma, + o, + cu_seqlens, + chunk_offsets, + scale, + T=T, + H=H, + N=N, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=128, + BV=128, + ) + return o + + +bwd_chunk_dqkwg = chunk_bwd_dqkwg +bwd_chunk_dv_local = chunk_bwd_dv_local diff --git a/src/llamafactory/third_party/triton/chunk_scaled_dot_kkt.py b/src/llamafactory/third_party/triton/chunk_scaled_dot_kkt.py new file mode 100644 index 000000000..0b3dec4bd --- /dev/null +++ b/src/llamafactory/third_party/triton/chunk_scaled_dot_kkt.py @@ -0,0 +1,359 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import prepare_chunk_indices + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, + NT, + B, + TOTAL_TASKS, +): + core_id = tl.program_id(0) + num_blocks = tl.num_programs(0) + T_max = T + + base_tasks_per_block = TOTAL_TASKS // num_blocks + remainder_tasks = TOTAL_TASKS % num_blocks + + if core_id < remainder_tasks: + tasks_this_core = base_tasks_per_block + 1 + start_idx = core_id * tasks_this_core + else: + tasks_this_core = base_tasks_per_block + start_idx = core_id * base_tasks_per_block + remainder_tasks + + for idx in range(start_idx, start_idx + tasks_this_core): + i_b = idx // NT + local_idx = idx % NT + + if IS_VARLEN: + i_n = tl.load(chunk_indices + local_idx * 2).to(tl.int32) + i_t = tl.load(chunk_indices + local_idx * 2 + 1).to(tl.int32) + bos = tl.load(cu_seqlens + i_n).to(tl.int32) + eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T_local = eos - bos + else: + bos, eos = 0, T + i_t = local_idx + T_local = T + + for i_h in range(H): + k_batch_off = i_b * T_max * H * K + beta_batch_off = i_b * H * T_max + g_batch_off = i_b * H * T_max + A_batch_off = i_b * T_max * H * BT + + p_beta = tl.make_block_ptr( + beta + beta_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,) + ) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + k_batch_off + (bos * H + i_h) * K, + (T_local, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + dot_product = tl.dot(b_k, tl.trans(b_k)) + + o_t = i_t * BT + tl.arange(0, BT) + o_t = o_t.to(tl.float32) + T_mask = (o_t < T_local).to(tl.float32) + + row_indices = tl.arange(0, BT)[:, None] + col_indices = tl.arange(0, BT)[None, :] + tril_mask = (row_indices > col_indices).to(tl.float32) + tril_mask = tril_mask * T_mask[:, None] + masked_dot = dot_product * tril_mask + b_A += masked_dot + + if USE_G: + p_g = tl.make_block_ptr( + g + g_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,) + ) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_g_diff = tl.minimum(tl.maximum(b_g_diff, -50.0), 50.0) + b_A *= tl.exp(b_g_diff) + b_A *= b_beta[:, None] + + p_A = tl.make_block_ptr( + A + A_batch_off + (bos * H + i_h) * BT, (T_local, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune(configs=[triton.Config({"BK": BK}) for BK in [32, 64]], key=["BC"]) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_i, i_j = i_c // NC, i_c % NC + + for i_h in range(H): + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T_val = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + T_val = T + + should_compute = (i_t * BT + i_i * BC < T_val) and (i_i > i_j) + + if should_compute: + k_ptr = k + (bos * H + i_h) * K + g_ptr = g + (bos * H + i_h) * K + A_ptr = A + (bos * H + i_h) * BT + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T_val,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0) + ) + p_g = tl.make_block_ptr( + g_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0) + ) + b_kt = tl.make_block_ptr( + k_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1) + ) + p_gk = tl.make_block_ptr( + g_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1) + ) + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + b_gn = tl.load(g_ptr + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.exp(b_g - b_gn[None, :]) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kt = tl.load(b_kt, boundary_check=(0, 1)) * tl.exp(b_gn[:, None] - b_gk) + b_A += tl.dot(b_k, b_kt) + b_A *= b_beta[:, None] + + p_A = tl.make_block_ptr(A_ptr, (T_val, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_i, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + for i_h in range(H): + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T_val = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + T_val = T + + should_compute = i_t * BT + i_i * BC < T_val + + if should_compute: + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + o_i) < T_val + o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC + + p_k = tl.make_block_ptr( + k + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0) + ) + p_g = tl.make_block_ptr( + g + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0) + ) + p_beta = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h + + b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.load(p_beta, mask=m_A, other=0)[:, None] + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k + + for j in range(0, min(BC, T_val - i_t * BT - i_i * BC)): + b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_k * b_kt[None, :] * tl.exp(b_g - b_gk[None, :]), 1) + # 转化成f32 + o_i_tmp = o_i.to(tl.float32) + b_A = tl.where(o_i_tmp > j, b_A, 0.0) + + tl.store(A + o_A + j, b_A, mask=m_A) + p_kt += H * K + p_gk += H * K + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + r"""Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + gk (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + B, T, H, K = k.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + beta = beta.transpose(1, 2).contiguous() + g = g.transpose(1, 2).contiguous() + BK = 128 + kernel_num = 24 + + if gk is None: + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(kernel_num,)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + NT=NT, + B=B, + TOTAL_TASKS=B * NT, + ) + return A + + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + BK = max(triton.next_power_of_2(K), 16) + A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype) + grid = (NT, NC * NC, B) + chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid]( + k=k, + g=gk, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + NC=NC, + ) + + grid = (NT, NC, B) + chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid]( + k=k, + g=gk, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + ) + return A diff --git a/src/llamafactory/third_party/triton/cumsum.py b/src/llamafactory/third_party/triton/cumsum.py new file mode 100644 index 000000000..5c5068d32 --- /dev/null +++ b/src/llamafactory/third_party/triton/cumsum.py @@ -0,0 +1,147 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import prepare_chunk_indices + + +@triton.heuristics( + {"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None} +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BLOCK_T: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, + CHUNK_SIZE: tl.constexpr = 64, +): + i_block, i_b = tl.program_id(0), tl.program_id(1) + N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE + + if IS_VARLEN: + i_s, i_block = ( + tl.load(chunk_indices + i_block * 2).to(tl.int32), + tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32), + ) + + bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) + ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) + b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32) + b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H)) + b_s = tl.trans(b_s, (1, 0, 2)) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + if HAS_SCALE: + b_o *= scale + b_o = tl.trans(b_o, (1, 0, 2)) + b_o = tl.reshape(b_o, (BLOCK_T, H)) + + tl.store(ptr_o, b_o.to(ptr_o.dtype.element_ty), boundary_check=(0,)) + return + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + B, T, H = g.shape + if chunk_size != 2 ** (chunk_size.bit_length() - 1): + raise ValueError(f"chunk_size must be a power of 2, chunk_size is{chunk_size}") + # We adjust the tiling strategy to prevent overflow in in backward passes and context parallel scenarios + # while maximizing UB utilization where possible. + # The tiling strategy is as follows: + # 1. BT must be greater than or equal to chunk_size. + # 2. UB estimation varies directly with H. + # 3. BT in reverse mode is smaller than in forward mode. + BT = max(chunk_size, triton.next_power_of_2((1 << 11 if reverse else 1 << 12) // H)) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B) + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BLOCK_T=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + CHUNK_SIZE=chunk_size, + ) + return g + + +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + if g.shape[0] != 1: + raise ValueError( + f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}" + ) + if len(g.shape) == 3: + return chunk_local_cumsum_scalar( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + ) + else: + raise ValueError( + f"Unsupported input shape {g.shape}, " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/src/llamafactory/third_party/triton/solve_tril.py b/src/llamafactory/third_party/triton/solve_tril.py new file mode 100644 index 000000000..4ffa3bdc3 --- /dev/null +++ b/src/llamafactory/third_party/triton/solve_tril.py @@ -0,0 +1,272 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import input_guard, make_tensor_descriptor, prepare_chunk_indices + + +FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T", "TPP"]) +def solve_tril_16x16_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + TPP: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + pid_t, pid_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = pid_bh // H, pid_bh % H + + base_t = pid_t * TPP + + if IS_VARLEN: + i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32) + bos = tl.load(cu_seqlens + i_n).to(tl.int32) + eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T_eff = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + T_eff = T + + o_i = tl.arange(0, 16) # noqa: F841 + o_i_fp32 = tl.arange(0, 16).to(tl.float32) + m_A = o_i_fp32[:, None] > o_i_fp32[None, :] + m_I = o_i_fp32[:, None] == o_i_fp32[None, :] + + A = A + (bos * H + i_h) * BT + Ai = Ai + (bos * H + i_h) * BT + + for tpp in tl.static_range(0, TPP): + tile_t = base_t + tpp + tile_row = tile_t * 16 + + offset = (tile_t * 16) % BT + + if not USE_TMA: + p_A = tl.make_block_ptr(A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0)) + b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16]) + b_A_raw = desc.load([tile_row, offset]).to(tl.float32) + + b_A_neg = -b_A_raw + b_A = b_A_neg * m_A + for i in range(2, min(16, T_eff - tile_row)): + slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1]) + b_a_val = tl.reshape(slice_res, (16,), can_reorder=True) + dot_prod = tl.sum(b_a_val[:, None] * b_A, 0) + b_a_update = b_a_val + dot_prod + b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A) + b_A += m_I + + if not USE_TMA: + p_Ai = tl.make_block_ptr(Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0)) + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + else: + desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T", "TPP"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + TPP: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + + b_Ai_11 += m_I + b_Ai_22 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + + b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_64x64_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_i = tl.arange(0, 64) + m_I = o_i[:, None] == o_i[None, :] + + A = A + (bos * H + i_h) * BT + Ai = Ai + (bos * H + i_h) * 64 + + offset = (i_t * 64) % BT + if not USE_TMA: + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 64, offset), (64, 64), (1, 0)) + b_A = -tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [64, 64]) + desc_o = make_tensor_descriptor(Ai, [T, 64], [H * 64, 1], [64, 64]) + b_A = -desc.load([i_t * 64, offset]).to(tl.float32) + + for i in range(2, min(64, T - i_t * 64)): + b_a = -tl.load(A + (i_t * 64 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + b_A = tl.where((o_i == i)[:, None], b_a, b_A) + b_A += m_I + if not USE_TMA: + p_Ai = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (64, 64), (1, 0)) + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + else: + desc_o.store([i_t * 64, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@input_guard +def solve_tril( + A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float +) -> torch.Tensor: + """Compute the inverse of the matrix I + A + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, BT], where BT should only be 16, 32, or 64. + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. Default: `None`. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. + + Returns: + (I + A)^-1 with the same shape as A + """ # noqa: D205 + if A.shape[-1] not in [16, 32, 64]: + raise ValueError(f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}") + output_dtype = A.dtype if output_dtype is None else output_dtype + + B, T, H, BT = A.shape + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + Ai = torch.zeros_like(A, dtype=output_dtype) + + if BT == 16: + merge_fn = solve_tril_16x16_kernel + elif BT == 32: + merge_fn = merge_16x16_to_32x32_inverse_kernel + elif BT == 64: + merge_fn = solve_tril_64x64_kernel + + merge_fn[NT, B * H]( + A=A, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + USE_TMA=False, + DOT_PRECISION=FLA_TRIL_PRECISION, + ) + return Ai diff --git a/src/llamafactory/third_party/triton/utils.py b/src/llamafactory/third_party/triton/utils.py new file mode 100644 index 000000000..e817e2eb6 --- /dev/null +++ b/src/llamafactory/third_party/triton/utils.py @@ -0,0 +1,359 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import itertools +import logging +import os +import warnings +from collections.abc import Callable +from enum import Enum +from typing import Any, Optional + +import torch +import triton +import triton.language as tl +import triton.language.extra.libdevice as tldevice +import triton.runtime.driver as driver +from packaging import version + + +logger = logging.getLogger(__name__) + +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" + + +def tensor_cache(fn: Optional[Callable[..., torch.Tensor]] = None, *, maxsize: int = 1) -> Any: + """A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the outputs of the decorated function for the most recent + set of input tensors, up to `maxsize` entries. If the function is called again with + the same input tensors, it will return the cached result. + + When maxsize=1 (default), the behavior is identical to caching only the most recent result. + Can be used as @tensor_cache or @tensor_cache(maxsize=n). + + Args: + fn (Callable[..., torch.Tensor], optional): + The function to be decorated when used without parentheses. + maxsize (int): + Maximum number of input combinations to cache. Default is 1. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with caching. + """ + if maxsize < 1: + raise ValueError("maxsize must be at least 1") + + def _is_match(a: Any, b: Any) -> bool: + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return a is b + try: + return a == b + except Exception: + return a is b + + def _make_wrapper(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + cache: list = [] + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + for i, (cached_args, cached_kwargs, cached_result) in enumerate(cache): + if len(args) == len(cached_args) and len(kwargs) == len(cached_kwargs): + if all(_is_match(a, b) for a, b in zip(args, cached_args)) and all( + k in cached_kwargs and _is_match(v, cached_kwargs[k]) for k, v in kwargs.items() + ): + if i != 0: + cache.insert(0, cache.pop(i)) + return cached_result + + result = fn(*args, **kwargs) + cache.insert(0, (args, kwargs, result)) + if len(cache) > maxsize: + cache.pop() + return result + + return wrapper + + if fn is not None: + return _make_wrapper(fn) + return _make_wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache(maxsize=3) +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +def get_abs_err(x, y): + return (x.detach() - y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item() + base = (x.detach()).flatten().square().mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" + logger.info(msg) + error_rate = get_err_ratio(ref, tri) + if abs_atol <= err_atol: + return + if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)): + if error_rate > ratio: + warnings.warn(msg) + else: + assert error_rate < ratio, msg + + +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, "make_tensor_descriptor"): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + _cpu_device_warning() + return "cpu" + + +def map_triton_backend_to_torch_device() -> str: + backend = get_available_device() # 'cuda' | 'hip' | 'xpu' | 'cpu' | ... + return {"cuda": "cuda", "hip": "cuda", "xpu": "xpu"}.get(backend, backend) + + +device = get_available_device() if get_available_device() != "hip" else "cuda" +device_torch_lib = getattr(torch, device) +device_platform = get_available_device() +is_amd = device_platform == "hip" +is_nvidia = device_platform == "cuda" +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +) + +is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8 +is_tma_supported = ( + (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) + and os.environ.get("FLA_NO_USE_TMA", "0") != "1" + and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") + ) +) + +if is_nvidia and not is_tf32_supported: + # Make old card happy, since triton will use tf32 by default. + # This is a workaround for old nvidia card. + os.environ["TRITON_F32_DEFAULT"] = "ieee" + + +@functools.cache +def check_pytorch_version(version_s: str = "2.4") -> bool: + return version.parse(torch.__version__) >= version.parse(version_s) + + +if check_pytorch_version("2.4"): + device = "cuda" if device == "cpu" else device + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) + + def custom_device_ctx(index: int): + return device_torch_lib.device(index) +else: + assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0." + autocast_custom_fwd = device_torch_lib.amp.custom_fwd + autocast_custom_bwd = device_torch_lib.amp.custom_bwd + + def custom_device_ctx(index: int): + return torch.cuda.device(index) + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """A decorator to make sure all input tensors are contiguous and set the device based on input tensors.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = custom_device_ctx(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +def _cpu_device_warning(): + warnings.warn(("Triton is not supported on current platform, roll back to CPU."), stacklevel=1) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) + + +if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": + exp = tldevice.fast_expf + exp2 = tldevice.exp2 + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + exp = tl.exp + exp2 = tl.math.exp2 + log = tl.log + log2 = tl.log2 + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + _cpu_device_warning() + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False + + +def get_autotune_config( + multibuffer_list: tuple = (False,), + unit_flag_list: tuple = (False,), + limit_auto_multi_buffer_only_for_local_buffer_list: tuple = (False,), + limit_auto_multi_buffer_of_local_buffer_list: tuple = ("no-l0c",), + set_workspace_multibuffer_list: tuple = (2, 4), + enable_hivm_auto_cv_balance_list: tuple = (True,), + tile_mix_vector_loop_num_list: tuple = (2, 4), + tile_mix_cube_loop_num_list: tuple = (2, 4), +): + configs = [] + for ( + multibuffer, + unit_flag, + limit_auto_multi_buffer_only_for_local_buffer, + limit_auto_multi_buffer_of_local_buffer, + ) in itertools.product( + list(multibuffer_list), + list(unit_flag_list), + list(limit_auto_multi_buffer_only_for_local_buffer_list), + list(limit_auto_multi_buffer_of_local_buffer_list), + ): + base_config_dict = { + "multibuffer": multibuffer, + "unit_flag": unit_flag, + "limit_auto_multi_buffer_only_for_local_buffer": limit_auto_multi_buffer_only_for_local_buffer, + "limit_auto_multi_buffer_of_local_buffer": limit_auto_multi_buffer_of_local_buffer, + } + + if limit_auto_multi_buffer_only_for_local_buffer: + configs.append(triton.Config(base_config_dict)) + else: + for ( + set_workspace_multibuffer, + enable_hivm_auto_cv_balance, + tile_mix_vector_loop, + tile_mix_cube_loop, + ) in itertools.product( + list(set_workspace_multibuffer_list), + list(enable_hivm_auto_cv_balance_list), + list(tile_mix_vector_loop_num_list), + list(tile_mix_cube_loop_num_list), + ): + full_config_dict = base_config_dict.copy() + full_config_dict.update( + { + "set_workspace_multibuffer": set_workspace_multibuffer, + "enable_hivm_auto_cv_balance": enable_hivm_auto_cv_balance, + "tile_mix_vector_loop": tile_mix_vector_loop, + "tile_mix_cube_loop": tile_mix_cube_loop, + } + ) + configs.append(triton.Config(full_config_dict)) + return configs + + +def get_npu_properties(): + return driver.active.utils.get_device_properties(torch.npu.current_device()) diff --git a/src/llamafactory/third_party/triton/wy_fast.py b/src/llamafactory/third_party/triton/wy_fast.py new file mode 100644 index 000000000..d22c12f6e --- /dev/null +++ b/src/llamafactory/third_party/triton/wy_fast.py @@ -0,0 +1,387 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import exp, prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def prepare_wy_repr_bwd_kernel( + k, + v, + beta, + g, + A, + dw, + du, + dk, + dv, + dbeta, + dg, + cu_seqlens, + chunk_indices, + T, + B, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + NT: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + core_id = tl.program_id(0) + total_cores = tl.num_programs(0) + T_max = T + + base_chunks_per_pid = NT // total_cores + remainder_chunks = NT % total_cores + + if core_id < remainder_chunks: + chunks_this_pid = base_chunks_per_pid + 1 + start_idx = core_id * chunks_this_pid + else: + chunks_this_pid = base_chunks_per_pid + start_idx = core_id * chunks_this_pid + remainder_chunks + + for idx in range(start_idx, start_idx + chunks_this_pid): + for i_b in range(B): + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + idx * 2).to(tl.int32), + tl.load(chunk_indices + idx * 2 + 1).to(tl.int32), + ) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_t = idx + bos, eos = i_b * T, i_b * T + T + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + for i_h in range(0, H): + if IS_VARLEN: + offset = bos + i_h * T_max + else: + offset = bos * H + i_h * T_max + + p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1) + ) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_exp = tl.exp(b_g) + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + b_dg = tl.zeros([BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_dk = tl.make_block_ptr( + dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_dw = tl.make_block_ptr( + dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta_g = (b_k * b_beta[:, None] * b_g_exp[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta_g)) + b_dk_beta_g = tl.dot(b_A, b_dw) + b_dk = b_dk_beta_g * b_beta[:, None] * b_g_exp[:, None] + b_dbeta += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None], 1) + b_dg += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None] * b_beta[:, None], 1) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_dv = tl.make_block_ptr( + dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_du = tl.make_block_ptr( + du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta)) + b_dv_beta = tl.dot(b_A, b_du) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(m_A, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(m_A, -b_dA * exp(b_g[:, None] - b_g[None, :]), 0) + b_dA = b_dA.to(k.dtype.element_ty) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_dk = tl.make_block_ptr( + dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_A += tl.dot(b_k_beta, tl.trans(b_k)) + b_dk_beta = tl.dot(b_dA, b_k) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA_A = b_dA * b_A + b_dg += tl.sum(b_dA_A, axis=1) - tl.sum(b_dA_A, axis=0) + p_dg = tl.make_block_ptr(dg + offset, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_dbeta = tl.make_block_ptr(dbeta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + gk, + cu_seqlens, + chunk_indices, + T_tmp, + B, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + NT: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + core_id = tl.program_id(0) + total_cores = tl.num_programs(0) + T_max = T_tmp + + base_chunks_per_pid = NT // total_cores + remainder_chunks = NT % total_cores + + if core_id < remainder_chunks: + chunks_this_pid = base_chunks_per_pid + 1 + start_idx = core_id * chunks_this_pid + else: + chunks_this_pid = base_chunks_per_pid + start_idx = core_id * chunks_this_pid + remainder_chunks + + for idx in range(start_idx, start_idx + chunks_this_pid): + for i_b in range(B): + for i_h in range(0, H): + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + idx * 2).to(tl.int32), + tl.load(chunk_indices + idx * 2 + 1).to(tl.int32), + ) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + offset = bos + i_h * T_max + T = eos - bos + else: + T = T_tmp + i_t = idx + bos, eos = i_b * T, i_b * T + T + offset = bos * H + i_h * T_max + + p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + if USE_G: + p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + if USE_G: + b_kb *= b_g[:, None] + if USE_GK: + p_gk = tl.make_block_ptr( + gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + b_kb *= tl.exp(tl.load(p_gk, boundary_check=(0, 1))) + b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = A.shape[-1] + BK = 128 + BV = 128 + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g = g.transpose(1, 2).contiguous() if g is not None else None + beta = beta.transpose(1, 2).contiguous() + + w = torch.empty_like(k) + u = torch.empty_like(v) + cv_kernel_num = 24 + recompute_w_u_fwd_kernel[(cv_kernel_num,)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T_tmp=T, + B=B, + H=H, + K=K, + V=V, + NT=NT, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u + + +def prepare_wy_repr_bwd( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int = 64, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 128 + BV = 128 + beta = beta.transpose(1, 2).contiguous() + g = g.transpose(1, 2).contiguous() + + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.empty_like(beta) + dg = torch.empty_like(g) + + cv_kernel_num = 24 + prepare_wy_repr_bwd_kernel[(cv_kernel_num,)]( + k=k, + v=v, + beta=beta, + g=g, + A=A, + dw=dw, + du=du, + dk=dk, + dv=dv, + dbeta=dbeta, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + K=K, + V=V, + NT=NT, + BT=BT, + BK=BK, + BV=BV, + ) + + dbeta = dbeta.transpose(1, 2).contiguous() + dg = dg.transpose(1, 2).contiguous() + + return dk, dv, dbeta, dg + + +bwd_prepare_wy_repr = prepare_wy_repr_bwd + +fwd_recompute_w_u = recompute_w_u_fwd