mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
[model] Patch GDN for NPU (#10504)
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
This commit is contained in:
@@ -36,6 +36,7 @@ COPY . /app
|
|||||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
RUN pip uninstall -y torch torchvision torchaudio
|
RUN pip uninstall -y torch torchvision torchaudio
|
||||||
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
||||||
|
RUN pip install --no-cache-dir -r requirements/triton_ascend.txt
|
||||||
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
||||||
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||||
|
|||||||
20
examples/accelerate/fsdp2_config_qwen35_moe.yaml
Normal file
20
examples/accelerate/fsdp2_config_qwen35_moe.yaml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer,Qwen3_5MoeVisionBlock
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
use_cpu: false
|
||||||
51
examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
Normal file
51
examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Start FSDP2 full fine-tuning on Ascend NPU
|
||||||
|
# Usage:
|
||||||
|
# accelerate launch \
|
||||||
|
# --config_file examples/accelerate/fsdp2_config_qwen35_moe.yaml \
|
||||||
|
# src/train.py examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
|
||||||
|
#
|
||||||
|
# Note: Change `num_processes` in fsdp2_config_qwen35_moe.yaml to match your NPU count
|
||||||
|
|
||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3.5-35B-A3B
|
||||||
|
trust_remote_code: true
|
||||||
|
use_v1_kernels: false
|
||||||
|
flash_attn: fa2
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: alpaca_en_demo
|
||||||
|
template: qwen3_5_nothink
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
packing: false
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/Qwen3.5-35B/lora/sft
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 2000
|
||||||
|
max_steps: 2000
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 1800
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
disable_gradient_checkpointing: true
|
||||||
2
requirements/triton_ascend.txt
Normal file
2
requirements/triton_ascend.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
--extra-index-url https://triton-ascend.osinfra.cn/pypi/simple
|
||||||
|
triton-ascend==3.2.1
|
||||||
@@ -81,6 +81,8 @@ def apply_liger_kernel(
|
|||||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
|
||||||
elif model_type == "qwen3_5":
|
elif model_type == "qwen3_5":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
|
||||||
|
elif model_type == "qwen3_5_moe":
|
||||||
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe as apply_liger_kernel
|
||||||
elif model_type == "gpt_oss":
|
elif model_type == "gpt_oss":
|
||||||
try:
|
try:
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from peft import PeftModel
|
|||||||
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
|
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import infer_optim_dtype
|
||||||
@@ -84,7 +85,60 @@ def _check_fla_dependencies() -> None:
|
|||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
def patch_qwen3_5_forward_npu(model: "PreTrainedModel") -> None:
|
||||||
|
"""Patch for Qwen3.5 models on NPU by importing torch_npu to enable torch.cuda compatibility.
|
||||||
|
|
||||||
|
On NPU, torch.cuda operations will fail unless torch_npu is imported.
|
||||||
|
torch_npu provides compatibility layer that maps torch.cuda calls to NPU operations.
|
||||||
|
|
||||||
|
Also replaces chunk_gated_delta_rule with NPU-compatible implementation.
|
||||||
|
"""
|
||||||
|
import importlib.metadata
|
||||||
|
|
||||||
|
if "Ascend910" not in torch.npu.get_device_name(0):
|
||||||
|
logger.warning_rank0("Currently only 910B series NPUs are supported for the NPU GDN patch.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
importlib.metadata.version("triton_ascend")
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
logger.warning_rank0(
|
||||||
|
"triton_ascend not installed, skipping NPU GDN patch. "
|
||||||
|
"To enable it on NPU, reinstall Triton with the Ascend build: "
|
||||||
|
"`pip uninstall -y triton && pip install -r requirements/triton_ascend.txt`. "
|
||||||
|
"Note: triton and triton_ascend cannot coexist — triton must be uninstalled first."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info_rank0("triton_ascend detected for NPU compatibility.")
|
||||||
|
|
||||||
|
from ..third_party.triton.chunk_gated_delta_rule import chunk_gated_delta_rule as npu_chunk_gated_delta_rule
|
||||||
|
|
||||||
|
if model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
|
||||||
|
try:
|
||||||
|
# Qwen3.5-MoE structure: model.model.language_model.layers
|
||||||
|
for layer in model.model.language_model.layers:
|
||||||
|
if hasattr(layer, "linear_attn"):
|
||||||
|
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
|
||||||
|
|
||||||
|
logger.info_rank0(
|
||||||
|
"Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5-MoE model."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
|
||||||
|
elif model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
|
||||||
|
try:
|
||||||
|
# Qwen3.5 structure: model.model.layers
|
||||||
|
for layer in model.model.layers:
|
||||||
|
if hasattr(layer, "linear_attn"):
|
||||||
|
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
|
||||||
|
|
||||||
|
logger.info_rank0("Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5 model.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_5_forward_gpu(model: "PreTrainedModel") -> None:
|
||||||
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
|
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
|
||||||
|
|
||||||
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
|
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
|
||||||
@@ -421,8 +475,12 @@ def patch_model(
|
|||||||
autocast_projector_dtype(model, model_args)
|
autocast_projector_dtype(model, model_args)
|
||||||
add_z3_leaf_module(model)
|
add_z3_leaf_module(model)
|
||||||
|
|
||||||
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
|
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"]:
|
||||||
patch_qwen3_5_forward(model)
|
if is_torch_npu_available():
|
||||||
|
patch_qwen3_5_forward_npu(model)
|
||||||
|
elif is_torch_cuda_available() and model_args.flash_attn == "fa2":
|
||||||
|
# this is the patch for packing/neat_packing for GPU GDN. And when setting packing, flash_attn must be fa2.
|
||||||
|
patch_qwen3_5_forward_gpu(model)
|
||||||
|
|
||||||
if not model_args.use_unsloth:
|
if not model_args.use_unsloth:
|
||||||
print_attn_implementation(model.config)
|
print_attn_implementation(model.config)
|
||||||
|
|||||||
594
src/llamafactory/third_party/triton/chunk_delta_h.py
vendored
Normal file
594
src/llamafactory/third_party/triton/chunk_delta_h.py
vendored
Normal file
@@ -0,0 +1,594 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .utils import get_autotune_config, get_npu_properties, prepare_chunk_indices, prepare_chunk_offsets
|
||||||
|
|
||||||
|
|
||||||
|
CUBE_CORE_NUM = get_npu_properties()["num_aicore"]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
|
"USE_GK": lambda args: args["gk"] is not None,
|
||||||
|
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||||
|
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||||
|
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||||
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.autotune(
|
||||||
|
configs=get_autotune_config(multibuffer_list=(False,)),
|
||||||
|
key=["H", "K", "V", "BT"],
|
||||||
|
)
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
w,
|
||||||
|
v_new,
|
||||||
|
g,
|
||||||
|
gk,
|
||||||
|
h,
|
||||||
|
h0,
|
||||||
|
ht,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_offsets,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
V: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BV: tl.constexpr,
|
||||||
|
NT: tl.constexpr,
|
||||||
|
USE_G: tl.constexpr,
|
||||||
|
USE_GK: tl.constexpr,
|
||||||
|
USE_INITIAL_STATE: tl.constexpr,
|
||||||
|
STORE_FINAL_STATE: tl.constexpr,
|
||||||
|
SAVE_NEW_VALUE: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
):
|
||||||
|
T_all = T
|
||||||
|
NT_all = NT
|
||||||
|
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||||
|
i_n, i_h = i_nh // H, i_nh % H
|
||||||
|
if IS_VARLEN:
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T = eos - bos
|
||||||
|
NT = tl.cdiv(T, BT)
|
||||||
|
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||||
|
else:
|
||||||
|
bos, eos = i_n * T, i_n * T + T
|
||||||
|
NT = tl.cdiv(T, BT)
|
||||||
|
boh = i_n * NT
|
||||||
|
|
||||||
|
# Initialize hidden states
|
||||||
|
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
if K > 64:
|
||||||
|
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
if K > 128:
|
||||||
|
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
if K > 192:
|
||||||
|
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
|
||||||
|
if IS_VARLEN:
|
||||||
|
v = v + (i_h * T_all + bos) * V
|
||||||
|
k = k + (i_h * T_all + bos) * K
|
||||||
|
w = w + (i_h * T_all + bos) * K
|
||||||
|
g = g + i_h * T_all + bos
|
||||||
|
h = h + (i_h * NT_all + boh) * K * V
|
||||||
|
if SAVE_NEW_VALUE:
|
||||||
|
v_new_base = v_new + (i_h * T_all + bos) * V
|
||||||
|
else:
|
||||||
|
v = v + (i_n * H + i_h) * T * V
|
||||||
|
k = k + (i_n * H + i_h) * T * K
|
||||||
|
w = w + (i_n * H + i_h) * T * K
|
||||||
|
g = g + (i_n * H + i_h) * T
|
||||||
|
h = h + (i_n * H + i_h) * NT * K * V
|
||||||
|
if SAVE_NEW_VALUE:
|
||||||
|
v_new_base = v_new + (i_n * H + i_h) * T * V
|
||||||
|
|
||||||
|
if USE_INITIAL_STATE:
|
||||||
|
h0_ptr = h0 + i_nh * K * V
|
||||||
|
if STORE_FINAL_STATE:
|
||||||
|
ht_ptr = ht + i_nh * K * V
|
||||||
|
|
||||||
|
# Load initial state
|
||||||
|
if USE_INITIAL_STATE:
|
||||||
|
p_h0_1 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||||
|
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
if K > 64:
|
||||||
|
p_h0_2 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||||
|
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
if K > 128:
|
||||||
|
p_h0_3 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||||
|
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
if K > 192:
|
||||||
|
p_h0_4 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||||
|
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
|
||||||
|
# Main recurrence over chunks
|
||||||
|
for i_t in range(NT):
|
||||||
|
# Store current hidden state h_t
|
||||||
|
p_h1 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 64:
|
||||||
|
p_h2 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 128:
|
||||||
|
p_h3 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 192:
|
||||||
|
p_h4 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
# Compute v_residual = v - w @ h
|
||||||
|
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 0), (BT, 64), (1, 0))
|
||||||
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
|
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||||
|
if K > 64:
|
||||||
|
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 64), (BT, 64), (1, 0))
|
||||||
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
|
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||||
|
if K > 128:
|
||||||
|
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 128), (BT, 64), (1, 0))
|
||||||
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
|
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||||
|
if K > 192:
|
||||||
|
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 192), (BT, 64), (1, 0))
|
||||||
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
|
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||||
|
|
||||||
|
p_v = tl.make_block_ptr(v, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
|
||||||
|
|
||||||
|
if SAVE_NEW_VALUE:
|
||||||
|
p_v_new = tl.make_block_ptr(v_new_base, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
last_idx = min((i_t + 1) * BT, T) - 1
|
||||||
|
|
||||||
|
# Apply output gate g
|
||||||
|
if USE_G:
|
||||||
|
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
|
||||||
|
b_g_last = tl.load(g + last_idx)
|
||||||
|
p_g = tl.make_block_ptr(g, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
|
b_v *= (m_t * tl.exp(b_g_last - b_g))[:, None]
|
||||||
|
b_g_last_exp = tl.exp(b_g_last)
|
||||||
|
b_h1 *= b_g_last_exp
|
||||||
|
if K > 64:
|
||||||
|
b_h2 *= b_g_last_exp
|
||||||
|
if K > 128:
|
||||||
|
b_h3 *= b_g_last_exp
|
||||||
|
if K > 192:
|
||||||
|
b_h4 *= b_g_last_exp
|
||||||
|
|
||||||
|
# Apply key gate gk
|
||||||
|
if USE_GK:
|
||||||
|
o_k1 = tl.arange(0, 64).to(tl.float32)
|
||||||
|
gk_base_ptr = gk + (i_n * H + i_h) * T * K
|
||||||
|
b_gk_last1 = tl.load(gk_base_ptr + last_idx * K + o_k1, mask=(o_k1 < K), other=0.0)
|
||||||
|
b_h1 *= tl.exp(b_gk_last1)[:, None]
|
||||||
|
if K > 64:
|
||||||
|
o_k2 = 64 + o_k1
|
||||||
|
b_gk_last2 = tl.load(gk_base_ptr + last_idx * K + o_k2, mask=(o_k2 < K), other=0.0)
|
||||||
|
b_h2 *= tl.exp(b_gk_last2)[:, None]
|
||||||
|
if K > 128:
|
||||||
|
o_k3 = 128 + o_k1
|
||||||
|
b_gk_last3 = tl.load(gk_base_ptr + last_idx * K + o_k3, mask=(o_k3 < K), other=0.0)
|
||||||
|
b_h3 *= tl.exp(b_gk_last3)[:, None]
|
||||||
|
if K > 192:
|
||||||
|
o_k4 = 192 + o_k1
|
||||||
|
b_gk_last4 = tl.load(gk_base_ptr + last_idx * K + o_k4, mask=(o_k4 < K), other=0.0)
|
||||||
|
b_h4 *= tl.exp(b_gk_last4)[:, None]
|
||||||
|
|
||||||
|
b_v = b_v.to(k.dtype.element_ty)
|
||||||
|
|
||||||
|
# Update hidden state: h += k @ v
|
||||||
|
p_k = tl.make_block_ptr(k, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
if USE_GK:
|
||||||
|
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_k = (b_k * tl.exp(b_gk_last1[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
|
||||||
|
b_h1 += tl.dot(b_k, b_v)
|
||||||
|
|
||||||
|
if K > 64:
|
||||||
|
p_k = tl.make_block_ptr(k, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
if USE_GK:
|
||||||
|
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_k = (b_k * tl.exp(b_gk_last2[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
|
||||||
|
b_h2 += tl.dot(b_k, b_v)
|
||||||
|
|
||||||
|
if K > 128:
|
||||||
|
p_k = tl.make_block_ptr(k, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
if USE_GK:
|
||||||
|
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_k = (b_k * tl.exp(b_gk_last3[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
|
||||||
|
b_h3 += tl.dot(b_k, b_v)
|
||||||
|
|
||||||
|
if K > 192:
|
||||||
|
p_k = tl.make_block_ptr(k, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
if USE_GK:
|
||||||
|
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_k = (b_k * tl.exp(b_gk_last4[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
|
||||||
|
b_h4 += tl.dot(b_k, b_v)
|
||||||
|
|
||||||
|
# Store final state
|
||||||
|
if STORE_FINAL_STATE:
|
||||||
|
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 64:
|
||||||
|
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 128:
|
||||||
|
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 192:
|
||||||
|
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_gated_delta_rule_fwd_h(
|
||||||
|
k: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
u: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
gk: Optional[torch.Tensor] = None,
|
||||||
|
initial_state: Optional[torch.Tensor] = None,
|
||||||
|
output_final_state: bool = False,
|
||||||
|
chunk_size: int = 64, # default:64
|
||||||
|
save_new_value: bool = True,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
B, T, H, K, V = *k.shape, u.shape[-1]
|
||||||
|
BT = chunk_size
|
||||||
|
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||||
|
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||||
|
if cu_seqlens is None:
|
||||||
|
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||||
|
else:
|
||||||
|
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
||||||
|
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||||
|
|
||||||
|
h = k.new_empty(B, NT, H, K, V).permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||||
|
|
||||||
|
BV = 128
|
||||||
|
|
||||||
|
v_new = torch.empty_like(u).permute(0, 2, 1, 3).contiguous() if save_new_value else None
|
||||||
|
k = k.permute(0, 2, 1, 3).contiguous()
|
||||||
|
w = w.permute(0, 2, 1, 3).contiguous()
|
||||||
|
u = u.permute(0, 2, 1, 3).contiguous()
|
||||||
|
g = g.permute(0, 2, 1).contiguous()
|
||||||
|
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[(triton.cdiv(V, BV), N * H)](
|
||||||
|
k=k,
|
||||||
|
v=u,
|
||||||
|
w=w,
|
||||||
|
v_new=v_new,
|
||||||
|
g=g,
|
||||||
|
gk=gk,
|
||||||
|
h=h,
|
||||||
|
h0=initial_state,
|
||||||
|
ht=final_state,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
V=V,
|
||||||
|
BT=BT,
|
||||||
|
BV=BV,
|
||||||
|
NT=NT,
|
||||||
|
)
|
||||||
|
h = h.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
v_new = v_new.permute(0, 2, 1, 3).contiguous()
|
||||||
|
return h, v_new, final_state
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
|
"USE_GK": lambda args: args["gk"] is not None,
|
||||||
|
"USE_INITIAL_STATE": lambda args: args["dh0"] is not None,
|
||||||
|
"USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None,
|
||||||
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.autotune(
|
||||||
|
configs=get_autotune_config(multibuffer_list=(True, False)),
|
||||||
|
key=["H", "K", "V", "BT", "BV", "USE_G", "IS_VARLEN"],
|
||||||
|
)
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
w,
|
||||||
|
g,
|
||||||
|
gk,
|
||||||
|
dht,
|
||||||
|
dh0,
|
||||||
|
do,
|
||||||
|
dh,
|
||||||
|
dv,
|
||||||
|
dv2,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_offsets,
|
||||||
|
scale,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
V: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BV: tl.constexpr,
|
||||||
|
USE_G: tl.constexpr,
|
||||||
|
USE_GK: tl.constexpr,
|
||||||
|
USE_INITIAL_STATE: tl.constexpr,
|
||||||
|
USE_FINAL_STATE_GRADIENT: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
):
|
||||||
|
T_all = T
|
||||||
|
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||||
|
i_n, i_h = i_nh // H, i_nh % H
|
||||||
|
if IS_VARLEN:
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T = eos - bos
|
||||||
|
NT = tl.cdiv(T, BT)
|
||||||
|
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||||
|
else:
|
||||||
|
bos, eos = i_n * T, i_n * T + T
|
||||||
|
NT = tl.cdiv(T, BT)
|
||||||
|
boh = i_n * NT
|
||||||
|
|
||||||
|
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
if K > 64:
|
||||||
|
b_dh2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
if K > 128:
|
||||||
|
b_dh3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
if K > 192:
|
||||||
|
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||||
|
|
||||||
|
q += (bos * H + i_h) * K
|
||||||
|
k += (bos * H + i_h) * K
|
||||||
|
w += (bos * H + i_h) * K
|
||||||
|
do += (bos * H + i_h) * V
|
||||||
|
dv += (bos * H + i_h) * V
|
||||||
|
dv2 += (bos * H + i_h) * V
|
||||||
|
dh += (boh * H + i_h) * K * V
|
||||||
|
if USE_GK:
|
||||||
|
gk += (bos * H + i_h) * K
|
||||||
|
|
||||||
|
if USE_INITIAL_STATE:
|
||||||
|
dh0 += i_nh * K * V
|
||||||
|
if USE_FINAL_STATE_GRADIENT:
|
||||||
|
dht += i_nh * K * V
|
||||||
|
|
||||||
|
stride_v = H * V
|
||||||
|
stride_h = H * K * V
|
||||||
|
stride_k = H * K
|
||||||
|
|
||||||
|
if USE_FINAL_STATE_GRADIENT:
|
||||||
|
p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||||
|
b_dh1 += tl.load(p_dht1, boundary_check=(0, 1))
|
||||||
|
if K > 64:
|
||||||
|
p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||||
|
b_dh2 += tl.load(p_dht2, boundary_check=(0, 1))
|
||||||
|
if K > 128:
|
||||||
|
p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||||
|
b_dh3 += tl.load(p_dht3, boundary_check=(0, 1))
|
||||||
|
if K > 192:
|
||||||
|
p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||||
|
b_dh4 += tl.load(p_dht4, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
for i_t in range(NT - 1, -1, -1):
|
||||||
|
p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 64:
|
||||||
|
p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 128:
|
||||||
|
p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 192:
|
||||||
|
p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
last_idx = min((i_t + 1) * BT, T) - 1
|
||||||
|
if USE_G:
|
||||||
|
if IS_VARLEN:
|
||||||
|
bos_g = i_h * T_all + bos
|
||||||
|
else:
|
||||||
|
bos_g = (i_n * H + i_h) * T_all
|
||||||
|
bg_last = tl.load(g + bos_g + last_idx)
|
||||||
|
bg_last_exp = tl.exp(bg_last)
|
||||||
|
p_g = tl.make_block_ptr(
|
||||||
|
base=g + bos_g, shape=(T,), strides=(1,), offsets=(i_t * BT,), block_shape=(BT,), order=(0,)
|
||||||
|
)
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
|
b_g_exp = tl.exp(b_g)
|
||||||
|
|
||||||
|
p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
|
||||||
|
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
# Update dv
|
||||||
|
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
if USE_GK:
|
||||||
|
o_k1 = tl.arange(0, 64)
|
||||||
|
b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0)
|
||||||
|
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
|
||||||
|
|
||||||
|
if K > 64:
|
||||||
|
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
if USE_GK:
|
||||||
|
o_k2 = 64 + o_k1
|
||||||
|
b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0)
|
||||||
|
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
|
||||||
|
|
||||||
|
if K > 128:
|
||||||
|
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
if USE_GK:
|
||||||
|
o_k3 = 128 + o_k1
|
||||||
|
b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0)
|
||||||
|
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
|
||||||
|
|
||||||
|
if K > 192:
|
||||||
|
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
if USE_GK:
|
||||||
|
o_k4 = 192 + o_k1
|
||||||
|
b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0)
|
||||||
|
b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
|
||||||
|
b_dv *= (m_t * tl.exp(bg_last - b_g))[:, None]
|
||||||
|
b_dv += tl.load(p_dv, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
# Update dh
|
||||||
|
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
|
||||||
|
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
|
if USE_G:
|
||||||
|
b_dh1 *= bg_last_exp
|
||||||
|
b_q = b_q * b_g_exp[None, :]
|
||||||
|
if USE_GK:
|
||||||
|
b_dh1 *= tl.exp(b_gk_last1[:, None])
|
||||||
|
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
|
||||||
|
if K > 64:
|
||||||
|
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
|
||||||
|
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
|
if USE_G:
|
||||||
|
b_dh2 *= bg_last_exp
|
||||||
|
b_q = b_q * b_g_exp[None, :]
|
||||||
|
if USE_GK:
|
||||||
|
b_dh2 *= tl.exp(b_gk_last2[:, None])
|
||||||
|
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
|
||||||
|
if K > 128:
|
||||||
|
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
|
||||||
|
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
|
if USE_G:
|
||||||
|
b_dh3 *= bg_last_exp
|
||||||
|
b_q = b_q * b_g_exp[None, :]
|
||||||
|
if USE_GK:
|
||||||
|
b_dh3 *= tl.exp(b_gk_last3[:, None])
|
||||||
|
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
|
||||||
|
if K > 192:
|
||||||
|
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
|
||||||
|
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
|
||||||
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
|
if USE_G:
|
||||||
|
b_dh4 *= bg_last_exp
|
||||||
|
b_q = b_q * b_g_exp[None, :]
|
||||||
|
if USE_GK:
|
||||||
|
b_dh4 *= tl.exp(b_gk_last4[:, None])
|
||||||
|
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
|
||||||
|
|
||||||
|
if USE_INITIAL_STATE:
|
||||||
|
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 64:
|
||||||
|
p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 128:
|
||||||
|
p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
if K > 192:
|
||||||
|
p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||||
|
tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_gated_delta_rule_bwd_dhu(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
do: torch.Tensor,
|
||||||
|
dv: torch.Tensor,
|
||||||
|
g: torch.Tensor | None = None,
|
||||||
|
gk: torch.Tensor | None = None,
|
||||||
|
h0: torch.Tensor | None = None,
|
||||||
|
dht: torch.Tensor | None = None,
|
||||||
|
scale: float | None = None,
|
||||||
|
cu_seqlens: torch.LongTensor | None = None,
|
||||||
|
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||||
|
chunk_indices: torch.LongTensor | None = None,
|
||||||
|
use_exp2: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
B, T, H, K, V = *q.shape, do.shape[-1]
|
||||||
|
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||||
|
BT = 64
|
||||||
|
assert K <= 256, "current kernel does not support head dimension being larger than 256."
|
||||||
|
|
||||||
|
if chunk_indices is None and cu_seqlens is not None:
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||||
|
if cu_seqlens is None:
|
||||||
|
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||||
|
else:
|
||||||
|
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
||||||
|
|
||||||
|
dh = q.new_empty(B, NT, H, K, V)
|
||||||
|
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
|
||||||
|
dv2 = torch.empty_like(dv)
|
||||||
|
|
||||||
|
BV = 128
|
||||||
|
|
||||||
|
g = g.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[(triton.cdiv(V, BV), N * H)](
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
w=w,
|
||||||
|
g=g,
|
||||||
|
gk=gk,
|
||||||
|
dht=dht,
|
||||||
|
dh0=dh0,
|
||||||
|
do=do,
|
||||||
|
dh=dh,
|
||||||
|
dv=dv,
|
||||||
|
dv2=dv2,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
|
scale=scale,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
V=V,
|
||||||
|
BT=BT,
|
||||||
|
BV=BV,
|
||||||
|
)
|
||||||
|
return dh, dh0, dv2
|
||||||
347
src/llamafactory/third_party/triton/chunk_gated_delta_rule.py
vendored
Normal file
347
src/llamafactory/third_party/triton/chunk_gated_delta_rule.py
vendored
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
|
||||||
|
from .chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
|
||||||
|
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||||
|
from .cumsum import chunk_local_cumsum
|
||||||
|
from .solve_tril import solve_tril
|
||||||
|
from .utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
|
||||||
|
from .wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_gated_delta_rule_fwd(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
g: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
initial_state: torch.Tensor,
|
||||||
|
output_final_state: bool,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
chunk_size: int = 64,
|
||||||
|
):
|
||||||
|
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False)
|
||||||
|
# obtain WY representation. u is actually the new v.
|
||||||
|
A = chunk_scaled_dot_kkt_fwd(
|
||||||
|
k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32
|
||||||
|
)
|
||||||
|
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||||
|
w, u = recompute_w_u_fwd(
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
beta=beta,
|
||||||
|
A=A,
|
||||||
|
g=g,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||||
|
k=k,
|
||||||
|
w=w,
|
||||||
|
u=u,
|
||||||
|
g=g,
|
||||||
|
initial_state=initial_state,
|
||||||
|
output_final_state=output_final_state,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
o = chunk_fwd_o(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v_new,
|
||||||
|
h=h,
|
||||||
|
g=g,
|
||||||
|
scale=scale,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
return g, o, A, final_state
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_gated_delta_rule_bwd(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
g: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
A: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
initial_state: torch.Tensor,
|
||||||
|
do: torch.Tensor,
|
||||||
|
dht: torch.Tensor,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
chunk_size: int = 64,
|
||||||
|
):
|
||||||
|
w, u = recompute_w_u_fwd(
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
beta=beta,
|
||||||
|
A=A,
|
||||||
|
g=g,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
|
||||||
|
k=k,
|
||||||
|
w=w,
|
||||||
|
u=u,
|
||||||
|
g=g,
|
||||||
|
initial_state=initial_state,
|
||||||
|
output_final_state=False,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
dv = chunk_bwd_dv_local(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
g=g,
|
||||||
|
do=do,
|
||||||
|
scale=scale,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
w=w,
|
||||||
|
g=g,
|
||||||
|
h0=initial_state,
|
||||||
|
dht=dht,
|
||||||
|
do=do,
|
||||||
|
dv=dv,
|
||||||
|
scale=scale,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
dq, dk, dw, dg = chunk_bwd_dqkwg(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v_new,
|
||||||
|
w=w,
|
||||||
|
g=g,
|
||||||
|
h=h,
|
||||||
|
dv=dv,
|
||||||
|
do=do,
|
||||||
|
dh=dh,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
scale=scale,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
dk2, dv, db, dg2 = prepare_wy_repr_bwd(
|
||||||
|
k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
dk.add_(dk2)
|
||||||
|
dg.add_(dg2)
|
||||||
|
if dg.dtype != torch.float32:
|
||||||
|
raise ValueError(f"dg current type is {dg.dtype} , should be float32")
|
||||||
|
dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False)
|
||||||
|
return dq, dk, dv, db, dg, dh0
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@input_guard
|
||||||
|
@autocast_custom_fwd
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
g: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
initial_state: torch.Tensor,
|
||||||
|
output_final_state: bool,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
use_qk_l2norm_in_kernel: bool = False,
|
||||||
|
chunk_size: int = 64,
|
||||||
|
):
|
||||||
|
q_rstd, k_rstd = None, None
|
||||||
|
g, o, A, final_state = chunk_gated_delta_rule_fwd(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
scale=scale,
|
||||||
|
initial_state=initial_state,
|
||||||
|
output_final_state=output_final_state,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens)
|
||||||
|
ctx.scale = scale
|
||||||
|
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||||
|
ctx.chunk_size = chunk_size
|
||||||
|
return o.to(q.dtype), final_state
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@input_guard
|
||||||
|
@autocast_custom_bwd
|
||||||
|
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
|
||||||
|
q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
|
||||||
|
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
A=A,
|
||||||
|
scale=ctx.scale,
|
||||||
|
initial_state=initial_state,
|
||||||
|
do=do,
|
||||||
|
dht=dht,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_size=ctx.chunk_size,
|
||||||
|
)
|
||||||
|
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compiler.disable
|
||||||
|
def chunk_gated_delta_rule(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
g: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
scale: float = None,
|
||||||
|
initial_state: torch.Tensor = None,
|
||||||
|
output_final_state: bool = False,
|
||||||
|
use_qk_l2norm_in_kernel: bool = False,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
chunk_size: int = 64,
|
||||||
|
head_first: bool = False,
|
||||||
|
):
|
||||||
|
r"""Args:
|
||||||
|
q (torch.Tensor):
|
||||||
|
queries of shape `[B, T, H, K]`.
|
||||||
|
k (torch.Tensor):
|
||||||
|
keys of shape `[B, T, H, K]`.
|
||||||
|
v (torch.Tensor):
|
||||||
|
values of shape `[B, T, H, V]`.
|
||||||
|
g (torch.Tensor):
|
||||||
|
(forget) gating tensor (in log space!) of shape `[B, T, H]`.
|
||||||
|
beta (torch.Tensor):
|
||||||
|
betas of shape `[B, T, H]`.
|
||||||
|
scale (Optional[float]):
|
||||||
|
Scale factor for the RetNet attention scores.
|
||||||
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||||
|
initial_state (Optional[torch.Tensor]):
|
||||||
|
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||||
|
For equal-length input sequences, `N` equals the batch size `B`.
|
||||||
|
Default: `None`.
|
||||||
|
output_final_state (Optional[bool]):
|
||||||
|
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||||
|
use_qk_l2norm_in_kernel (bool):
|
||||||
|
Whether to apply L2norm to the q/k tensor internally. Default: `False`.
|
||||||
|
cu_seqlens (torch.LongTensor):
|
||||||
|
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||||
|
consistent with the FlashAttention API.
|
||||||
|
head_first (Optional[bool]):
|
||||||
|
Whether the inputs are in the head-first format. Default: `False`.
|
||||||
|
This argument has been deprecated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
o (torch.Tensor):
|
||||||
|
Outputs of shape `[B, T, H, V]`.
|
||||||
|
final_state (torch.Tensor):
|
||||||
|
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> import torch
|
||||||
|
>>> import torch.nn.functional as F
|
||||||
|
>>> from einops import rearrange
|
||||||
|
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||||
|
# inputs with equal lengths
|
||||||
|
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||||
|
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||||
|
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||||
|
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||||
|
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||||
|
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||||
|
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||||
|
>>> o, ht = chunk_gated_delta_rule(
|
||||||
|
q, k, v, g, beta,
|
||||||
|
initial_state=h0,
|
||||||
|
output_final_state=True
|
||||||
|
)
|
||||||
|
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||||
|
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||||
|
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||||
|
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||||
|
>>> o, ht = chunk_gated_delta_rule(
|
||||||
|
q, k, v, g, beta,
|
||||||
|
initial_state=h0,
|
||||||
|
output_final_state=True,
|
||||||
|
cu_seqlens=cu_seqlens
|
||||||
|
)
|
||||||
|
""" # noqa: D205
|
||||||
|
if q.dtype != k.dtype or k.dtype != v.dtype:
|
||||||
|
raise ValueError(
|
||||||
|
f"q current type is {q.dtype} , k current type is {k.dtype} ,v current type is {v.dtype} , they should are equal"
|
||||||
|
)
|
||||||
|
if q.dtype == torch.float32:
|
||||||
|
raise ValueError("ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.")
|
||||||
|
if len(beta.shape) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||||
|
)
|
||||||
|
|
||||||
|
if head_first:
|
||||||
|
warnings.warn(
|
||||||
|
"head_first is deprecated and will be removed in a future version. "
|
||||||
|
"Please use head_first=False for now instead."
|
||||||
|
)
|
||||||
|
if not head_first and q.shape[1] < q.shape[2]:
|
||||||
|
warnings.warn(
|
||||||
|
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||||
|
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||||
|
"when head_first=False was specified. "
|
||||||
|
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
||||||
|
)
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
if q.shape[0] != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||||
|
f"Please flatten variable-length inputs before processing."
|
||||||
|
)
|
||||||
|
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||||
|
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||||
|
)
|
||||||
|
if scale is None:
|
||||||
|
scale = k.shape[-1] ** -0.5
|
||||||
|
|
||||||
|
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
||||||
|
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
||||||
|
original_dtype = x.dtype
|
||||||
|
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||||
|
# Counteract verl's autocast promotion (bf16 -> fp32) by restoring original dtype
|
||||||
|
return (x * inv_norm).to(original_dtype)
|
||||||
|
|
||||||
|
if use_qk_l2norm_in_kernel:
|
||||||
|
q = l2norm(q, dim=-1, eps=1e-6)
|
||||||
|
k = l2norm(k, dim=-1, eps=1e-6)
|
||||||
|
|
||||||
|
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||||
|
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, False, chunk_size
|
||||||
|
)
|
||||||
|
return o, final_state
|
||||||
617
src/llamafactory/third_party/triton/chunk_o.py
vendored
Normal file
617
src/llamafactory/third_party/triton/chunk_o.py
vendored
Normal file
@@ -0,0 +1,617 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .utils import exp, prepare_chunk_indices, prepare_chunk_offsets
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
|
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
|
||||||
|
"USE_DW": lambda args: args["dw"] is not None,
|
||||||
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_bwd_kernel_dqkwg(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
h,
|
||||||
|
g,
|
||||||
|
g_gamma,
|
||||||
|
do,
|
||||||
|
dh,
|
||||||
|
dq,
|
||||||
|
dk,
|
||||||
|
dg,
|
||||||
|
w,
|
||||||
|
dv,
|
||||||
|
dw,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
scale,
|
||||||
|
B: tl.constexpr,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
V: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
BV: tl.constexpr,
|
||||||
|
USE_G: tl.constexpr,
|
||||||
|
USE_G_GAMMA: tl.constexpr,
|
||||||
|
USE_DW: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
gdiff,
|
||||||
|
):
|
||||||
|
i_t, i_b = tl.program_id(0), tl.program_id(1)
|
||||||
|
T_max = T
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_tg = i_t
|
||||||
|
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
total = B * T_max
|
||||||
|
T = eos - bos
|
||||||
|
else:
|
||||||
|
NT = tl.cdiv(T, BT)
|
||||||
|
i_tg = i_b * NT + i_t
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
total = B * T_max
|
||||||
|
|
||||||
|
NK = tl.cdiv(K, BK)
|
||||||
|
for i_k in range(NK):
|
||||||
|
if USE_G:
|
||||||
|
dg_k = dg + i_k * total * H
|
||||||
|
|
||||||
|
for i_h in range(H):
|
||||||
|
v_h = v + (bos * H + i_h) * V
|
||||||
|
do_h = do + (bos * H + i_h) * V
|
||||||
|
h_h = h + (i_tg * H + i_h).to(tl.int64) * K * V
|
||||||
|
dh_h = dh + (i_tg * H + i_h).to(tl.int64) * K * V
|
||||||
|
q_h = q + (bos * H + i_h) * K
|
||||||
|
k_h = k + (bos * H + i_h) * K
|
||||||
|
dq_h = dq + (bos * H + i_h) * K
|
||||||
|
dk_h = dk + (bos * H + i_h) * K
|
||||||
|
|
||||||
|
if USE_DW:
|
||||||
|
w_h = w + (bos * H + i_h) * K # noqa: F841
|
||||||
|
dw_h = dw + (bos * H + i_h) * K
|
||||||
|
dv_h = dv + (bos * H + i_h) * V
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
if IS_VARLEN:
|
||||||
|
dg_h = dg_k + i_h * T_max + bos
|
||||||
|
g_h = g + i_h * T_max + bos
|
||||||
|
else:
|
||||||
|
dg_h = dg_k + (i_b * H + i_h) * T_max
|
||||||
|
g_h = g + (i_b * H + i_h) * T_max
|
||||||
|
b_dg_last = tl.zeros(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
],
|
||||||
|
dtype=tl.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
if USE_G_GAMMA:
|
||||||
|
b_gamma = tl.load(g_gamma + i_h)
|
||||||
|
b_g = b_gamma * (tl.arange(0, BT) + 1)
|
||||||
|
b_g_last = b_gamma * min(BT, T - i_t * BT)
|
||||||
|
|
||||||
|
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||||
|
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||||
|
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
|
||||||
|
b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
|
||||||
|
|
||||||
|
for i_v in range(tl.cdiv(V, BV)):
|
||||||
|
p_v = tl.make_block_ptr(v_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
p_do = tl.make_block_ptr(do_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
p_h = tl.make_block_ptr(h_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||||
|
p_dh = tl.make_block_ptr(dh_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||||
|
|
||||||
|
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||||
|
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||||
|
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||||
|
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
b_dg_last += tl.sum(b_h * b_dh)
|
||||||
|
|
||||||
|
b_ds += tl.dot(b_do, tl.trans(b_v))
|
||||||
|
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
|
||||||
|
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
|
||||||
|
|
||||||
|
if USE_DW:
|
||||||
|
p_dv = tl.make_block_ptr(dv_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
||||||
|
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
|
||||||
|
|
||||||
|
if USE_DW:
|
||||||
|
p_dw = tl.make_block_ptr(dw_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||||
|
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
tl.debug_barrier()
|
||||||
|
|
||||||
|
p_q = tl.make_block_ptr(q_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||||
|
p_k = tl.make_block_ptr(k_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||||
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
p_dq = tl.make_block_ptr(dq_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||||
|
p_dk = tl.make_block_ptr(dk_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||||
|
|
||||||
|
o_t = i_t * BT + tl.arange(0, BT)
|
||||||
|
m_t = o_t < T
|
||||||
|
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
b_dg = tl.zeros(
|
||||||
|
[
|
||||||
|
BT,
|
||||||
|
],
|
||||||
|
dtype=tl.float32,
|
||||||
|
)
|
||||||
|
p_g = tl.make_block_ptr(g_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
|
b_g_last = tl.load(g_h + (min(i_t * BT + BT, T) - 1) * 1)
|
||||||
|
b_dg_last *= tl.exp(b_g_last)
|
||||||
|
|
||||||
|
b_dq = b_dq * tl.exp(b_g)[:, None] * scale
|
||||||
|
b_dg += tl.sum(b_dq * b_q, axis=1)
|
||||||
|
|
||||||
|
b_dk = b_dk * tl.where(m_t, tl.exp(-b_g + b_g_last), 0)[:, None]
|
||||||
|
b_dg -= tl.sum(b_k * b_dk, axis=1)
|
||||||
|
b_dg_last += tl.sum(b_dk * b_k)
|
||||||
|
|
||||||
|
if IS_VARLEN:
|
||||||
|
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
|
||||||
|
else:
|
||||||
|
p_gdiff = tl.make_block_ptr(
|
||||||
|
gdiff + i_b * H * NT * BT * BT + i_h * NT * BT * BT + i_t * BT * BT,
|
||||||
|
(BT, BT),
|
||||||
|
(BT, 1),
|
||||||
|
(0, 0),
|
||||||
|
(BT, BT),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
gdiff_ = tl.load(p_gdiff)
|
||||||
|
b_ds = b_ds * gdiff_ * scale
|
||||||
|
|
||||||
|
b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
|
||||||
|
b_dg += tl.sum(b_ds2, axis=1)
|
||||||
|
b_dg -= tl.sum(b_ds2, axis=0)
|
||||||
|
|
||||||
|
b_ds = b_ds.to(b_k.dtype)
|
||||||
|
b_dq += tl.dot(b_ds, b_k)
|
||||||
|
b_dk += tl.dot(tl.trans(b_ds), b_q)
|
||||||
|
p_dg = tl.make_block_ptr(dg_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
|
||||||
|
last_index_local = min(BT, T - i_t * BT) - 1
|
||||||
|
if last_index_local >= 0:
|
||||||
|
is_last_mask = tl.arange(0, BT) == last_index_local
|
||||||
|
b_dg = tl.where(is_last_mask, b_dg + b_dg_last, b_dg)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
|
||||||
|
|
||||||
|
elif USE_G_GAMMA:
|
||||||
|
b_dq = b_dq * exp(b_g)[:, None] * scale
|
||||||
|
b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None]
|
||||||
|
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
|
||||||
|
b_ds = b_ds.to(b_k.dtype)
|
||||||
|
b_dq += tl.dot(b_ds, b_k)
|
||||||
|
b_dk += tl.dot(tl.trans(b_ds), b_q)
|
||||||
|
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
else:
|
||||||
|
b_ds = tl.where(m_A, b_ds, 0)
|
||||||
|
b_ds = b_ds.to(b_k.dtype)
|
||||||
|
b_dq += tl.dot(b_ds, b_k)
|
||||||
|
b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
|
||||||
|
b_dq *= scale
|
||||||
|
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
|
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
|
||||||
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_bwd_kernel_dv_local(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
g,
|
||||||
|
g_gamma,
|
||||||
|
do,
|
||||||
|
dv,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
scale,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
V: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
BV: tl.constexpr,
|
||||||
|
USE_G: tl.constexpr,
|
||||||
|
USE_G_GAMMA: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
):
|
||||||
|
i_t, i_b = tl.program_id(0), tl.program_id(1)
|
||||||
|
T_max = T
|
||||||
|
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T = eos - bos
|
||||||
|
else:
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
|
||||||
|
for i_h in range(H):
|
||||||
|
offset_kh = (bos * H + i_h) * K
|
||||||
|
offset_vh = (bos * H + i_h) * V
|
||||||
|
|
||||||
|
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||||
|
for i_k in range(tl.cdiv(K, BK)):
|
||||||
|
p_k = tl.make_block_ptr(k + offset_kh, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||||
|
p_q = tl.make_block_ptr(q + offset_kh, (K, T), (1, H * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||||
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
b_A += tl.dot(b_k, b_q)
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
if IS_VARLEN:
|
||||||
|
offset_g = i_h * T_max + bos
|
||||||
|
else:
|
||||||
|
offset_g = i_b * H * T_max + i_h * T_max
|
||||||
|
|
||||||
|
p_g = tl.make_block_ptr(g + offset_g, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
|
|
||||||
|
if USE_G_GAMMA:
|
||||||
|
b_gamma = tl.load(g_gamma + i_h)
|
||||||
|
b_g = b_gamma * (tl.arange(0, BT) + 1)
|
||||||
|
|
||||||
|
o_t = i_t * BT + tl.arange(0, BT)
|
||||||
|
m_t = o_t < T
|
||||||
|
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
b_A = tl.where(m_A, b_A * tl.exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
|
||||||
|
else:
|
||||||
|
b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty)
|
||||||
|
|
||||||
|
for i_v in range(tl.cdiv(V, BV)):
|
||||||
|
p_do = tl.make_block_ptr(do + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
p_dv = tl.make_block_ptr(dv + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||||
|
b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
|
||||||
|
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
|
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
|
||||||
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_fwd_kernel_o(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
h,
|
||||||
|
g,
|
||||||
|
g_gamma,
|
||||||
|
o,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_offsets,
|
||||||
|
scale,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
N: tl.constexpr,
|
||||||
|
Hg: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
V: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
BV: tl.constexpr,
|
||||||
|
USE_G: tl.constexpr,
|
||||||
|
USE_G_GAMMA: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
):
|
||||||
|
T_max = T
|
||||||
|
for i_v in range(tl.cdiv(V, BV)):
|
||||||
|
for i_n in range(N):
|
||||||
|
if IS_VARLEN:
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T = eos - bos
|
||||||
|
NT = tl.cdiv(T, BT)
|
||||||
|
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
|
||||||
|
else:
|
||||||
|
bos, eos = i_n * T, i_n * T + T
|
||||||
|
NT = tl.cdiv(T, BT)
|
||||||
|
boh = i_n * NT
|
||||||
|
|
||||||
|
core_id = tl.program_id(0)
|
||||||
|
total_cores = tl.num_programs(0)
|
||||||
|
base_chunks_per_pid = NT // total_cores
|
||||||
|
remainder = NT % total_cores
|
||||||
|
|
||||||
|
if core_id < remainder:
|
||||||
|
chunks_this_pid = base_chunks_per_pid + 1
|
||||||
|
start_idx = core_id * chunks_this_pid
|
||||||
|
else:
|
||||||
|
chunks_this_pid = base_chunks_per_pid
|
||||||
|
start_idx = core_id * base_chunks_per_pid + remainder
|
||||||
|
|
||||||
|
# offset calculation
|
||||||
|
for i_h in range(0, H):
|
||||||
|
q_offset = (bos * Hg + i_h // (H // Hg)) * K
|
||||||
|
k_offset = (bos * Hg + i_h // (H // Hg)) * K
|
||||||
|
v_offset = (bos * H + i_h) * V
|
||||||
|
o_offset = (bos * H + i_h) * V
|
||||||
|
|
||||||
|
for i_t in range(start_idx, start_idx + chunks_this_pid):
|
||||||
|
i_tg = boh + i_t
|
||||||
|
h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V
|
||||||
|
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||||
|
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||||
|
for i_k in range(tl.cdiv(K, BK)):
|
||||||
|
p_q = tl.make_block_ptr(
|
||||||
|
q + q_offset, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
p_k = tl.make_block_ptr(
|
||||||
|
k + k_offset, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||||
|
)
|
||||||
|
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||||
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||||
|
b_o += tl.dot(b_q, b_h)
|
||||||
|
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||||
|
b_A += tl.dot(b_q, b_k)
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
if IS_VARLEN:
|
||||||
|
p_g = tl.make_block_ptr(g + bos + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
else:
|
||||||
|
p_g = tl.make_block_ptr(g + bos * H + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
|
b_o = b_o * exp(b_g)[:, None]
|
||||||
|
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
|
||||||
|
if USE_G_GAMMA:
|
||||||
|
b_gamma = tl.load(g_gamma + i_h)
|
||||||
|
b_g = b_gamma * (tl.arange(0, BT) + 1)
|
||||||
|
|
||||||
|
o_i = tl.arange(0, BT)
|
||||||
|
m_A = o_i[:, None] >= o_i[None, :]
|
||||||
|
b_A = tl.where(m_A, b_A, 0)
|
||||||
|
|
||||||
|
p_v = tl.make_block_ptr(v + v_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
p_o = tl.make_block_ptr(o + o_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||||
|
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
# to fix mma -> mma layout conversion
|
||||||
|
# already solved by triton v3.2 or higher
|
||||||
|
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||||
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_bwd_dqkwg(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
do: torch.Tensor,
|
||||||
|
h: torch.Tensor,
|
||||||
|
dh: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
g_gamma: Optional[torch.Tensor] = None,
|
||||||
|
dv: Optional[torch.Tensor] = None,
|
||||||
|
w: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
chunk_size: int = 64,
|
||||||
|
scale: float = 1.0,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||||
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
|
||||||
|
BK = 128 if cu_seqlens is None else 64
|
||||||
|
BV = 64
|
||||||
|
NK = triton.cdiv(K, BK)
|
||||||
|
dq = torch.empty_like(q)
|
||||||
|
dk = torch.empty_like(k)
|
||||||
|
g = g.transpose(1, 2).contiguous()
|
||||||
|
dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
|
||||||
|
dw = torch.empty_like(w) if w is not None else None
|
||||||
|
grid = (NT, B)
|
||||||
|
|
||||||
|
if cu_seqlens is None:
|
||||||
|
if NT * BT == T:
|
||||||
|
g_ = g.reshape(B, H, NT, BT)
|
||||||
|
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
|
||||||
|
g_diff = g_diff.clamp(-60, 60).exp()
|
||||||
|
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
|
||||||
|
else:
|
||||||
|
diff = NT * BT - T
|
||||||
|
g_ = torch.cat((g, torch.zeros(B, H, diff).to(g.device)), dim=-1).reshape(B, H, NT, BT)
|
||||||
|
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
|
||||||
|
g_diff = g_diff.clamp(-60, 60).exp()
|
||||||
|
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
|
||||||
|
bias = torch.arange(0, BT).to(g.device)
|
||||||
|
o_t = (NT - 1) * BT + bias
|
||||||
|
m_t = o_t < T
|
||||||
|
m_A = m_t[:, None] & m_t
|
||||||
|
g_diff[:, :, -1] *= m_A
|
||||||
|
else:
|
||||||
|
g_diff = None
|
||||||
|
|
||||||
|
chunk_bwd_kernel_dqkwg[grid](
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
h=h,
|
||||||
|
g=g,
|
||||||
|
g_gamma=g_gamma,
|
||||||
|
do=do,
|
||||||
|
dh=dh,
|
||||||
|
dv=dv,
|
||||||
|
w=w,
|
||||||
|
dw=dw,
|
||||||
|
dq=dq,
|
||||||
|
dk=dk,
|
||||||
|
dg=dg,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
scale=scale,
|
||||||
|
B=B,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
V=V,
|
||||||
|
BT=BT,
|
||||||
|
BK=BK,
|
||||||
|
BV=BV,
|
||||||
|
gdiff=g_diff,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dg is not None:
|
||||||
|
dg = dg.sum(0)
|
||||||
|
dg = dg.transpose(1, 2).contiguous()
|
||||||
|
return dq, dk, dw, dg
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_bwd_dv_local(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
do: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
g_gamma: Optional[torch.Tensor] = None,
|
||||||
|
scale: float = None,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
chunk_size: int = 64,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, T, H, K, V = *k.shape, do.shape[-1]
|
||||||
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
|
|
||||||
|
BK = 128
|
||||||
|
BV = 128
|
||||||
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
|
||||||
|
g = g.transpose(1, 2).contiguous()
|
||||||
|
dv = torch.empty_like(do)
|
||||||
|
grid = (NT, B)
|
||||||
|
chunk_bwd_kernel_dv_local[grid](
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
g=g,
|
||||||
|
g_gamma=g_gamma,
|
||||||
|
do=do,
|
||||||
|
dv=dv,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
scale=scale,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
V=V,
|
||||||
|
BT=BT,
|
||||||
|
BK=BK,
|
||||||
|
BV=BV,
|
||||||
|
)
|
||||||
|
return dv
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_fwd_o(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
h: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
g_gamma: Optional[torch.Tensor] = None,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
chunk_size: int = 64,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||||
|
H = v.shape[-2]
|
||||||
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) # noqa: F841
|
||||||
|
if scale is None:
|
||||||
|
scale = k.shape[-1] ** -0.5
|
||||||
|
|
||||||
|
o = torch.empty_like(v)
|
||||||
|
if cu_seqlens is None:
|
||||||
|
N, chunk_offsets = B, None
|
||||||
|
else:
|
||||||
|
N, chunk_offsets = (
|
||||||
|
len(cu_seqlens) - 1,
|
||||||
|
prepare_chunk_offsets(cu_seqlens, BT),
|
||||||
|
)
|
||||||
|
|
||||||
|
def grid(meta):
|
||||||
|
return (triton.cdiv(V, meta["BV"]), N * H)
|
||||||
|
|
||||||
|
g = g.transpose(1, 2).contiguous()
|
||||||
|
h = h.contiguous()
|
||||||
|
CV_kernel_num = 24
|
||||||
|
chunk_fwd_kernel_o[(CV_kernel_num,)](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
h,
|
||||||
|
g,
|
||||||
|
g_gamma,
|
||||||
|
o,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_offsets,
|
||||||
|
scale,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
N=N,
|
||||||
|
Hg=Hg,
|
||||||
|
K=K,
|
||||||
|
V=V,
|
||||||
|
BT=BT,
|
||||||
|
BK=128,
|
||||||
|
BV=128,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
bwd_chunk_dqkwg = chunk_bwd_dqkwg
|
||||||
|
bwd_chunk_dv_local = chunk_bwd_dv_local
|
||||||
359
src/llamafactory/third_party/triton/chunk_scaled_dot_kkt.py
vendored
Normal file
359
src/llamafactory/third_party/triton/chunk_scaled_dot_kkt.py
vendored
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .utils import prepare_chunk_indices
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||||
|
k,
|
||||||
|
g,
|
||||||
|
beta,
|
||||||
|
A,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
USE_G: tl.constexpr,
|
||||||
|
NT,
|
||||||
|
B,
|
||||||
|
TOTAL_TASKS,
|
||||||
|
):
|
||||||
|
core_id = tl.program_id(0)
|
||||||
|
num_blocks = tl.num_programs(0)
|
||||||
|
T_max = T
|
||||||
|
|
||||||
|
base_tasks_per_block = TOTAL_TASKS // num_blocks
|
||||||
|
remainder_tasks = TOTAL_TASKS % num_blocks
|
||||||
|
|
||||||
|
if core_id < remainder_tasks:
|
||||||
|
tasks_this_core = base_tasks_per_block + 1
|
||||||
|
start_idx = core_id * tasks_this_core
|
||||||
|
else:
|
||||||
|
tasks_this_core = base_tasks_per_block
|
||||||
|
start_idx = core_id * base_tasks_per_block + remainder_tasks
|
||||||
|
|
||||||
|
for idx in range(start_idx, start_idx + tasks_this_core):
|
||||||
|
i_b = idx // NT
|
||||||
|
local_idx = idx % NT
|
||||||
|
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n = tl.load(chunk_indices + local_idx * 2).to(tl.int32)
|
||||||
|
i_t = tl.load(chunk_indices + local_idx * 2 + 1).to(tl.int32)
|
||||||
|
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
|
||||||
|
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T_local = eos - bos
|
||||||
|
else:
|
||||||
|
bos, eos = 0, T
|
||||||
|
i_t = local_idx
|
||||||
|
T_local = T
|
||||||
|
|
||||||
|
for i_h in range(H):
|
||||||
|
k_batch_off = i_b * T_max * H * K
|
||||||
|
beta_batch_off = i_b * H * T_max
|
||||||
|
g_batch_off = i_b * H * T_max
|
||||||
|
A_batch_off = i_b * T_max * H * BT
|
||||||
|
|
||||||
|
p_beta = tl.make_block_ptr(
|
||||||
|
beta + beta_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
|
||||||
|
)
|
||||||
|
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||||
|
|
||||||
|
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||||
|
for i_k in range(tl.cdiv(K, BK)):
|
||||||
|
p_k = tl.make_block_ptr(
|
||||||
|
k + k_batch_off + (bos * H + i_h) * K,
|
||||||
|
(T_local, K),
|
||||||
|
(H * K, 1),
|
||||||
|
(i_t * BT, i_k * BK),
|
||||||
|
(BT, BK),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
dot_product = tl.dot(b_k, tl.trans(b_k))
|
||||||
|
|
||||||
|
o_t = i_t * BT + tl.arange(0, BT)
|
||||||
|
o_t = o_t.to(tl.float32)
|
||||||
|
T_mask = (o_t < T_local).to(tl.float32)
|
||||||
|
|
||||||
|
row_indices = tl.arange(0, BT)[:, None]
|
||||||
|
col_indices = tl.arange(0, BT)[None, :]
|
||||||
|
tril_mask = (row_indices > col_indices).to(tl.float32)
|
||||||
|
tril_mask = tril_mask * T_mask[:, None]
|
||||||
|
masked_dot = dot_product * tril_mask
|
||||||
|
b_A += masked_dot
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
p_g = tl.make_block_ptr(
|
||||||
|
g + g_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
|
||||||
|
)
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
|
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||||
|
b_g_diff = tl.minimum(tl.maximum(b_g_diff, -50.0), 50.0)
|
||||||
|
b_A *= tl.exp(b_g_diff)
|
||||||
|
b_A *= b_beta[:, None]
|
||||||
|
|
||||||
|
p_A = tl.make_block_ptr(
|
||||||
|
A + A_batch_off + (bos * H + i_h) * BT, (T_local, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||||
|
)
|
||||||
|
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
|
@triton.autotune(configs=[triton.Config({"BK": BK}) for BK in [32, 64]], key=["BC"])
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
|
||||||
|
k,
|
||||||
|
g,
|
||||||
|
beta,
|
||||||
|
A,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BC: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
NC: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
):
|
||||||
|
i_t, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||||
|
i_i, i_j = i_c // NC, i_c % NC
|
||||||
|
|
||||||
|
for i_h in range(H):
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T_val = eos - bos
|
||||||
|
else:
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
T_val = T
|
||||||
|
|
||||||
|
should_compute = (i_t * BT + i_i * BC < T_val) and (i_i > i_j)
|
||||||
|
|
||||||
|
if should_compute:
|
||||||
|
k_ptr = k + (bos * H + i_h) * K
|
||||||
|
g_ptr = g + (bos * H + i_h) * K
|
||||||
|
A_ptr = A + (bos * H + i_h) * BT
|
||||||
|
|
||||||
|
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T_val,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,))
|
||||||
|
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||||
|
|
||||||
|
b_A = tl.zeros([BC, BC], dtype=tl.float32)
|
||||||
|
for i_k in range(tl.cdiv(K, BK)):
|
||||||
|
p_k = tl.make_block_ptr(
|
||||||
|
k_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
|
||||||
|
)
|
||||||
|
p_g = tl.make_block_ptr(
|
||||||
|
g_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
|
||||||
|
)
|
||||||
|
b_kt = tl.make_block_ptr(
|
||||||
|
k_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
|
||||||
|
)
|
||||||
|
p_gk = tl.make_block_ptr(
|
||||||
|
g_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
o_k = i_k * BK + tl.arange(0, BK)
|
||||||
|
m_k = o_k < K
|
||||||
|
b_gn = tl.load(g_ptr + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.exp(b_g - b_gn[None, :])
|
||||||
|
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||||
|
b_kt = tl.load(b_kt, boundary_check=(0, 1)) * tl.exp(b_gn[:, None] - b_gk)
|
||||||
|
b_A += tl.dot(b_k, b_kt)
|
||||||
|
b_A *= b_beta[:, None]
|
||||||
|
|
||||||
|
p_A = tl.make_block_ptr(A_ptr, (T_val, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||||
|
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
|
||||||
|
k,
|
||||||
|
g,
|
||||||
|
beta,
|
||||||
|
A,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BC: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
):
|
||||||
|
i_t, i_i, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||||
|
|
||||||
|
for i_h in range(H):
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T_val = eos - bos
|
||||||
|
else:
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
T_val = T
|
||||||
|
|
||||||
|
should_compute = i_t * BT + i_i * BC < T_val
|
||||||
|
|
||||||
|
if should_compute:
|
||||||
|
o_i = tl.arange(0, BC)
|
||||||
|
o_k = tl.arange(0, BK)
|
||||||
|
m_k = o_k < K
|
||||||
|
m_A = (i_t * BT + i_i * BC + o_i) < T_val
|
||||||
|
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
|
||||||
|
|
||||||
|
p_k = tl.make_block_ptr(
|
||||||
|
k + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
|
||||||
|
)
|
||||||
|
p_g = tl.make_block_ptr(
|
||||||
|
g + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
|
||||||
|
)
|
||||||
|
p_beta = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
|
||||||
|
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.load(p_beta, mask=m_A, other=0)[:, None]
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
|
||||||
|
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
|
||||||
|
|
||||||
|
for j in range(0, min(BC, T_val - i_t * BT - i_i * BC)):
|
||||||
|
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
|
||||||
|
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
|
||||||
|
b_A = tl.sum(b_k * b_kt[None, :] * tl.exp(b_g - b_gk[None, :]), 1)
|
||||||
|
# 转化成f32
|
||||||
|
o_i_tmp = o_i.to(tl.float32)
|
||||||
|
b_A = tl.where(o_i_tmp > j, b_A, 0.0)
|
||||||
|
|
||||||
|
tl.store(A + o_A + j, b_A, mask=m_A)
|
||||||
|
p_kt += H * K
|
||||||
|
p_gk += H * K
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_scaled_dot_kkt_fwd(
|
||||||
|
k: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
gk: Optional[torch.Tensor] = None,
|
||||||
|
beta: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
chunk_size: int = 64,
|
||||||
|
output_dtype: torch.dtype = torch.float32,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""Compute beta * K * K^T.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k (torch.Tensor):
|
||||||
|
The key tensor of shape `[B, T, H, K]`.
|
||||||
|
beta (torch.Tensor):
|
||||||
|
The beta tensor of shape `[B, T, H]`.
|
||||||
|
g (torch.Tensor):
|
||||||
|
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||||
|
gk (torch.Tensor):
|
||||||
|
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
||||||
|
cu_seqlens (torch.LongTensor):
|
||||||
|
The cumulative sequence lengths of the input tensor.
|
||||||
|
Default: None
|
||||||
|
chunk_size (int):
|
||||||
|
The chunk size. Default: 64.
|
||||||
|
output_dtype (torch.dtype):
|
||||||
|
The dtype of the output tensor. Default: `torch.float32`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||||
|
"""
|
||||||
|
B, T, H, K = k.shape
|
||||||
|
BT = chunk_size
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
beta = beta.transpose(1, 2).contiguous()
|
||||||
|
g = g.transpose(1, 2).contiguous()
|
||||||
|
BK = 128
|
||||||
|
kernel_num = 24
|
||||||
|
|
||||||
|
if gk is None:
|
||||||
|
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||||
|
chunk_scaled_dot_kkt_fwd_kernel[(kernel_num,)](
|
||||||
|
k=k,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
A=A,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
BT=BT,
|
||||||
|
BK=BK,
|
||||||
|
NT=NT,
|
||||||
|
B=B,
|
||||||
|
TOTAL_TASKS=B * NT,
|
||||||
|
)
|
||||||
|
return A
|
||||||
|
|
||||||
|
BC = min(16, BT)
|
||||||
|
NC = triton.cdiv(BT, BC)
|
||||||
|
BK = max(triton.next_power_of_2(K), 16)
|
||||||
|
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||||
|
grid = (NT, NC * NC, B)
|
||||||
|
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
|
||||||
|
k=k,
|
||||||
|
g=gk,
|
||||||
|
beta=beta,
|
||||||
|
A=A,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
BT=BT,
|
||||||
|
BC=BC,
|
||||||
|
NC=NC,
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = (NT, NC, B)
|
||||||
|
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
|
||||||
|
k=k,
|
||||||
|
g=gk,
|
||||||
|
beta=beta,
|
||||||
|
A=A,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
BT=BT,
|
||||||
|
BC=BC,
|
||||||
|
BK=BK,
|
||||||
|
)
|
||||||
|
return A
|
||||||
147
src/llamafactory/third_party/triton/cumsum.py
vendored
Normal file
147
src/llamafactory/third_party/triton/cumsum.py
vendored
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .utils import prepare_chunk_indices
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None}
|
||||||
|
)
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def chunk_local_cumsum_scalar_kernel(
|
||||||
|
s,
|
||||||
|
o,
|
||||||
|
scale,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T,
|
||||||
|
B: tl.constexpr,
|
||||||
|
H: tl.constexpr,
|
||||||
|
BLOCK_T: tl.constexpr,
|
||||||
|
REVERSE: tl.constexpr,
|
||||||
|
HAS_SCALE: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
HEAD_FIRST: tl.constexpr,
|
||||||
|
CHUNK_SIZE: tl.constexpr = 64,
|
||||||
|
):
|
||||||
|
i_block, i_b = tl.program_id(0), tl.program_id(1)
|
||||||
|
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
|
||||||
|
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_s, i_block = (
|
||||||
|
tl.load(chunk_indices + i_block * 2).to(tl.int32),
|
||||||
|
tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
|
||||||
|
)
|
||||||
|
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
|
||||||
|
T = eos - bos
|
||||||
|
else:
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
|
||||||
|
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||||
|
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||||
|
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
|
||||||
|
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
|
||||||
|
b_s = tl.trans(b_s, (1, 0, 2))
|
||||||
|
b_o = tl.cumsum(b_s, axis=0)
|
||||||
|
if REVERSE:
|
||||||
|
b_z = tl.sum(b_s, axis=0)
|
||||||
|
b_o = -b_o + b_z[None] + b_s
|
||||||
|
if HAS_SCALE:
|
||||||
|
b_o *= scale
|
||||||
|
b_o = tl.trans(b_o, (1, 0, 2))
|
||||||
|
b_o = tl.reshape(b_o, (BLOCK_T, H))
|
||||||
|
|
||||||
|
tl.store(ptr_o, b_o.to(ptr_o.dtype.element_ty), boundary_check=(0,))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_local_cumsum_scalar(
|
||||||
|
g: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
reverse: bool = False,
|
||||||
|
scale: float = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
head_first: bool = False,
|
||||||
|
output_dtype: Optional[torch.dtype] = torch.float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, T, H = g.shape
|
||||||
|
if chunk_size != 2 ** (chunk_size.bit_length() - 1):
|
||||||
|
raise ValueError(f"chunk_size must be a power of 2, chunk_size is{chunk_size}")
|
||||||
|
# We adjust the tiling strategy to prevent overflow in in backward passes and context parallel scenarios
|
||||||
|
# while maximizing UB utilization where possible.
|
||||||
|
# The tiling strategy is as follows:
|
||||||
|
# 1. BT must be greater than or equal to chunk_size.
|
||||||
|
# 2. UB estimation varies directly with H.
|
||||||
|
# 3. BT in reverse mode is smaller than in forward mode.
|
||||||
|
BT = max(chunk_size, triton.next_power_of_2((1 << 11 if reverse else 1 << 12) // H))
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||||
|
grid = (NT, B)
|
||||||
|
chunk_local_cumsum_scalar_kernel[grid](
|
||||||
|
s=g_org,
|
||||||
|
o=g,
|
||||||
|
scale=scale,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
T=T,
|
||||||
|
B=B,
|
||||||
|
H=H,
|
||||||
|
BLOCK_T=BT,
|
||||||
|
HEAD_FIRST=head_first,
|
||||||
|
REVERSE=reverse,
|
||||||
|
CHUNK_SIZE=chunk_size,
|
||||||
|
)
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_local_cumsum(
|
||||||
|
g: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
reverse: bool = False,
|
||||||
|
scale: float = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
head_first: bool = False,
|
||||||
|
output_dtype: Optional[torch.dtype] = torch.float,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
if g.shape[0] != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}"
|
||||||
|
)
|
||||||
|
if len(g.shape) == 3:
|
||||||
|
return chunk_local_cumsum_scalar(
|
||||||
|
g=g,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
reverse=reverse,
|
||||||
|
scale=scale,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
head_first=head_first,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported input shape {g.shape}, "
|
||||||
|
f"which should be (B, T, H, D) if `head_first=False` "
|
||||||
|
f"or (B, H, T, D) otherwise"
|
||||||
|
)
|
||||||
272
src/llamafactory/third_party/triton/solve_tril.py
vendored
Normal file
272
src/llamafactory/third_party/triton/solve_tril.py
vendored
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .utils import input_guard, make_tensor_descriptor, prepare_chunk_indices
|
||||||
|
|
||||||
|
|
||||||
|
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
|
@triton.jit(do_not_specialize=["T", "TPP"])
|
||||||
|
def solve_tril_16x16_kernel(
|
||||||
|
A,
|
||||||
|
Ai,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
TPP: tl.constexpr,
|
||||||
|
USE_TMA: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
DOT_PRECISION: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
|
||||||
|
i_b, i_h = pid_bh // H, pid_bh % H
|
||||||
|
|
||||||
|
base_t = pid_t * TPP
|
||||||
|
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
|
||||||
|
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
|
||||||
|
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T_eff = eos - bos
|
||||||
|
else:
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
T_eff = T
|
||||||
|
|
||||||
|
o_i = tl.arange(0, 16) # noqa: F841
|
||||||
|
o_i_fp32 = tl.arange(0, 16).to(tl.float32)
|
||||||
|
m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
|
||||||
|
m_I = o_i_fp32[:, None] == o_i_fp32[None, :]
|
||||||
|
|
||||||
|
A = A + (bos * H + i_h) * BT
|
||||||
|
Ai = Ai + (bos * H + i_h) * BT
|
||||||
|
|
||||||
|
for tpp in tl.static_range(0, TPP):
|
||||||
|
tile_t = base_t + tpp
|
||||||
|
tile_row = tile_t * 16
|
||||||
|
|
||||||
|
offset = (tile_t * 16) % BT
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
|
p_A = tl.make_block_ptr(A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0))
|
||||||
|
b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
else:
|
||||||
|
desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16])
|
||||||
|
desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16])
|
||||||
|
b_A_raw = desc.load([tile_row, offset]).to(tl.float32)
|
||||||
|
|
||||||
|
b_A_neg = -b_A_raw
|
||||||
|
b_A = b_A_neg * m_A
|
||||||
|
for i in range(2, min(16, T_eff - tile_row)):
|
||||||
|
slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1])
|
||||||
|
b_a_val = tl.reshape(slice_res, (16,), can_reorder=True)
|
||||||
|
dot_prod = tl.sum(b_a_val[:, None] * b_A, 0)
|
||||||
|
b_a_update = b_a_val + dot_prod
|
||||||
|
b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
|
||||||
|
b_A += m_I
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
|
p_Ai = tl.make_block_ptr(Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0))
|
||||||
|
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||||
|
else:
|
||||||
|
desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
|
@triton.jit(do_not_specialize=["T", "TPP"])
|
||||||
|
def merge_16x16_to_32x32_inverse_kernel(
|
||||||
|
A,
|
||||||
|
Ai,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
TPP: tl.constexpr,
|
||||||
|
USE_TMA: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
DOT_PRECISION: tl.constexpr,
|
||||||
|
):
|
||||||
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||||
|
i_b, i_h = i_bh // H, i_bh % H
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T = eos - bos
|
||||||
|
else:
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
|
||||||
|
o_i = tl.arange(0, 16)
|
||||||
|
m_A = o_i[:, None] > o_i[None, :]
|
||||||
|
m_I = o_i[:, None] == o_i[None, :]
|
||||||
|
A += (bos * H + i_h) * BT
|
||||||
|
Ai += (bos * H + i_h) * BT
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
|
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
|
||||||
|
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
|
||||||
|
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
else:
|
||||||
|
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
|
||||||
|
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
|
||||||
|
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
|
||||||
|
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
|
||||||
|
|
||||||
|
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
|
||||||
|
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
|
||||||
|
|
||||||
|
for i in range(2, min(16, T - i_t * BT)):
|
||||||
|
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
|
||||||
|
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
|
||||||
|
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
|
||||||
|
for i in range(16 + 2, min(32, T - i_t * BT)):
|
||||||
|
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
|
||||||
|
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
|
||||||
|
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
|
||||||
|
|
||||||
|
b_Ai_11 += m_I
|
||||||
|
b_Ai_22 += m_I
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
|
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
|
||||||
|
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
else:
|
||||||
|
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
|
||||||
|
|
||||||
|
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
|
||||||
|
|
||||||
|
if not USE_TMA:
|
||||||
|
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
|
||||||
|
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
|
||||||
|
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
|
||||||
|
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||||
|
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||||
|
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||||
|
else:
|
||||||
|
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||||
|
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||||
|
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def solve_tril_64x64_kernel(
|
||||||
|
A,
|
||||||
|
Ai,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T,
|
||||||
|
H: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
USE_TMA: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
DOT_PRECISION: tl.constexpr,
|
||||||
|
):
|
||||||
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||||
|
i_b, i_h = i_bh // H, i_bh % H
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T = eos - bos
|
||||||
|
else:
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
o_i = tl.arange(0, 64)
|
||||||
|
m_I = o_i[:, None] == o_i[None, :]
|
||||||
|
|
||||||
|
A = A + (bos * H + i_h) * BT
|
||||||
|
Ai = Ai + (bos * H + i_h) * 64
|
||||||
|
|
||||||
|
offset = (i_t * 64) % BT
|
||||||
|
if not USE_TMA:
|
||||||
|
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 64, offset), (64, 64), (1, 0))
|
||||||
|
b_A = -tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
else:
|
||||||
|
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [64, 64])
|
||||||
|
desc_o = make_tensor_descriptor(Ai, [T, 64], [H * 64, 1], [64, 64])
|
||||||
|
b_A = -desc.load([i_t * 64, offset]).to(tl.float32)
|
||||||
|
|
||||||
|
for i in range(2, min(64, T - i_t * 64)):
|
||||||
|
b_a = -tl.load(A + (i_t * 64 + i) * H * BT + o_i + offset)
|
||||||
|
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||||
|
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
|
||||||
|
b_A += m_I
|
||||||
|
if not USE_TMA:
|
||||||
|
p_Ai = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (64, 64), (1, 0))
|
||||||
|
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||||
|
else:
|
||||||
|
desc_o.store([i_t * 64, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||||
|
|
||||||
|
|
||||||
|
@input_guard
|
||||||
|
def solve_tril(
|
||||||
|
A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute the inverse of the matrix I + A
|
||||||
|
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
A (torch.Tensor):
|
||||||
|
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
||||||
|
cu_seqlens (torch.Tensor):
|
||||||
|
The cumulative sequence lengths of the input tensor. Default: `None`.
|
||||||
|
output_dtype (torch.dtype):
|
||||||
|
The dtype of the output tensor. Default: `torch.float`.
|
||||||
|
If `None`, the output dtype will be the same as the input dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(I + A)^-1 with the same shape as A
|
||||||
|
""" # noqa: D205
|
||||||
|
if A.shape[-1] not in [16, 32, 64]:
|
||||||
|
raise ValueError(f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}")
|
||||||
|
output_dtype = A.dtype if output_dtype is None else output_dtype
|
||||||
|
|
||||||
|
B, T, H, BT = A.shape
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
|
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||||
|
|
||||||
|
Ai = torch.zeros_like(A, dtype=output_dtype)
|
||||||
|
|
||||||
|
if BT == 16:
|
||||||
|
merge_fn = solve_tril_16x16_kernel
|
||||||
|
elif BT == 32:
|
||||||
|
merge_fn = merge_16x16_to_32x32_inverse_kernel
|
||||||
|
elif BT == 64:
|
||||||
|
merge_fn = solve_tril_64x64_kernel
|
||||||
|
|
||||||
|
merge_fn[NT, B * H](
|
||||||
|
A=A,
|
||||||
|
Ai=Ai,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
T=T,
|
||||||
|
H=H,
|
||||||
|
BT=BT,
|
||||||
|
USE_TMA=False,
|
||||||
|
DOT_PRECISION=FLA_TRIL_PRECISION,
|
||||||
|
)
|
||||||
|
return Ai
|
||||||
359
src/llamafactory/third_party/triton/utils.py
vendored
Normal file
359
src/llamafactory/third_party/triton/utils.py
vendored
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from collections.abc import Callable
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
import triton.language.extra.libdevice as tldevice
|
||||||
|
import triton.runtime.driver as driver
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_cache(fn: Optional[Callable[..., torch.Tensor]] = None, *, maxsize: int = 1) -> Any:
|
||||||
|
"""A decorator that caches the most recent results of a function with tensor inputs.
|
||||||
|
|
||||||
|
This decorator will store the outputs of the decorated function for the most recent
|
||||||
|
set of input tensors, up to `maxsize` entries. If the function is called again with
|
||||||
|
the same input tensors, it will return the cached result.
|
||||||
|
|
||||||
|
When maxsize=1 (default), the behavior is identical to caching only the most recent result.
|
||||||
|
Can be used as @tensor_cache or @tensor_cache(maxsize=n).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn (Callable[..., torch.Tensor], optional):
|
||||||
|
The function to be decorated when used without parentheses.
|
||||||
|
maxsize (int):
|
||||||
|
Maximum number of input combinations to cache. Default is 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[..., torch.Tensor]:
|
||||||
|
A wrapped version of the input function with caching.
|
||||||
|
"""
|
||||||
|
if maxsize < 1:
|
||||||
|
raise ValueError("maxsize must be at least 1")
|
||||||
|
|
||||||
|
def _is_match(a: Any, b: Any) -> bool:
|
||||||
|
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
|
||||||
|
return a is b
|
||||||
|
try:
|
||||||
|
return a == b
|
||||||
|
except Exception:
|
||||||
|
return a is b
|
||||||
|
|
||||||
|
def _make_wrapper(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||||
|
cache: list = []
|
||||||
|
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
for i, (cached_args, cached_kwargs, cached_result) in enumerate(cache):
|
||||||
|
if len(args) == len(cached_args) and len(kwargs) == len(cached_kwargs):
|
||||||
|
if all(_is_match(a, b) for a, b in zip(args, cached_args)) and all(
|
||||||
|
k in cached_kwargs and _is_match(v, cached_kwargs[k]) for k, v in kwargs.items()
|
||||||
|
):
|
||||||
|
if i != 0:
|
||||||
|
cache.insert(0, cache.pop(i))
|
||||||
|
return cached_result
|
||||||
|
|
||||||
|
result = fn(*args, **kwargs)
|
||||||
|
cache.insert(0, (args, kwargs, result))
|
||||||
|
if len(cache) > maxsize:
|
||||||
|
cache.pop()
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
if fn is not None:
|
||||||
|
return _make_wrapper(fn)
|
||||||
|
return _make_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@tensor_cache
|
||||||
|
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||||
|
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
|
||||||
|
|
||||||
|
@tensor_cache(maxsize=3)
|
||||||
|
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
||||||
|
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
|
||||||
|
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||||
|
|
||||||
|
|
||||||
|
def get_abs_err(x, y):
|
||||||
|
return (x.detach() - y.detach()).flatten().abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
|
def get_err_ratio(x, y):
|
||||||
|
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
|
||||||
|
base = (x.detach()).flatten().square().mean().sqrt().item()
|
||||||
|
return err / (base + 1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
|
||||||
|
abs_atol = get_abs_err(ref, tri)
|
||||||
|
msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
|
||||||
|
logger.info(msg)
|
||||||
|
error_rate = get_err_ratio(ref, tri)
|
||||||
|
if abs_atol <= err_atol:
|
||||||
|
return
|
||||||
|
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
|
||||||
|
if error_rate > ratio:
|
||||||
|
warnings.warn(msg)
|
||||||
|
else:
|
||||||
|
assert error_rate < ratio, msg
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
|
||||||
|
# For Triton 3.3.x
|
||||||
|
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
|
||||||
|
elif hasattr(triton.language, "make_tensor_descriptor"):
|
||||||
|
# For Triton 3.4.x and later
|
||||||
|
make_tensor_descriptor = triton.language.make_tensor_descriptor
|
||||||
|
else:
|
||||||
|
"""
|
||||||
|
Fallback implementation when TMA is not supported.
|
||||||
|
Returns None to indicate TMA descriptors are unavailable.
|
||||||
|
Just make triton compiler happy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def make_tensor_descriptor(
|
||||||
|
base,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
block_shape,
|
||||||
|
_builder=None,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_available_device() -> str:
|
||||||
|
try:
|
||||||
|
return triton.runtime.driver.active.get_current_target().backend
|
||||||
|
except BaseException:
|
||||||
|
_cpu_device_warning()
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def map_triton_backend_to_torch_device() -> str:
|
||||||
|
backend = get_available_device() # 'cuda' | 'hip' | 'xpu' | 'cpu' | ...
|
||||||
|
return {"cuda": "cuda", "hip": "cuda", "xpu": "xpu"}.get(backend, backend)
|
||||||
|
|
||||||
|
|
||||||
|
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
||||||
|
device_torch_lib = getattr(torch, device)
|
||||||
|
device_platform = get_available_device()
|
||||||
|
is_amd = device_platform == "hip"
|
||||||
|
is_nvidia = device_platform == "cuda"
|
||||||
|
is_nvidia_hopper = is_nvidia and (
|
||||||
|
"NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
|
||||||
|
)
|
||||||
|
|
||||||
|
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
|
||||||
|
is_tma_supported = (
|
||||||
|
(is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9)
|
||||||
|
and os.environ.get("FLA_NO_USE_TMA", "0") != "1"
|
||||||
|
and (
|
||||||
|
hasattr(triton.language, "_experimental_make_tensor_descriptor")
|
||||||
|
or hasattr(triton.language, "make_tensor_descriptor")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_nvidia and not is_tf32_supported:
|
||||||
|
# Make old card happy, since triton will use tf32 by default.
|
||||||
|
# This is a workaround for old nvidia card.
|
||||||
|
os.environ["TRITON_F32_DEFAULT"] = "ieee"
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def check_pytorch_version(version_s: str = "2.4") -> bool:
|
||||||
|
return version.parse(torch.__version__) >= version.parse(version_s)
|
||||||
|
|
||||||
|
|
||||||
|
if check_pytorch_version("2.4"):
|
||||||
|
device = "cuda" if device == "cpu" else device
|
||||||
|
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
||||||
|
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
||||||
|
|
||||||
|
def custom_device_ctx(index: int):
|
||||||
|
return device_torch_lib.device(index)
|
||||||
|
else:
|
||||||
|
assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0."
|
||||||
|
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
||||||
|
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
||||||
|
|
||||||
|
def custom_device_ctx(index: int):
|
||||||
|
return torch.cuda.device(index)
|
||||||
|
|
||||||
|
|
||||||
|
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||||
|
"""A decorator to make sure all input tensors are contiguous and set the device based on input tensors."""
|
||||||
|
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
|
||||||
|
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
|
||||||
|
|
||||||
|
tensor = None
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, torch.Tensor):
|
||||||
|
tensor = arg
|
||||||
|
break
|
||||||
|
if tensor is None:
|
||||||
|
for value in kwargs.values():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
tensor = value
|
||||||
|
break
|
||||||
|
|
||||||
|
if tensor is not None:
|
||||||
|
ctx = custom_device_ctx(tensor.device.index)
|
||||||
|
else:
|
||||||
|
ctx = contextlib.nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
return fn(*contiguous_args, **contiguous_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _cpu_device_warning():
|
||||||
|
warnings.warn(("Triton is not supported on current platform, roll back to CPU."), stacklevel=1)
|
||||||
|
|
||||||
|
|
||||||
|
@tensor_cache
|
||||||
|
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
||||||
|
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
|
||||||
|
|
||||||
|
|
||||||
|
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
|
||||||
|
exp = tldevice.fast_expf
|
||||||
|
exp2 = tldevice.exp2
|
||||||
|
log = tldevice.fast_logf
|
||||||
|
log2 = tldevice.fast_log2f
|
||||||
|
else:
|
||||||
|
exp = tl.exp
|
||||||
|
exp2 = tl.math.exp2
|
||||||
|
log = tl.log
|
||||||
|
log2 = tl.log2
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_max_shared_mem():
|
||||||
|
try:
|
||||||
|
return [
|
||||||
|
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
|
||||||
|
for i in range(device_torch_lib.device_count())
|
||||||
|
]
|
||||||
|
except BaseException:
|
||||||
|
_cpu_device_warning()
|
||||||
|
return [-1]
|
||||||
|
|
||||||
|
|
||||||
|
class Backend(Enum):
|
||||||
|
ADA = 101376 # RTX 4090
|
||||||
|
AMPERE = 166912 # A100
|
||||||
|
HOPPER = 232448 # H100
|
||||||
|
DEFAULT = 102400 # Default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_shared_memory(cls, arch: str) -> int:
|
||||||
|
try:
|
||||||
|
return cls[arch.upper()].value
|
||||||
|
except KeyError:
|
||||||
|
return cls.DEFAULT.value
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
||||||
|
try:
|
||||||
|
device_shared_mem_list = get_all_max_shared_mem()
|
||||||
|
max_shared_memory = device_shared_mem_list[tensor_idx]
|
||||||
|
return max_shared_memory >= Backend.get_shared_memory(arch)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_autotune_config(
|
||||||
|
multibuffer_list: tuple = (False,),
|
||||||
|
unit_flag_list: tuple = (False,),
|
||||||
|
limit_auto_multi_buffer_only_for_local_buffer_list: tuple = (False,),
|
||||||
|
limit_auto_multi_buffer_of_local_buffer_list: tuple = ("no-l0c",),
|
||||||
|
set_workspace_multibuffer_list: tuple = (2, 4),
|
||||||
|
enable_hivm_auto_cv_balance_list: tuple = (True,),
|
||||||
|
tile_mix_vector_loop_num_list: tuple = (2, 4),
|
||||||
|
tile_mix_cube_loop_num_list: tuple = (2, 4),
|
||||||
|
):
|
||||||
|
configs = []
|
||||||
|
for (
|
||||||
|
multibuffer,
|
||||||
|
unit_flag,
|
||||||
|
limit_auto_multi_buffer_only_for_local_buffer,
|
||||||
|
limit_auto_multi_buffer_of_local_buffer,
|
||||||
|
) in itertools.product(
|
||||||
|
list(multibuffer_list),
|
||||||
|
list(unit_flag_list),
|
||||||
|
list(limit_auto_multi_buffer_only_for_local_buffer_list),
|
||||||
|
list(limit_auto_multi_buffer_of_local_buffer_list),
|
||||||
|
):
|
||||||
|
base_config_dict = {
|
||||||
|
"multibuffer": multibuffer,
|
||||||
|
"unit_flag": unit_flag,
|
||||||
|
"limit_auto_multi_buffer_only_for_local_buffer": limit_auto_multi_buffer_only_for_local_buffer,
|
||||||
|
"limit_auto_multi_buffer_of_local_buffer": limit_auto_multi_buffer_of_local_buffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit_auto_multi_buffer_only_for_local_buffer:
|
||||||
|
configs.append(triton.Config(base_config_dict))
|
||||||
|
else:
|
||||||
|
for (
|
||||||
|
set_workspace_multibuffer,
|
||||||
|
enable_hivm_auto_cv_balance,
|
||||||
|
tile_mix_vector_loop,
|
||||||
|
tile_mix_cube_loop,
|
||||||
|
) in itertools.product(
|
||||||
|
list(set_workspace_multibuffer_list),
|
||||||
|
list(enable_hivm_auto_cv_balance_list),
|
||||||
|
list(tile_mix_vector_loop_num_list),
|
||||||
|
list(tile_mix_cube_loop_num_list),
|
||||||
|
):
|
||||||
|
full_config_dict = base_config_dict.copy()
|
||||||
|
full_config_dict.update(
|
||||||
|
{
|
||||||
|
"set_workspace_multibuffer": set_workspace_multibuffer,
|
||||||
|
"enable_hivm_auto_cv_balance": enable_hivm_auto_cv_balance,
|
||||||
|
"tile_mix_vector_loop": tile_mix_vector_loop,
|
||||||
|
"tile_mix_cube_loop": tile_mix_cube_loop,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
configs.append(triton.Config(full_config_dict))
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_npu_properties():
|
||||||
|
return driver.active.utils.get_device_properties(torch.npu.current_device())
|
||||||
387
src/llamafactory/third_party/triton/wy_fast.py
vendored
Normal file
387
src/llamafactory/third_party/triton/wy_fast.py
vendored
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||||
|
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .utils import exp, prepare_chunk_indices
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def prepare_wy_repr_bwd_kernel(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
beta,
|
||||||
|
g,
|
||||||
|
A,
|
||||||
|
dw,
|
||||||
|
du,
|
||||||
|
dk,
|
||||||
|
dv,
|
||||||
|
dbeta,
|
||||||
|
dg,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T,
|
||||||
|
B,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
V: tl.constexpr,
|
||||||
|
NT: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
BV: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
):
|
||||||
|
core_id = tl.program_id(0)
|
||||||
|
total_cores = tl.num_programs(0)
|
||||||
|
T_max = T
|
||||||
|
|
||||||
|
base_chunks_per_pid = NT // total_cores
|
||||||
|
remainder_chunks = NT % total_cores
|
||||||
|
|
||||||
|
if core_id < remainder_chunks:
|
||||||
|
chunks_this_pid = base_chunks_per_pid + 1
|
||||||
|
start_idx = core_id * chunks_this_pid
|
||||||
|
else:
|
||||||
|
chunks_this_pid = base_chunks_per_pid
|
||||||
|
start_idx = core_id * chunks_this_pid + remainder_chunks
|
||||||
|
|
||||||
|
for idx in range(start_idx, start_idx + chunks_this_pid):
|
||||||
|
for i_b in range(B):
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n, i_t = (
|
||||||
|
tl.load(chunk_indices + idx * 2).to(tl.int32),
|
||||||
|
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
|
||||||
|
)
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
T = eos - bos
|
||||||
|
else:
|
||||||
|
i_t = idx
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
|
||||||
|
o_t = i_t * BT + tl.arange(0, BT)
|
||||||
|
m_t = o_t < T
|
||||||
|
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
||||||
|
for i_h in range(0, H):
|
||||||
|
if IS_VARLEN:
|
||||||
|
offset = bos + i_h * T_max
|
||||||
|
else:
|
||||||
|
offset = bos * H + i_h * T_max
|
||||||
|
|
||||||
|
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
p_A = tl.make_block_ptr(
|
||||||
|
A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||||
|
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||||
|
b_g = tl.load(p_g, boundary_check=(0,))
|
||||||
|
b_g_exp = tl.exp(b_g)
|
||||||
|
|
||||||
|
b_dbeta = tl.zeros([BT], dtype=tl.float32)
|
||||||
|
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
|
||||||
|
b_dg = tl.zeros([BT], dtype=tl.float32)
|
||||||
|
|
||||||
|
for i_k in range(tl.cdiv(K, BK)):
|
||||||
|
p_k = tl.make_block_ptr(
|
||||||
|
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
p_dk = tl.make_block_ptr(
|
||||||
|
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
p_dw = tl.make_block_ptr(
|
||||||
|
dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
b_k_beta_g = (b_k * b_beta[:, None] * b_g_exp[:, None]).to(b_k.dtype)
|
||||||
|
b_dw = tl.load(p_dw, boundary_check=(0, 1))
|
||||||
|
b_dA += tl.dot(b_dw, tl.trans(b_k_beta_g))
|
||||||
|
b_dk_beta_g = tl.dot(b_A, b_dw)
|
||||||
|
b_dk = b_dk_beta_g * b_beta[:, None] * b_g_exp[:, None]
|
||||||
|
b_dbeta += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None], 1)
|
||||||
|
b_dg += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None] * b_beta[:, None], 1)
|
||||||
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
for i_v in range(tl.cdiv(V, BV)):
|
||||||
|
p_v = tl.make_block_ptr(
|
||||||
|
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
|
)
|
||||||
|
p_dv = tl.make_block_ptr(
|
||||||
|
dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
|
)
|
||||||
|
p_du = tl.make_block_ptr(
|
||||||
|
du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
|
)
|
||||||
|
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||||
|
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||||
|
b_du = tl.load(p_du, boundary_check=(0, 1))
|
||||||
|
b_dA += tl.dot(b_du, tl.trans(b_v_beta))
|
||||||
|
b_dv_beta = tl.dot(b_A, b_du)
|
||||||
|
b_dv = b_dv_beta * b_beta[:, None]
|
||||||
|
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
|
||||||
|
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
b_dA = tl.where(m_A, b_dA, 0)
|
||||||
|
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
|
||||||
|
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
|
||||||
|
b_dA = tl.where(m_A, -b_dA * exp(b_g[:, None] - b_g[None, :]), 0)
|
||||||
|
b_dA = b_dA.to(k.dtype.element_ty)
|
||||||
|
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||||
|
|
||||||
|
for i_k in range(tl.cdiv(K, BK)):
|
||||||
|
p_k = tl.make_block_ptr(
|
||||||
|
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
p_dk = tl.make_block_ptr(
|
||||||
|
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
b_dk = tl.load(p_dk, boundary_check=(0, 1))
|
||||||
|
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
|
||||||
|
b_A += tl.dot(b_k_beta, tl.trans(b_k))
|
||||||
|
b_dk_beta = tl.dot(b_dA, b_k)
|
||||||
|
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
|
||||||
|
b_dk += tl.dot(tl.trans(b_dA), b_k_beta)
|
||||||
|
b_dk += b_dk_beta * b_beta[:, None]
|
||||||
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
b_dA_A = b_dA * b_A
|
||||||
|
b_dg += tl.sum(b_dA_A, axis=1) - tl.sum(b_dA_A, axis=0)
|
||||||
|
p_dg = tl.make_block_ptr(dg + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
p_dbeta = tl.make_block_ptr(dbeta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
|
||||||
|
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"USE_G": lambda args: args["g"] is not None,
|
||||||
|
"USE_GK": lambda args: args["gk"] is not None,
|
||||||
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit(do_not_specialize=["T"])
|
||||||
|
def recompute_w_u_fwd_kernel(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
beta,
|
||||||
|
w,
|
||||||
|
u,
|
||||||
|
A,
|
||||||
|
g,
|
||||||
|
gk,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
T_tmp,
|
||||||
|
B,
|
||||||
|
H: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
V: tl.constexpr,
|
||||||
|
NT: tl.constexpr,
|
||||||
|
BT: tl.constexpr,
|
||||||
|
BK: tl.constexpr,
|
||||||
|
BV: tl.constexpr,
|
||||||
|
USE_G: tl.constexpr,
|
||||||
|
USE_GK: tl.constexpr,
|
||||||
|
IS_VARLEN: tl.constexpr,
|
||||||
|
):
|
||||||
|
core_id = tl.program_id(0)
|
||||||
|
total_cores = tl.num_programs(0)
|
||||||
|
T_max = T_tmp
|
||||||
|
|
||||||
|
base_chunks_per_pid = NT // total_cores
|
||||||
|
remainder_chunks = NT % total_cores
|
||||||
|
|
||||||
|
if core_id < remainder_chunks:
|
||||||
|
chunks_this_pid = base_chunks_per_pid + 1
|
||||||
|
start_idx = core_id * chunks_this_pid
|
||||||
|
else:
|
||||||
|
chunks_this_pid = base_chunks_per_pid
|
||||||
|
start_idx = core_id * chunks_this_pid + remainder_chunks
|
||||||
|
|
||||||
|
for idx in range(start_idx, start_idx + chunks_this_pid):
|
||||||
|
for i_b in range(B):
|
||||||
|
for i_h in range(0, H):
|
||||||
|
if IS_VARLEN:
|
||||||
|
i_n, i_t = (
|
||||||
|
tl.load(chunk_indices + idx * 2).to(tl.int32),
|
||||||
|
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
|
||||||
|
)
|
||||||
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||||
|
offset = bos + i_h * T_max
|
||||||
|
T = eos - bos
|
||||||
|
else:
|
||||||
|
T = T_tmp
|
||||||
|
i_t = idx
|
||||||
|
bos, eos = i_b * T, i_b * T + T
|
||||||
|
offset = bos * H + i_h * T_max
|
||||||
|
|
||||||
|
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||||
|
|
||||||
|
p_A = tl.make_block_ptr(
|
||||||
|
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||||
|
)
|
||||||
|
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
for i_v in range(tl.cdiv(V, BV)):
|
||||||
|
p_v = tl.make_block_ptr(
|
||||||
|
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
|
)
|
||||||
|
p_u = tl.make_block_ptr(
|
||||||
|
u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
|
)
|
||||||
|
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||||
|
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||||
|
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||||
|
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
if USE_G:
|
||||||
|
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||||
|
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
||||||
|
|
||||||
|
for i_k in range(tl.cdiv(K, BK)):
|
||||||
|
p_k = tl.make_block_ptr(
|
||||||
|
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
p_w = tl.make_block_ptr(
|
||||||
|
w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
|
b_kb = b_k * b_beta[:, None]
|
||||||
|
if USE_G:
|
||||||
|
b_kb *= b_g[:, None]
|
||||||
|
if USE_GK:
|
||||||
|
p_gk = tl.make_block_ptr(
|
||||||
|
gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||||
|
)
|
||||||
|
b_kb *= tl.exp(tl.load(p_gk, boundary_check=(0, 1)))
|
||||||
|
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
|
||||||
|
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def recompute_w_u_fwd(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
A: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
gk: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||||
|
BT = A.shape[-1]
|
||||||
|
BK = 128
|
||||||
|
BV = 128
|
||||||
|
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
g = g.transpose(1, 2).contiguous() if g is not None else None
|
||||||
|
beta = beta.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
w = torch.empty_like(k)
|
||||||
|
u = torch.empty_like(v)
|
||||||
|
cv_kernel_num = 24
|
||||||
|
recompute_w_u_fwd_kernel[(cv_kernel_num,)](
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
beta=beta,
|
||||||
|
w=w,
|
||||||
|
u=u,
|
||||||
|
A=A,
|
||||||
|
g=g,
|
||||||
|
gk=gk,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
T_tmp=T,
|
||||||
|
B=B,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
V=V,
|
||||||
|
NT=NT,
|
||||||
|
BT=BT,
|
||||||
|
BK=BK,
|
||||||
|
BV=BV,
|
||||||
|
)
|
||||||
|
return w, u
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_wy_repr_bwd(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
g: torch.Tensor,
|
||||||
|
beta: torch.Tensor,
|
||||||
|
A: torch.Tensor,
|
||||||
|
dw: torch.Tensor,
|
||||||
|
du: torch.Tensor,
|
||||||
|
cu_seqlens: Optional[torch.LongTensor],
|
||||||
|
chunk_size: int = 64,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||||
|
BT = chunk_size
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||||
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
BK = 128
|
||||||
|
BV = 128
|
||||||
|
beta = beta.transpose(1, 2).contiguous()
|
||||||
|
g = g.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
dk = torch.empty_like(k)
|
||||||
|
dv = torch.empty_like(v)
|
||||||
|
dbeta = torch.empty_like(beta)
|
||||||
|
dg = torch.empty_like(g)
|
||||||
|
|
||||||
|
cv_kernel_num = 24
|
||||||
|
prepare_wy_repr_bwd_kernel[(cv_kernel_num,)](
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
beta=beta,
|
||||||
|
g=g,
|
||||||
|
A=A,
|
||||||
|
dw=dw,
|
||||||
|
du=du,
|
||||||
|
dk=dk,
|
||||||
|
dv=dv,
|
||||||
|
dbeta=dbeta,
|
||||||
|
dg=dg,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
T=T,
|
||||||
|
B=B,
|
||||||
|
H=H,
|
||||||
|
K=K,
|
||||||
|
V=V,
|
||||||
|
NT=NT,
|
||||||
|
BT=BT,
|
||||||
|
BK=BK,
|
||||||
|
BV=BV,
|
||||||
|
)
|
||||||
|
|
||||||
|
dbeta = dbeta.transpose(1, 2).contiguous()
|
||||||
|
dg = dg.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return dk, dv, dbeta, dg
|
||||||
|
|
||||||
|
|
||||||
|
bwd_prepare_wy_repr = prepare_wy_repr_bwd
|
||||||
|
|
||||||
|
fwd_recompute_w_u = recompute_w_u_fwd
|
||||||
Reference in New Issue
Block a user