11 Commits
v0.9.5 ... main

Author SHA1 Message Date
jiaqiw09
8669a22e9c [fix] fix liger kernel patch for npu (#10583) 2026-06-16 18:21:52 +08:00
Hao Liang
897a44386c [docs] add DataFlow and DataFlex blog tutorials (#10582)
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-16 14:20:36 +08:00
jiaqiw09
7a1e9630f2 [fix] update ascend doc link (#10572) 2026-06-15 13:55:53 +08:00
souljoy
cabe59a343 [model] add MiniCPM5-1B-Chat (#10558) 2026-06-10 16:18:27 +08:00
Co-Cl2
9ca4026efe [model] handle unsloth model loading fallback during checkpoint resume (#7156) (#10551) 2026-06-09 01:01:01 +08:00
Ximing Xing
0b7aaf8f6a [fix] correctly place new token embeddings when embedding is padded (#10547) 2026-06-05 10:47:51 +08:00
codingma
8a4f6a3da5 [model] add gemma-4-12B-it (#10549) 2026-06-04 23:43:20 +08:00
A1waysBeenHere
409e8a477f [model] Patch GDN for NPU (#10504)
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
2026-06-04 16:39:02 +08:00
Cui-yshoho
053d43c0ac [feat] support HyperParallel PT training and activation optimization (#10370) 2026-06-02 22:39:32 +08:00
Zhao73
a98a1ef101 [docs] fix README citation typo (#10540) 2026-06-01 21:04:53 +08:00
Yaowei Zheng
8ef7335b6a [misc] set dev version (#10533) 2026-05-31 00:16:07 +08:00
28 changed files with 3940 additions and 105 deletions

View File

@@ -54,6 +54,7 @@ Start cloud training:
Read technical notes:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
- **Documentation (ASCEND NPU)**: https://llamafactory.readthedocs.io/en/latest/multibackend/npu/index.html
- **Official Blog**: https://blog.llamafactory.net/en/
> [!NOTE]
@@ -111,6 +112,8 @@ Read technical notes:
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
- 💡 [DataFlow × LLaMA Factory: Producing High-Quality Data for LLM Training with a Data Preparation Pipeline](https://wcny4qa9krto.feishu.cn/wiki/LWkkwTDBfiiRKqkDSvucG6yjnbW) (English) | [中文](https://wcny4qa9krto.feishu.cn/wiki/LlMxweUAJimrmykRD5qcGuswnHd)
- 💡 [DataFlex × LLaMA Factory: A Data-Centric Dynamic Training System Built on LLaMA-Factory](https://wcny4qa9krto.feishu.cn/wiki/OlREwPQWdi9K6ZkJNHIciLhtnkv) (English) | [中文](https://wcny4qa9krto.feishu.cn/wiki/H2A9wSsbCinzavkT2oyc2C5Vn0e)
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
@@ -824,7 +827,7 @@ If you have a project that should be incorporated, please contact via email or c
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collapse. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)

View File

@@ -55,7 +55,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **微调视频教程**https://www.bilibili.com/video/BV1djgRzxEts/
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **框架文档(昇腾 NPU**https://llamafactory.readthedocs.io/zh-cn/latest/multibackend/npu/index.html
- **官方博客**https://blog.llamafactory.net/
> [!NOTE]
@@ -113,6 +113,8 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- 💡 [DataFlow × LLaMA Factory: 利用数据准备流水线产出高质量数据训练 LLM](https://wcny4qa9krto.feishu.cn/wiki/LlMxweUAJimrmykRD5qcGuswnHd)(中文)| [English](https://wcny4qa9krto.feishu.cn/wiki/LWkkwTDBfiiRKqkDSvucG6yjnbW)
- 💡 [DataFlex × LLaMA Factory: 构建在 LLaMA-Factory 之上的以数据为中心的动态训练系统](https://wcny4qa9krto.feishu.cn/wiki/H2A9wSsbCinzavkT2oyc2C5Vn0e)(中文)| [English](https://wcny4qa9krto.feishu.cn/wiki/OlREwPQWdi9K6ZkJNHIciLhtnkv)
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
@@ -828,7 +830,7 @@ swanlab_run_name: test_run # 可选
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collapse. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)

View File

@@ -36,6 +36,7 @@ COPY . /app
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
RUN pip uninstall -y torch torchvision torchaudio
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
RUN pip install --no-cache-dir -r requirements/triton_ascend.txt
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation

View File

@@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer,Qwen3_5MoeVisionBlock
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
rdzv_backend: static
same_network: true
use_cpu: false

View File

@@ -0,0 +1,51 @@
# Start FSDP2 full fine-tuning on Ascend NPU
# Usage:
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config_qwen35_moe.yaml \
# src/train.py examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
#
# Note: Change `num_processes` in fsdp2_config_qwen35_moe.yaml to match your NPU count
### model
model_name_or_path: Qwen/Qwen3.5-35B-A3B
trust_remote_code: true
use_v1_kernels: false
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: alpaca_en_demo
template: qwen3_5_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
packing: false
### output
output_dir: saves/Qwen3.5-35B/lora/sft
logging_steps: 1
save_steps: 2000
max_steps: 2000
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null
disable_gradient_checkpointing: true

View File

@@ -0,0 +1,2 @@
--extra-index-url https://triton-ascend.osinfra.cn/pypi/simple
triton-ascend==3.2.1

View File

@@ -886,6 +886,9 @@ register_model_group(
"Gemma-4-E4B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
},
"Gemma-4-12B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-12B-it",
},
},
template="gemma4n",
multimodal=True,
@@ -1912,6 +1915,17 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM5-1B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM5-1B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM5-1B",
},
},
template="empty",
)
register_model_group(
models={
"MiniCPM-o-2.6": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.5"
VERSION = "0.9.6.dev0"
def print_env() -> None:

View File

@@ -487,7 +487,7 @@ class FinetuningArguments(
metadata={
"help": (
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
"Only supported for the 'sft' stage with full fine-tuning."
"Only supported for the 'pt' and 'sft' stages with full fine-tuning."
)
},
)

View File

@@ -194,9 +194,15 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if isinstance(model, PeftModel):
pass # already loaded via load_unsloth_peft_model in loader.py
else:
if model_args.use_unsloth:
peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if peft_model is not None:
model = peft_model
if not model_args.use_unsloth: # unsloth was disabled or fell back
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))

View File

@@ -34,7 +34,7 @@ from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
@@ -142,14 +142,13 @@ def load_model(
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load:
if model is None:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
@@ -176,7 +175,6 @@ def load_model(
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections.abc import Iterable
from contextlib import nullcontext
from typing import TYPE_CHECKING, Optional
@@ -29,7 +30,81 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
def get_embedding_vocab_size(model: "PreTrainedModel") -> int:
r"""Get the vocab size from the input embedding layer.
Handles DeepSpeed ZeRO-3 parameter sharding by gathering the embedding weight
before reading its size.
"""
embedding = model.get_input_embeddings()
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
with deepspeed.zero.GatheredParameters([embedding.weight]):
return embedding.weight.size(0)
return embedding.weight.size(0)
def _resolve_new_token_ids(
new_tokens: Optional[Iterable[str]],
tokenizer: "PreTrainedTokenizer",
embed_size: int,
) -> Optional[list[int]]:
r"""Resolve the explicit embedding-row IDs of the newly added tokens.
Relying on ``embed_weight[-num_new_tokens:]`` to locate new tokens is unsafe when
the model embedding was already padded beyond the tokenizer vocab (e.g. Qwen2.5-VL
has vocab 151665 but embedding 151936). In that case the appended tokens land
inside the original padding zone and the tail slice points at the wrong rows.
Args:
new_tokens: Iterable of the newly added token strings.
tokenizer: The tokenizer instance.
embed_size: Current embedding size (upper bound for valid token IDs).
Returns:
A sorted list of unique, in-range token IDs, or ``None`` when no tokens are
given so that callers can fall back to the tail-slice behaviour.
"""
if not new_tokens:
return None
unk_token_id = getattr(tokenizer, "unk_token_id", None)
token_ids: set[int] = set()
for token_str in new_tokens:
token_id = tokenizer.convert_tokens_to_ids(token_str)
if token_id is None or token_id == unk_token_id or not (0 <= token_id < embed_size):
logger.warning_rank0(f"Token '{token_str}' not found or out of range, skipping during init.")
continue
token_ids.add(token_id)
return sorted(token_ids) or None
def _existing_embeddings(
embed_weight: "torch.Tensor", num_new_tokens: int, new_token_ids: Optional[list[int]]
) -> "torch.Tensor":
"""Return the rows treated as 'existing' embeddings used as the init baseline.
Prefers excluding the explicit new-token rows (robust to padding). Falls back to
dropping the last ``num_new_tokens`` rows when no explicit IDs are available.
"""
if new_token_ids:
mask = torch.ones(embed_weight.size(0), dtype=torch.bool, device=embed_weight.device)
mask[torch.as_tensor(new_token_ids, device=embed_weight.device, dtype=torch.long)] = False
return embed_weight[mask]
if num_new_tokens > 0:
return embed_weight[:-num_new_tokens]
return embed_weight
def _noisy_mean_initialization(
embed_weight: "torch.Tensor", num_new_tokens: int, token_ids: Optional[list[int]] = None
) -> None:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
@@ -37,9 +112,20 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
token_ids: Explicit token IDs to initialize. When provided, these exact rows are
written (robust to padding). When ``None``, falls back to the last
``num_new_tokens`` rows.
"""
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
avg_weight = _existing_embeddings(embed_weight, num_new_tokens, token_ids).mean(dim=0, keepdim=True)
if token_ids:
noise_weight = torch.empty(
len(token_ids), embedding_dim, device=embed_weight.device, dtype=embed_weight.dtype
)
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[token_ids] = avg_weight + noise_weight
else:
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
@@ -51,6 +137,7 @@ def _description_based_initialization(
descriptions: dict[str, str],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
new_token_ids: Optional[list[int]] = None,
add_noise: bool = False,
) -> None:
"""Initialize new token embeddings based on textual descriptions.
@@ -61,6 +148,9 @@ def _description_based_initialization(
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
New tokens are placed by their resolved token ID rather than by tail slicing,
so the initialization is correct even when the embedding matrix was padded.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
@@ -68,6 +158,8 @@ def _description_based_initialization(
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
new_token_ids: IDs of all newly added tokens. Used to exclude not-yet-initialized
rows when averaging description-token embeddings (robust to embedding padding).
add_noise: Whether to add Gaussian noise to the initialization
Example:
@@ -77,38 +169,55 @@ def _description_based_initialization(
}
"""
embedding_dim = embed_weight.size(1)
vocab_size = embed_weight.size(0)
unk_token_id = getattr(tokenizer, "unk_token_id", None)
device = embed_weight.device
# The set of rows that are NOT yet initialized (the newly added tokens). Description
# tokens that fall into this set must be excluded, otherwise we would average garbage.
# `num_new_tokens` (the padded resize delta) is NOT a reliable boundary, so rely on
# the explicit IDs, falling back to resolving them from the description keys.
if new_token_ids is None:
new_token_ids = _resolve_new_token_ids(descriptions.keys(), tokenizer, vocab_size)
new_id_set = set(new_token_ids or [])
fallback_embedding = _existing_embeddings(embed_weight, num_new_tokens, new_token_ids).mean(dim=0)
for token_str, desc in descriptions.items():
# Resolve token ID for correct placement (robust to embedding padding)
token_id = tokenizer.convert_tokens_to_ids(token_str)
if token_id is None or token_id == unk_token_id or not (0 <= token_id < vocab_size):
logger.warning_rank0(f"desc_init: token '{token_str}' not found or out of range, skipping.")
continue
for i, desc in enumerate(descriptions.values()):
# Tokenize description text
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
token_ids = tokens["input_ids"][0]
# Move to the same device as embed_weight
device = embed_weight.device
token_ids = token_ids.to(device)
token_ids = tokens["input_ids"][0].tolist()
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
# Keep only description tokens that already have a meaningful embedding.
valid_token_ids = [tid for tid in token_ids if tid not in new_id_set and 0 <= tid < vocab_size]
if len(valid_token_ids) == 0:
# Fallback: use mean of all existing embeddings
logger.warning_rank0(
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
f"Description for token '{token_str}' contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
base_embedding = fallback_embedding
else:
# Get embeddings of description tokens and average them
token_embeds = model.get_input_embeddings()(valid_token_ids)
valid_ids_tensor = torch.as_tensor(valid_token_ids, device=device, dtype=torch.long)
token_embeds = model.get_input_embeddings()(valid_ids_tensor)
base_embedding = token_embeds.mean(dim=0)
# Add noise if requested (ensure correct device and dtype)
if add_noise:
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
embed_weight[-num_new_tokens + i] = base_embedding + noise
embed_weight[token_id] = base_embedding + noise
else:
embed_weight[-num_new_tokens + i] = base_embedding
embed_weight[token_id] = base_embedding
def _initialize_embeddings(
@@ -118,6 +227,7 @@ def _initialize_embeddings(
new_special_tokens_config: Optional[dict],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
new_token_ids: Optional[list[int]] = None,
) -> None:
"""Single source of truth for embedding initialization.
@@ -130,16 +240,18 @@ def _initialize_embeddings(
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
new_token_ids: Explicit IDs of the newly added tokens (robust to embedding padding).
When ``None``, the init helpers fall back to the last ``num_new_tokens`` rows.
"""
if init_method == "desc_init" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=False
)
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=True
)
else:
if init_method != "noise_init":
@@ -147,20 +259,28 @@ def _initialize_embeddings(
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
)
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
_noisy_mean_initialization(embed_weight, num_new_tokens)
_noisy_mean_initialization(embed_weight, num_new_tokens, token_ids=new_token_ids)
def resize_embedding_layer(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
new_tokens: Optional[Iterable[str]] = None,
new_special_tokens_config: Optional[dict] = None,
init_special_tokens: str = "noise_init",
) -> None:
r"""Resize token embeddings and initialize new tokens.
r"""Resize token embeddings (when needed) and initialize the newly added tokens.
Resizing and initialization are decoupled: even when the tokenizer vocab fits inside
the model's existing (padded) embedding matrix and no resize is triggered, the newly
added tokens still occupy uninitialized rows and must be initialized. We therefore
resolve the explicit row IDs of ``new_tokens`` and always initialize those rows.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_tokens: Iterable of the newly added token strings. Used to locate the exact
embedding rows to initialize, which is robust to pre-existing embedding padding.
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
@@ -175,23 +295,41 @@ def resize_embedding_layer(
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
current_embedding_size = get_embedding_vocab_size(model)
needs_resize = len(tokenizer) > current_embedding_size
if len(tokenizer) > current_embedding_size:
if needs_resize:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
# mean_resizing=False preserves the original embedding distribution exactly.
# HuggingFace's default mean_resizing=True re-samples new rows from the mean/covariance
# of existing embeddings, which conflicts with our explicit initialization below.
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64, mean_resizing=False)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
# Resolve the exact rows of the new tokens. This works whether or not a resize was
# triggered (e.g. tokens added into a model's pre-existing padding zone).
new_token_ids = _resolve_new_token_ids(new_tokens, tokenizer, new_embedding_size)
if num_new_tokens <= 0 and not new_token_ids:
return
if needs_resize:
logger.info_rank0(
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
)
else:
logger.info_rank0(
f"No resize needed (vocab fits in padded embedding {new_embedding_size}); "
f"initializing {len(new_token_ids or [])} new token(s) in place."
)
# Initialize input embeddings
_initialize_embeddings(
@@ -201,6 +339,7 @@ def resize_embedding_layer(
new_special_tokens_config,
tokenizer,
model,
new_token_ids=new_token_ids,
)
# Initialize output embeddings if not tied
@@ -212,7 +351,14 @@ def resize_embedding_layer(
new_special_tokens_config,
tokenizer,
model,
new_token_ids=new_token_ids,
)
if needs_resize:
model.config.vocab_size = new_embedding_size
# Also update the nested text_config for VL models (e.g., Qwen2.5-VL, LLaVA),
# otherwise config.vocab_size and config.text_config.vocab_size become inconsistent.
if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "vocab_size"):
model.config.text_config.vocab_size = new_embedding_size
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@@ -16,6 +16,7 @@ import inspect
from typing import TYPE_CHECKING
from ...extras import logging
from ...extras.misc import get_device_name
if TYPE_CHECKING:
@@ -81,6 +82,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
elif model_type == "qwen3_5":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
elif model_type == "qwen3_5_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
@@ -97,5 +100,12 @@ def apply_liger_kernel(
else:
kwargs = {}
if get_device_name() == "npu":
import torch
if "Ascend910" not in torch.npu.get_device_name(0):
kwargs["swiglu"] = False
kwargs["fused_linear_cross_entropy"] = False
apply_liger_kernel(**kwargs)
logger.info_rank0("Liger kernel has been applied to the model.")

View File

@@ -84,8 +84,12 @@ def load_unsloth_peft_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""Load peft model with unsloth. Used in both training and inference."""
) -> Optional["PreTrainedModel"]:
r"""Load peft model with unsloth. Used in both training and inference.
Returns None if unsloth does not support the model type, and sets
model_args.use_unsloth = False so callers can fall back to standard loading.
"""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
@@ -95,7 +99,9 @@ def load_unsloth_peft_model(
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
return None
if not is_trainable:
FastLanguageModel.for_inference(model)

View File

@@ -20,6 +20,7 @@ from peft import PeftModel
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
from ..extras import logging
from ..extras.misc import infer_optim_dtype
@@ -84,7 +85,60 @@ def _check_fla_dependencies() -> None:
) from exc
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
def patch_qwen3_5_forward_npu(model: "PreTrainedModel") -> None:
"""Patch for Qwen3.5 models on NPU by importing torch_npu to enable torch.cuda compatibility.
On NPU, torch.cuda operations will fail unless torch_npu is imported.
torch_npu provides compatibility layer that maps torch.cuda calls to NPU operations.
Also replaces chunk_gated_delta_rule with NPU-compatible implementation.
"""
import importlib.metadata
if "Ascend910" not in torch.npu.get_device_name(0):
logger.warning_rank0("Currently only 910B series NPUs are supported for the NPU GDN patch.")
return
try:
importlib.metadata.version("triton_ascend")
except importlib.metadata.PackageNotFoundError:
logger.warning_rank0(
"triton_ascend not installed, skipping NPU GDN patch. "
"To enable it on NPU, reinstall Triton with the Ascend build: "
"`pip uninstall -y triton && pip install -r requirements/triton_ascend.txt`. "
"Note: triton and triton_ascend cannot coexist — triton must be uninstalled first."
)
return
logger.info_rank0("triton_ascend detected for NPU compatibility.")
from ..third_party.triton.chunk_gated_delta_rule import chunk_gated_delta_rule as npu_chunk_gated_delta_rule
if model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
try:
# Qwen3.5-MoE structure: model.model.language_model.layers
for layer in model.model.language_model.layers:
if hasattr(layer, "linear_attn"):
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
logger.info_rank0(
"Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5-MoE model."
)
except Exception as e:
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
elif model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
try:
# Qwen3.5 structure: model.model.layers
for layer in model.model.layers:
if hasattr(layer, "linear_attn"):
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
logger.info_rank0("Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5 model.")
except Exception as e:
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
def patch_qwen3_5_forward_gpu(model: "PreTrainedModel") -> None:
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
@@ -403,9 +457,14 @@ def patch_model(
prepare_valuehead_model(model)
if model_args.resize_vocab:
# Pass the explicit list of newly added tokens so their exact embedding rows can be
# located and initialized, even when they land in a model's pre-existing padding zone.
new_tokens = (model_args.add_tokens or []) + (model_args.add_special_tokens or [])
resize_embedding_layer(
model,
tokenizer,
new_tokens=new_tokens or None,
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
init_special_tokens=model_args.init_special_tokens,
)
@@ -421,8 +480,12 @@ def patch_model(
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
patch_qwen3_5_forward(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"]:
if is_torch_npu_available():
patch_qwen3_5_forward_npu(model)
elif is_torch_cuda_available() and model_args.flash_attn == "fa2":
# this is the patch for packing/neat_packing for GPU GDN. And when setting packing, flash_attn must be fa2.
patch_qwen3_5_forward_gpu(model)
if not model_args.use_unsloth:
print_attn_implementation(model.config)

View File

@@ -0,0 +1,594 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import get_autotune_config, get_npu_properties, prepare_chunk_indices, prepare_chunk_offsets
CUBE_CORE_NUM = get_npu_properties()["num_aicore"]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=get_autotune_config(multibuffer_list=(False,)),
key=["H", "K", "V", "BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
w,
v_new,
g,
gk,
h,
h0,
ht,
cu_seqlens,
chunk_offsets,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_all = T
NT_all = NT
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# Initialize hidden states
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
if IS_VARLEN:
v = v + (i_h * T_all + bos) * V
k = k + (i_h * T_all + bos) * K
w = w + (i_h * T_all + bos) * K
g = g + i_h * T_all + bos
h = h + (i_h * NT_all + boh) * K * V
if SAVE_NEW_VALUE:
v_new_base = v_new + (i_h * T_all + bos) * V
else:
v = v + (i_n * H + i_h) * T * V
k = k + (i_n * H + i_h) * T * K
w = w + (i_n * H + i_h) * T * K
g = g + (i_n * H + i_h) * T
h = h + (i_n * H + i_h) * NT * K * V
if SAVE_NEW_VALUE:
v_new_base = v_new + (i_n * H + i_h) * T * V
if USE_INITIAL_STATE:
h0_ptr = h0 + i_nh * K * V
if STORE_FINAL_STATE:
ht_ptr = ht + i_nh * K * V
# Load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# Main recurrence over chunks
for i_t in range(NT):
# Store current hidden state h_t
p_h1 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
# Compute v_residual = v - w @ h
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
p_v = tl.make_block_ptr(v, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(v_new_base, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
# Apply output gate g
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
b_g_last = tl.load(g + last_idx)
p_g = tl.make_block_ptr(g, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_v *= (m_t * tl.exp(b_g_last - b_g))[:, None]
b_g_last_exp = tl.exp(b_g_last)
b_h1 *= b_g_last_exp
if K > 64:
b_h2 *= b_g_last_exp
if K > 128:
b_h3 *= b_g_last_exp
if K > 192:
b_h4 *= b_g_last_exp
# Apply key gate gk
if USE_GK:
o_k1 = tl.arange(0, 64).to(tl.float32)
gk_base_ptr = gk + (i_n * H + i_h) * T * K
b_gk_last1 = tl.load(gk_base_ptr + last_idx * K + o_k1, mask=(o_k1 < K), other=0.0)
b_h1 *= tl.exp(b_gk_last1)[:, None]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(gk_base_ptr + last_idx * K + o_k2, mask=(o_k2 < K), other=0.0)
b_h2 *= tl.exp(b_gk_last2)[:, None]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(gk_base_ptr + last_idx * K + o_k3, mask=(o_k3 < K), other=0.0)
b_h3 *= tl.exp(b_gk_last3)[:, None]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(gk_base_ptr + last_idx * K + o_k4, mask=(o_k4 < K), other=0.0)
b_h4 *= tl.exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)
# Update hidden state: h += k @ v
p_k = tl.make_block_ptr(k, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last1[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h1 += tl.dot(b_k, b_v)
if K > 64:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last2[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h2 += tl.dot(b_k, b_v)
if K > 128:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last3[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h3 += tl.dot(b_k, b_v)
if K > 192:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last4[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h4 += tl.dot(b_k, b_v)
# Store final state
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # default:64
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, u.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V).permute(0, 2, 1, 3, 4).contiguous()
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
BV = 128
v_new = torch.empty_like(u).permute(0, 2, 1, 3).contiguous() if save_new_value else None
k = k.permute(0, 2, 1, 3).contiguous()
w = w.permute(0, 2, 1, 3).contiguous()
u = u.permute(0, 2, 1, 3).contiguous()
g = g.permute(0, 2, 1).contiguous()
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[(triton.cdiv(V, BV), N * H)](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
NT=NT,
)
h = h.permute(0, 2, 1, 3, 4).contiguous()
v_new = v_new.permute(0, 2, 1, 3).contiguous()
return h, v_new, final_state
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["dh0"] is not None,
"USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=get_autotune_config(multibuffer_list=(True, False)),
key=["H", "K", "V", "BT", "BV", "USE_G", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
q,
k,
w,
g,
gk,
dht,
dh0,
do,
dh,
dv,
dv2,
cu_seqlens,
chunk_offsets,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_FINAL_STATE_GRADIENT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_all = T
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_dh2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_dh3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
w += (bos * H + i_h) * K
do += (bos * H + i_h) * V
dv += (bos * H + i_h) * V
dv2 += (bos * H + i_h) * V
dh += (boh * H + i_h) * K * V
if USE_GK:
gk += (bos * H + i_h) * K
if USE_INITIAL_STATE:
dh0 += i_nh * K * V
if USE_FINAL_STATE_GRADIENT:
dht += i_nh * K * V
stride_v = H * V
stride_h = H * K * V
stride_k = H * K
if USE_FINAL_STATE_GRADIENT:
p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_dh1 += tl.load(p_dht1, boundary_check=(0, 1))
if K > 64:
p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_dh2 += tl.load(p_dht2, boundary_check=(0, 1))
if K > 128:
p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_dh3 += tl.load(p_dht3, boundary_check=(0, 1))
if K > 192:
p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_dh4 += tl.load(p_dht4, boundary_check=(0, 1))
for i_t in range(NT - 1, -1, -1):
p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
if IS_VARLEN:
bos_g = i_h * T_all + bos
else:
bos_g = (i_n * H + i_h) * T_all
bg_last = tl.load(g + bos_g + last_idx)
bg_last_exp = tl.exp(bg_last)
p_g = tl.make_block_ptr(
base=g + bos_g, shape=(T,), strides=(1,), offsets=(i_t * BT,), block_shape=(BT,), order=(0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_exp = tl.exp(b_g)
p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
# Update dv
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0)
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
if K > 64:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
if K > 128:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
if K > 192:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
b_dv *= (m_t * tl.exp(bg_last - b_g))[:, None]
b_dv += tl.load(p_dv, boundary_check=(0, 1))
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# Update dh
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
if USE_G:
b_dh1 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh1 *= tl.exp(b_gk_last1[:, None])
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 64:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh2 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh2 *= tl.exp(b_gk_last2[:, None])
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 128:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh3 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh3 *= tl.exp(b_gk_last3[:, None])
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 192:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh4 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh4 *= tl.exp(b_gk_last4[:, None])
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if USE_INITIAL_STATE:
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_bwd_dhu(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
do: torch.Tensor,
dv: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
h0: torch.Tensor | None = None,
dht: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *q.shape, do.shape[-1]
# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
dh = q.new_empty(B, NT, H, K, V)
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
dv2 = torch.empty_like(dv)
BV = 128
g = g.permute(0, 2, 1).contiguous()
chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[(triton.cdiv(V, BV), N * H)](
q=q,
k=k,
w=w,
g=g,
gk=gk,
dht=dht,
dh0=dh0,
do=do,
dh=dh,
dv=dv,
dv2=dv2,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
)
return dh, dh0, dv2

View File

@@ -0,0 +1,347 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional
import torch
from .chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .solve_tril import solve_tril
from .utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
from .wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
):
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
return g, o, A, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
):
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=False,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dv = chunk_bwd_dv_local(
q=q,
k=k,
g=g,
do=do,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=q,
k=k,
w=w,
g=g,
h0=initial_state,
dht=dht,
do=do,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dq, dk, dw, dg = chunk_bwd_dqkwg(
q=q,
k=k,
v=v_new,
w=w,
g=g,
h=h,
dv=dv,
do=do,
dh=dh,
chunk_size=chunk_size,
scale=scale,
cu_seqlens=cu_seqlens,
)
dk2, dv, db, dg2 = prepare_wy_repr_bwd(
k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size
)
dk.add_(dk2)
dg.add_(dg2)
if dg.dtype != torch.float32:
raise ValueError(f"dg current type is {dg.dtype} , should be float32")
dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False)
return dq, dk, dv, db, dg, dh0
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@autocast_custom_fwd
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
chunk_size: int = 64,
):
q_rstd, k_rstd = None, None
g, o, A, final_state = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
ctx.chunk_size = chunk_size
return o.to(q.dtype), final_state
@staticmethod
@input_guard
@autocast_custom_bwd
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
A=A,
scale=ctx.scale,
initial_state=initial_state,
do=do,
dht=dht,
cu_seqlens=cu_seqlens,
chunk_size=ctx.chunk_size,
)
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
head_first: bool = False,
):
r"""Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor):
betas of shape `[B, T, H]`.
scale (Optional[float]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
use_qk_l2norm_in_kernel (bool):
Whether to apply L2norm to the q/k tensor internally. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
This argument has been deprecated.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
""" # noqa: D205
if q.dtype != k.dtype or k.dtype != v.dtype:
raise ValueError(
f"q current type is {q.dtype} , k current type is {k.dtype} ,v current type is {v.dtype} , they should are equal"
)
if q.dtype == torch.float32:
raise ValueError("ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.")
if len(beta.shape) != 3:
raise ValueError(
f"beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
)
if head_first:
warnings.warn(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
original_dtype = x.dtype
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
# Counteract verl's autocast promotion (bf16 -> fp32) by restoring original dtype
return (x * inv_norm).to(original_dtype)
if use_qk_l2norm_in_kernel:
q = l2norm(q, dim=-1, eps=1e-6)
k = l2norm(k, dim=-1, eps=1e-6)
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, False, chunk_size
)
return o, final_state

View File

@@ -0,0 +1,617 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import exp, prepare_chunk_indices, prepare_chunk_offsets
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"USE_DW": lambda args: args["dw"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_bwd_kernel_dqkwg(
q,
k,
v,
h,
g,
g_gamma,
do,
dh,
dq,
dk,
dg,
w,
dv,
dw,
cu_seqlens,
chunk_indices,
scale,
B: tl.constexpr,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
USE_DW: tl.constexpr,
IS_VARLEN: tl.constexpr,
gdiff,
):
i_t, i_b = tl.program_id(0), tl.program_id(1)
T_max = T
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
total = B * T_max
T = eos - bos
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
total = B * T_max
NK = tl.cdiv(K, BK)
for i_k in range(NK):
if USE_G:
dg_k = dg + i_k * total * H
for i_h in range(H):
v_h = v + (bos * H + i_h) * V
do_h = do + (bos * H + i_h) * V
h_h = h + (i_tg * H + i_h).to(tl.int64) * K * V
dh_h = dh + (i_tg * H + i_h).to(tl.int64) * K * V
q_h = q + (bos * H + i_h) * K
k_h = k + (bos * H + i_h) * K
dq_h = dq + (bos * H + i_h) * K
dk_h = dk + (bos * H + i_h) * K
if USE_DW:
w_h = w + (bos * H + i_h) * K # noqa: F841
dw_h = dw + (bos * H + i_h) * K
dv_h = dv + (bos * H + i_h) * V
if USE_G:
if IS_VARLEN:
dg_h = dg_k + i_h * T_max + bos
g_h = g + i_h * T_max + bos
else:
dg_h = dg_k + (i_b * H + i_h) * T_max
g_h = g + (i_b * H + i_h) * T_max
b_dg_last = tl.zeros(
[
1,
],
dtype=tl.float32,
)
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
b_g_last = b_gamma * min(BT, T - i_t * BT)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
if USE_G:
b_dg_last += tl.sum(b_h * b_dh)
b_ds += tl.dot(b_do, tl.trans(b_v))
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
if USE_DW:
p_dv = tl.make_block_ptr(dv_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
if USE_DW:
p_dw = tl.make_block_ptr(dw_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
p_q = tl.make_block_ptr(q_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
p_dq = tl.make_block_ptr(dq_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_dg = tl.zeros(
[
BT,
],
dtype=tl.float32,
)
p_g = tl.make_block_ptr(g_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_last = tl.load(g_h + (min(i_t * BT + BT, T) - 1) * 1)
b_dg_last *= tl.exp(b_g_last)
b_dq = b_dq * tl.exp(b_g)[:, None] * scale
b_dg += tl.sum(b_dq * b_q, axis=1)
b_dk = b_dk * tl.where(m_t, tl.exp(-b_g + b_g_last), 0)[:, None]
b_dg -= tl.sum(b_k * b_dk, axis=1)
b_dg_last += tl.sum(b_dk * b_k)
if IS_VARLEN:
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
else:
p_gdiff = tl.make_block_ptr(
gdiff + i_b * H * NT * BT * BT + i_h * NT * BT * BT + i_t * BT * BT,
(BT, BT),
(BT, 1),
(0, 0),
(BT, BT),
(1, 0),
)
gdiff_ = tl.load(p_gdiff)
b_ds = b_ds * gdiff_ * scale
b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
b_dg += tl.sum(b_ds2, axis=1)
b_dg -= tl.sum(b_ds2, axis=0)
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)
p_dg = tl.make_block_ptr(dg_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
last_index_local = min(BT, T - i_t * BT) - 1
if last_index_local >= 0:
is_last_mask = tl.arange(0, BT) == last_index_local
b_dg = tl.where(is_last_mask, b_dg + b_dg_last, b_dg)
else:
pass
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
elif USE_G_GAMMA:
b_dq = b_dq * exp(b_g)[:, None] * scale
b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None]
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
else:
b_ds = tl.where(m_A, b_ds, 0)
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_bwd_kernel_dv_local(
q,
k,
g,
g_gamma,
do,
dv,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_b = tl.program_id(0), tl.program_id(1)
T_max = T
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
for i_h in range(H):
offset_kh = (bos * H + i_h) * K
offset_vh = (bos * H + i_h) * V
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + offset_kh, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_q = tl.make_block_ptr(q + offset_kh, (K, T), (1, H * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, b_q)
if USE_G:
if IS_VARLEN:
offset_g = i_h * T_max + bos
else:
offset_g = i_b * H * T_max + i_h * T_max
p_g = tl.make_block_ptr(g + offset_g, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_A = tl.where(m_A, b_A * tl.exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
else:
b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_do = tl.make_block_ptr(do + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
v,
h,
g,
g_gamma,
o,
cu_seqlens,
chunk_offsets,
scale,
T,
H: tl.constexpr,
N: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_max = T
for i_v in range(tl.cdiv(V, BV)):
for i_n in range(N):
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
base_chunks_per_pid = NT // total_cores
remainder = NT % total_cores
if core_id < remainder:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * base_chunks_per_pid + remainder
# offset calculation
for i_h in range(0, H):
q_offset = (bos * Hg + i_h // (H // Hg)) * K
k_offset = (bos * Hg + i_h // (H // Hg)) * K
v_offset = (bos * H + i_h) * V
o_offset = (bos * H + i_h) * V
for i_t in range(start_idx, start_idx + chunks_this_pid):
i_tg = boh + i_t
h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + q_offset, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k + k_offset, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o += tl.dot(b_q, b_h)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A += tl.dot(b_q, b_k)
if USE_G:
if IS_VARLEN:
p_g = tl.make_block_ptr(g + bos + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
else:
p_g = tl.make_block_ptr(g + bos * H + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v + v_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + o_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_bwd_dqkwg(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
do: torch.Tensor,
h: torch.Tensor,
dh: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
w: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
scale: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 128 if cu_seqlens is None else 64
BV = 64
NK = triton.cdiv(K, BK)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
g = g.transpose(1, 2).contiguous()
dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
dw = torch.empty_like(w) if w is not None else None
grid = (NT, B)
if cu_seqlens is None:
if NT * BT == T:
g_ = g.reshape(B, H, NT, BT)
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
g_diff = g_diff.clamp(-60, 60).exp()
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
else:
diff = NT * BT - T
g_ = torch.cat((g, torch.zeros(B, H, diff).to(g.device)), dim=-1).reshape(B, H, NT, BT)
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
g_diff = g_diff.clamp(-60, 60).exp()
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
bias = torch.arange(0, BT).to(g.device)
o_t = (NT - 1) * BT + bias
m_t = o_t < T
m_A = m_t[:, None] & m_t
g_diff[:, :, -1] *= m_A
else:
g_diff = None
chunk_bwd_kernel_dqkwg[grid](
q=q,
k=k,
v=v,
h=h,
g=g,
g_gamma=g_gamma,
do=do,
dh=dh,
dv=dv,
w=w,
dw=dw,
dq=dq,
dk=dk,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
B=B,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
gdiff=g_diff,
)
if dg is not None:
dg = dg.sum(0)
dg = dg.transpose(1, 2).contiguous()
return dq, dk, dw, dg
def chunk_bwd_dv_local(
q: torch.Tensor,
k: torch.Tensor,
do: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
scale: float = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, H, K, V = *k.shape, do.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
BK = 128
BV = 128
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g = g.transpose(1, 2).contiguous()
dv = torch.empty_like(do)
grid = (NT, B)
chunk_bwd_kernel_dv_local[grid](
q=q,
k=k,
g=g,
g_gamma=g_gamma,
do=do,
dv=dv,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return dv
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) # noqa: F841
if scale is None:
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
if cu_seqlens is None:
N, chunk_offsets = B, None
else:
N, chunk_offsets = (
len(cu_seqlens) - 1,
prepare_chunk_offsets(cu_seqlens, BT),
)
def grid(meta):
return (triton.cdiv(V, meta["BV"]), N * H)
g = g.transpose(1, 2).contiguous()
h = h.contiguous()
CV_kernel_num = 24
chunk_fwd_kernel_o[(CV_kernel_num,)](
q,
k,
v,
h,
g,
g_gamma,
o,
cu_seqlens,
chunk_offsets,
scale,
T=T,
H=H,
N=N,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=128,
BV=128,
)
return o
bwd_chunk_dqkwg = chunk_bwd_dqkwg
bwd_chunk_dv_local = chunk_bwd_dv_local

View File

@@ -0,0 +1,359 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import prepare_chunk_indices
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
USE_G: tl.constexpr,
NT,
B,
TOTAL_TASKS,
):
core_id = tl.program_id(0)
num_blocks = tl.num_programs(0)
T_max = T
base_tasks_per_block = TOTAL_TASKS // num_blocks
remainder_tasks = TOTAL_TASKS % num_blocks
if core_id < remainder_tasks:
tasks_this_core = base_tasks_per_block + 1
start_idx = core_id * tasks_this_core
else:
tasks_this_core = base_tasks_per_block
start_idx = core_id * base_tasks_per_block + remainder_tasks
for idx in range(start_idx, start_idx + tasks_this_core):
i_b = idx // NT
local_idx = idx % NT
if IS_VARLEN:
i_n = tl.load(chunk_indices + local_idx * 2).to(tl.int32)
i_t = tl.load(chunk_indices + local_idx * 2 + 1).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_local = eos - bos
else:
bos, eos = 0, T
i_t = local_idx
T_local = T
for i_h in range(H):
k_batch_off = i_b * T_max * H * K
beta_batch_off = i_b * H * T_max
g_batch_off = i_b * H * T_max
A_batch_off = i_b * T_max * H * BT
p_beta = tl.make_block_ptr(
beta + beta_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + k_batch_off + (bos * H + i_h) * K,
(T_local, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
dot_product = tl.dot(b_k, tl.trans(b_k))
o_t = i_t * BT + tl.arange(0, BT)
o_t = o_t.to(tl.float32)
T_mask = (o_t < T_local).to(tl.float32)
row_indices = tl.arange(0, BT)[:, None]
col_indices = tl.arange(0, BT)[None, :]
tril_mask = (row_indices > col_indices).to(tl.float32)
tril_mask = tril_mask * T_mask[:, None]
masked_dot = dot_product * tril_mask
b_A += masked_dot
if USE_G:
p_g = tl.make_block_ptr(
g + g_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_g_diff = tl.minimum(tl.maximum(b_g_diff, -50.0), 50.0)
b_A *= tl.exp(b_g_diff)
b_A *= b_beta[:, None]
p_A = tl.make_block_ptr(
A + A_batch_off + (bos * H + i_h) * BT, (T_local, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(configs=[triton.Config({"BK": BK}) for BK in [32, 64]], key=["BC"])
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_i, i_j = i_c // NC, i_c % NC
for i_h in range(H):
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_val = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_val = T
should_compute = (i_t * BT + i_i * BC < T_val) and (i_i > i_j)
if should_compute:
k_ptr = k + (bos * H + i_h) * K
g_ptr = g + (bos * H + i_h) * K
A_ptr = A + (bos * H + i_h) * BT
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T_val,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
b_kt = tl.make_block_ptr(
k_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
p_gk = tl.make_block_ptr(
g_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
b_gn = tl.load(g_ptr + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
b_g = tl.load(p_g, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.exp(b_g - b_gn[None, :])
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kt = tl.load(b_kt, boundary_check=(0, 1)) * tl.exp(b_gn[:, None] - b_gk)
b_A += tl.dot(b_k, b_kt)
b_A *= b_beta[:, None]
p_A = tl.make_block_ptr(A_ptr, (T_val, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_i, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
for i_h in range(H):
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_val = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_val = T
should_compute = i_t * BT + i_i * BC < T_val
if should_compute:
o_i = tl.arange(0, BC)
o_k = tl.arange(0, BK)
m_k = o_k < K
m_A = (i_t * BT + i_i * BC + o_i) < T_val
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
)
p_beta = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.load(p_beta, mask=m_A, other=0)[:, None]
b_g = tl.load(p_g, boundary_check=(0, 1))
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
for j in range(0, min(BC, T_val - i_t * BT - i_i * BC)):
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
b_A = tl.sum(b_k * b_kt[None, :] * tl.exp(b_g - b_gk[None, :]), 1)
# 转化成f32
o_i_tmp = o_i.to(tl.float32)
b_A = tl.where(o_i_tmp > j, b_A, 0.0)
tl.store(A + o_A + j, b_A, mask=m_A)
p_kt += H * K
p_gk += H * K
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
beta: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, H, K = k.shape
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
beta = beta.transpose(1, 2).contiguous()
g = g.transpose(1, 2).contiguous()
BK = 128
kernel_num = 24
if gk is None:
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(kernel_num,)](
k=k,
g=g,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BK=BK,
NT=NT,
B=B,
TOTAL_TASKS=B * NT,
)
return A
BC = min(16, BT)
NC = triton.cdiv(BT, BC)
BK = max(triton.next_power_of_2(K), 16)
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
grid = (NT, NC * NC, B)
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
k=k,
g=gk,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
NC=NC,
)
grid = (NT, NC, B)
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
k=k,
g=gk,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
BK=BK,
)
return A

View File

@@ -0,0 +1,147 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import prepare_chunk_indices
@triton.heuristics(
{"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None}
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_scalar_kernel(
s,
o,
scale,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BLOCK_T: tl.constexpr,
REVERSE: tl.constexpr,
HAS_SCALE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
CHUNK_SIZE: tl.constexpr = 64,
):
i_block, i_b = tl.program_id(0), tl.program_id(1)
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
if IS_VARLEN:
i_s, i_block = (
tl.load(chunk_indices + i_block * 2).to(tl.int32),
tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
b_s = tl.trans(b_s, (1, 0, 2))
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
if HAS_SCALE:
b_o *= scale
b_o = tl.trans(b_o, (1, 0, 2))
b_o = tl.reshape(b_o, (BLOCK_T, H))
tl.store(ptr_o, b_o.to(ptr_o.dtype.element_ty), boundary_check=(0,))
return
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
) -> torch.Tensor:
B, T, H = g.shape
if chunk_size != 2 ** (chunk_size.bit_length() - 1):
raise ValueError(f"chunk_size must be a power of 2, chunk_size is{chunk_size}")
# We adjust the tiling strategy to prevent overflow in in backward passes and context parallel scenarios
# while maximizing UB utilization where possible.
# The tiling strategy is as follows:
# 1. BT must be greater than or equal to chunk_size.
# 2. UB estimation varies directly with H.
# 3. BT in reverse mode is smaller than in forward mode.
BT = max(chunk_size, triton.next_power_of_2((1 << 11 if reverse else 1 << 12) // H))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B)
chunk_local_cumsum_scalar_kernel[grid](
s=g_org,
o=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
BLOCK_T=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
CHUNK_SIZE=chunk_size,
)
return g
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is not None:
if g.shape[0] != 1:
raise ValueError(
f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}"
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(
g=g,
chunk_size=chunk_size,
reverse=reverse,
scale=scale,
cu_seqlens=cu_seqlens,
head_first=head_first,
output_dtype=output_dtype,
)
else:
raise ValueError(
f"Unsupported input shape {g.shape}, "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise"
)

View File

@@ -0,0 +1,272 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import input_guard, make_tensor_descriptor, prepare_chunk_indices
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def solve_tril_16x16_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = pid_bh // H, pid_bh % H
base_t = pid_t * TPP
if IS_VARLEN:
i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_eff = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_eff = T
o_i = tl.arange(0, 16) # noqa: F841
o_i_fp32 = tl.arange(0, 16).to(tl.float32)
m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
m_I = o_i_fp32[:, None] == o_i_fp32[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * BT
for tpp in tl.static_range(0, TPP):
tile_t = base_t + tpp
tile_row = tile_t * 16
offset = (tile_t * 16) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0))
b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16])
b_A_raw = desc.load([tile_row, offset]).to(tl.float32)
b_A_neg = -b_A_raw
b_A = b_A_neg * m_A
for i in range(2, min(16, T_eff - tile_row)):
slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1])
b_a_val = tl.reshape(slice_res, (16,), can_reorder=True)
dot_prod = tl.sum(b_a_val[:, None] * b_A, 0)
b_a_update = b_a_val + dot_prod
b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
b_Ai_11 += m_I
b_Ai_22 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def solve_tril_64x64_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 64)
m_I = o_i[:, None] == o_i[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * 64
offset = (i_t * 64) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 64, offset), (64, 64), (1, 0))
b_A = -tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [64, 64])
desc_o = make_tensor_descriptor(Ai, [T, 64], [H * 64, 1], [64, 64])
b_A = -desc.load([i_t * 64, offset]).to(tl.float32)
for i in range(2, min(64, T - i_t * 64)):
b_a = -tl.load(A + (i_t * 64 + i) * H * BT + o_i + offset)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (64, 64), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * 64, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@input_guard
def solve_tril(
A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float
) -> torch.Tensor:
"""Compute the inverse of the matrix I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
(I + A)^-1 with the same shape as A
""" # noqa: D205
if A.shape[-1] not in [16, 32, 64]:
raise ValueError(f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}")
output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype)
if BT == 16:
merge_fn = solve_tril_16x16_kernel
elif BT == 32:
merge_fn = merge_16x16_to_32x32_inverse_kernel
elif BT == 64:
merge_fn = solve_tril_64x64_kernel
merge_fn[NT, B * H](
A=A,
Ai=Ai,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
USE_TMA=False,
DOT_PRECISION=FLA_TRIL_PRECISION,
)
return Ai

View File

@@ -0,0 +1,359 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import itertools
import logging
import os
import warnings
from collections.abc import Callable
from enum import Enum
from typing import Any, Optional
import torch
import triton
import triton.language as tl
import triton.language.extra.libdevice as tldevice
import triton.runtime.driver as driver
from packaging import version
logger = logging.getLogger(__name__)
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
def tensor_cache(fn: Optional[Callable[..., torch.Tensor]] = None, *, maxsize: int = 1) -> Any:
"""A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the outputs of the decorated function for the most recent
set of input tensors, up to `maxsize` entries. If the function is called again with
the same input tensors, it will return the cached result.
When maxsize=1 (default), the behavior is identical to caching only the most recent result.
Can be used as @tensor_cache or @tensor_cache(maxsize=n).
Args:
fn (Callable[..., torch.Tensor], optional):
The function to be decorated when used without parentheses.
maxsize (int):
Maximum number of input combinations to cache. Default is 1.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with caching.
"""
if maxsize < 1:
raise ValueError("maxsize must be at least 1")
def _is_match(a: Any, b: Any) -> bool:
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
return a is b
try:
return a == b
except Exception:
return a is b
def _make_wrapper(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
cache: list = []
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
for i, (cached_args, cached_kwargs, cached_result) in enumerate(cache):
if len(args) == len(cached_args) and len(kwargs) == len(cached_kwargs):
if all(_is_match(a, b) for a, b in zip(args, cached_args)) and all(
k in cached_kwargs and _is_match(v, cached_kwargs[k]) for k, v in kwargs.items()
):
if i != 0:
cache.insert(0, cache.pop(i))
return cached_result
result = fn(*args, **kwargs)
cache.insert(0, (args, kwargs, result))
if len(cache) > maxsize:
cache.pop()
return result
return wrapper
if fn is not None:
return _make_wrapper(fn)
return _make_wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache(maxsize=3)
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
def get_abs_err(x, y):
return (x.detach() - y.detach()).flatten().abs().max().item()
def get_err_ratio(x, y):
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
base = (x.detach()).flatten().square().mean().sqrt().item()
return err / (base + 1e-8)
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
abs_atol = get_abs_err(ref, tri)
msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
logger.info(msg)
error_rate = get_err_ratio(ref, tri)
if abs_atol <= err_atol:
return
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
if error_rate > ratio:
warnings.warn(msg)
else:
assert error_rate < ratio, msg
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
# For Triton 3.3.x
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
elif hasattr(triton.language, "make_tensor_descriptor"):
# For Triton 3.4.x and later
make_tensor_descriptor = triton.language.make_tensor_descriptor
else:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@triton.jit
def make_tensor_descriptor(
base,
shape,
strides,
block_shape,
_builder=None,
):
return None
@functools.cache
def get_available_device() -> str:
try:
return triton.runtime.driver.active.get_current_target().backend
except BaseException:
_cpu_device_warning()
return "cpu"
def map_triton_backend_to_torch_device() -> str:
backend = get_available_device() # 'cuda' | 'hip' | 'xpu' | 'cpu' | ...
return {"cuda": "cuda", "hip": "cuda", "xpu": "xpu"}.get(backend, backend)
device = get_available_device() if get_available_device() != "hip" else "cuda"
device_torch_lib = getattr(torch, device)
device_platform = get_available_device()
is_amd = device_platform == "hip"
is_nvidia = device_platform == "cuda"
is_nvidia_hopper = is_nvidia and (
"NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
)
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
is_tma_supported = (
(is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9)
and os.environ.get("FLA_NO_USE_TMA", "0") != "1"
and (
hasattr(triton.language, "_experimental_make_tensor_descriptor")
or hasattr(triton.language, "make_tensor_descriptor")
)
)
if is_nvidia and not is_tf32_supported:
# Make old card happy, since triton will use tf32 by default.
# This is a workaround for old nvidia card.
os.environ["TRITON_F32_DEFAULT"] = "ieee"
@functools.cache
def check_pytorch_version(version_s: str = "2.4") -> bool:
return version.parse(torch.__version__) >= version.parse(version_s)
if check_pytorch_version("2.4"):
device = "cuda" if device == "cpu" else device
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
def custom_device_ctx(index: int):
return device_torch_lib.device(index)
else:
assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0."
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
def custom_device_ctx(index: int):
return torch.cuda.device(index)
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""A decorator to make sure all input tensors are contiguous and set the device based on input tensors."""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = arg
break
if tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
tensor = value
break
if tensor is not None:
ctx = custom_device_ctx(tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return fn(*contiguous_args, **contiguous_kwargs)
return wrapper
def _cpu_device_warning():
warnings.warn(("Triton is not supported on current platform, roll back to CPU."), stacklevel=1)
@tensor_cache
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
exp = tldevice.fast_expf
exp2 = tldevice.exp2
log = tldevice.fast_logf
log2 = tldevice.fast_log2f
else:
exp = tl.exp
exp2 = tl.math.exp2
log = tl.log
log2 = tl.log2
def get_all_max_shared_mem():
try:
return [
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
for i in range(device_torch_lib.device_count())
]
except BaseException:
_cpu_device_warning()
return [-1]
class Backend(Enum):
ADA = 101376 # RTX 4090
AMPERE = 166912 # A100
HOPPER = 232448 # H100
DEFAULT = 102400 # Default
@classmethod
def get_shared_memory(cls, arch: str) -> int:
try:
return cls[arch.upper()].value
except KeyError:
return cls.DEFAULT.value
@functools.cache
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
try:
device_shared_mem_list = get_all_max_shared_mem()
max_shared_memory = device_shared_mem_list[tensor_idx]
return max_shared_memory >= Backend.get_shared_memory(arch)
except Exception:
return False
def get_autotune_config(
multibuffer_list: tuple = (False,),
unit_flag_list: tuple = (False,),
limit_auto_multi_buffer_only_for_local_buffer_list: tuple = (False,),
limit_auto_multi_buffer_of_local_buffer_list: tuple = ("no-l0c",),
set_workspace_multibuffer_list: tuple = (2, 4),
enable_hivm_auto_cv_balance_list: tuple = (True,),
tile_mix_vector_loop_num_list: tuple = (2, 4),
tile_mix_cube_loop_num_list: tuple = (2, 4),
):
configs = []
for (
multibuffer,
unit_flag,
limit_auto_multi_buffer_only_for_local_buffer,
limit_auto_multi_buffer_of_local_buffer,
) in itertools.product(
list(multibuffer_list),
list(unit_flag_list),
list(limit_auto_multi_buffer_only_for_local_buffer_list),
list(limit_auto_multi_buffer_of_local_buffer_list),
):
base_config_dict = {
"multibuffer": multibuffer,
"unit_flag": unit_flag,
"limit_auto_multi_buffer_only_for_local_buffer": limit_auto_multi_buffer_only_for_local_buffer,
"limit_auto_multi_buffer_of_local_buffer": limit_auto_multi_buffer_of_local_buffer,
}
if limit_auto_multi_buffer_only_for_local_buffer:
configs.append(triton.Config(base_config_dict))
else:
for (
set_workspace_multibuffer,
enable_hivm_auto_cv_balance,
tile_mix_vector_loop,
tile_mix_cube_loop,
) in itertools.product(
list(set_workspace_multibuffer_list),
list(enable_hivm_auto_cv_balance_list),
list(tile_mix_vector_loop_num_list),
list(tile_mix_cube_loop_num_list),
):
full_config_dict = base_config_dict.copy()
full_config_dict.update(
{
"set_workspace_multibuffer": set_workspace_multibuffer,
"enable_hivm_auto_cv_balance": enable_hivm_auto_cv_balance,
"tile_mix_vector_loop": tile_mix_vector_loop,
"tile_mix_cube_loop": tile_mix_cube_loop,
}
)
configs.append(triton.Config(full_config_dict))
return configs
def get_npu_properties():
return driver.active.utils.get_device_properties(torch.npu.current_device())

View File

@@ -0,0 +1,387 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import exp, prepare_chunk_indices
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def prepare_wy_repr_bwd_kernel(
k,
v,
beta,
g,
A,
dw,
du,
dk,
dv,
dbeta,
dg,
cu_seqlens,
chunk_indices,
T,
B,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
NT: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
T_max = T
base_chunks_per_pid = NT // total_cores
remainder_chunks = NT % total_cores
if core_id < remainder_chunks:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * chunks_this_pid + remainder_chunks
for idx in range(start_idx, start_idx + chunks_this_pid):
for i_b in range(B):
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + idx * 2).to(tl.int32),
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
i_t = idx
bos, eos = i_b * T, i_b * T + T
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
for i_h in range(0, H):
if IS_VARLEN:
offset = bos + i_h * T_max
else:
offset = bos * H + i_h * T_max
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_exp = tl.exp(b_g)
b_dbeta = tl.zeros([BT], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
b_dg = tl.zeros([BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dk = tl.make_block_ptr(
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dw = tl.make_block_ptr(
dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_k_beta_g = (b_k * b_beta[:, None] * b_g_exp[:, None]).to(b_k.dtype)
b_dw = tl.load(p_dw, boundary_check=(0, 1))
b_dA += tl.dot(b_dw, tl.trans(b_k_beta_g))
b_dk_beta_g = tl.dot(b_A, b_dw)
b_dk = b_dk_beta_g * b_beta[:, None] * b_g_exp[:, None]
b_dbeta += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None], 1)
b_dg += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None] * b_beta[:, None], 1)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_dv = tl.make_block_ptr(
dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_du = tl.make_block_ptr(
du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
b_du = tl.load(p_du, boundary_check=(0, 1))
b_dA += tl.dot(b_du, tl.trans(b_v_beta))
b_dv_beta = tl.dot(b_A, b_du)
b_dv = b_dv_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
b_dA = tl.where(m_A, b_dA, 0)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA = tl.where(m_A, -b_dA * exp(b_g[:, None] - b_g[None, :]), 0)
b_dA = b_dA.to(k.dtype.element_ty)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dk = tl.make_block_ptr(
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.load(p_dk, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_A += tl.dot(b_k_beta, tl.trans(b_k))
b_dk_beta = tl.dot(b_dA, b_k)
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
b_dk += tl.dot(tl.trans(b_dA), b_k_beta)
b_dk += b_dk_beta * b_beta[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
b_dA_A = b_dA * b_A
b_dg += tl.sum(b_dA_A, axis=1) - tl.sum(b_dA_A, axis=0)
p_dg = tl.make_block_ptr(dg + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_dbeta = tl.make_block_ptr(dbeta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
gk,
cu_seqlens,
chunk_indices,
T_tmp,
B,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
NT: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
T_max = T_tmp
base_chunks_per_pid = NT // total_cores
remainder_chunks = NT % total_cores
if core_id < remainder_chunks:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * chunks_this_pid + remainder_chunks
for idx in range(start_idx, start_idx + chunks_this_pid):
for i_b in range(B):
for i_h in range(0, H):
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + idx * 2).to(tl.int32),
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
offset = bos + i_h * T_max
T = eos - bos
else:
T = T_tmp
i_t = idx
bos, eos = i_b * T, i_b * T + T
offset = bos * H + i_h * T_max
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
if USE_G:
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
if USE_G:
b_kb *= b_g[:, None]
if USE_GK:
p_gk = tl.make_block_ptr(
gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_kb *= tl.exp(tl.load(p_gk, boundary_check=(0, 1)))
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 128
BV = 128
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g = g.transpose(1, 2).contiguous() if g is not None else None
beta = beta.transpose(1, 2).contiguous()
w = torch.empty_like(k)
u = torch.empty_like(v)
cv_kernel_num = 24
recompute_w_u_fwd_kernel[(cv_kernel_num,)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g,
gk=gk,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T_tmp=T,
B=B,
H=H,
K=K,
V=V,
NT=NT,
BT=BT,
BK=BK,
BV=BV,
)
return w, u
def prepare_wy_repr_bwd(
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
dw: torch.Tensor,
du: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
chunk_size: int = 64,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 128
BV = 128
beta = beta.transpose(1, 2).contiguous()
g = g.transpose(1, 2).contiguous()
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dbeta = torch.empty_like(beta)
dg = torch.empty_like(g)
cv_kernel_num = 24
prepare_wy_repr_bwd_kernel[(cv_kernel_num,)](
k=k,
v=v,
beta=beta,
g=g,
A=A,
dw=dw,
du=du,
dk=dk,
dv=dv,
dbeta=dbeta,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
K=K,
V=V,
NT=NT,
BT=BT,
BK=BK,
BV=BV,
)
dbeta = dbeta.transpose(1, 2).contiguous()
dg = dg.transpose(1, 2).contiguous()
return dk, dv, dbeta, dg
bwd_prepare_wy_repr = prepare_wy_repr_bwd
fwd_recompute_w_u = recompute_w_u_fwd

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_sft
from .workflow import run_pt, run_sft
__all__ = ["run_sft"]
__all__ = ["run_pt", "run_sft"]

View File

@@ -0,0 +1,222 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperParallel distributed trainer for LlamaFactory."""
import logging
import os
import types
from contextlib import nullcontext
from typing import Any, Optional
import torch
from hyper_parallel.integration.llamafactory import (
HSDPModule,
HyperParallelArguments,
export_to_hf_format,
fsdp2_prepare_model,
hsdp_sync_stream,
load_hsdp_model,
load_hsdp_optimizer_and_scheduler,
save_hsdp_checkpoint,
wrap_optimizer_with_skip_dtensor_dispatch,
)
from hyper_parallel.integration.llamafactory import (
clip_grad_norm_ as hp_clip_grad_norm_,
)
from torch import nn
from ..sft.trainer import CustomSeq2SeqTrainer
logger = logging.getLogger(__name__)
class HyperParallelTrainer(CustomSeq2SeqTrainer):
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
Inherits CustomSeq2SeqTrainer for training algorithm logic (loss, metrics,
prediction, sampler, etc.) and only overrides HSDP-specific behavior.
"""
def __init__(
self,
hp_args: HyperParallelArguments,
finetuning_args=None,
processor=None,
ref_model: Optional[nn.Module] = None,
**kwargs,
):
self._hp_args = hp_args
# Let CustomSeq2SeqTrainer handle everything except ref_model —
# Custom would prepare it with accelerate's fsdp2_prepare_model,
# but we need HP's version instead.
super().__init__(
finetuning_args=finetuning_args,
processor=processor,
ref_model=None,
**kwargs,
)
if not getattr(self.accelerator, "is_fsdp2", False):
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
# Prepare ref_model with HP's fsdp2_prepare_model
self.ref_model = ref_model
if self.ref_model is not None:
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
self._orig_fsdp2_prepare_model = None
self._accelerator_patches_active = False
def _activate_accelerator_patches(self) -> None:
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
if self._accelerator_patches_active:
return
import accelerate.accelerator as acc_module # pylint: disable=C0415
hp_args = self._hp_args
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
def _hp_fsdp2_prepare_model(accelerator, model):
return fsdp2_prepare_model(accelerator, model, hp_args)
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2):
if getattr(accelerator, "is_fsdp2", False):
accelerator.unscale_gradients()
parameter_list = list(parameters)
parameter_ids = {id(param) for param in parameter_list}
for model in accelerator._models: # pylint: disable=protected-access
if not isinstance(model, HSDPModule):
continue
model_param_ids = {id(param) for param in model.parameters()}
if parameter_ids and parameter_ids.issubset(model_param_ids):
return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type)
return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator)
self._accelerator_patches_active = True
def _restore_accelerator_patches(self) -> None:
"""Restore original Accelerate methods."""
if not self._accelerator_patches_active:
return
import accelerate.accelerator as acc_module # pylint: disable=C0415
if self._orig_fsdp2_prepare_model is not None:
acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model
self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm
self._accelerator_patches_active = False
def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module:
"""Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct."""
del dataloader
if isinstance(model, HSDPModule):
return model
if training and getattr(self.accelerator, "is_fsdp2", False):
return model
return super()._wrap_model(model, training=training)
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
"""Skip redundant device moves for HSDP-wrapped models."""
if isinstance(model, HSDPModule):
return model
if device is None:
return model
return model.to(device)
def train(self, *args, **kwargs):
"""Activate HP patches during training and restore afterwards."""
self._activate_accelerator_patches()
try:
return super().train(*args, **kwargs)
finally:
self._restore_accelerator_patches()
def training_step(
self,
model: nn.Module,
inputs: dict[str, Any],
num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
"""Standard training step with HSDP gradient synchronization."""
model.train()
inputs = self._prepare_inputs(inputs)
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
if isinstance(model, HSDPModule):
model.set_is_last_backward(sync_gradients)
model.set_requires_gradient_sync(sync_gradients)
compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext)
with compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.args.n_gpu > 1:
loss = loss.mean()
if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None:
loss = loss / self.args.gradient_accumulation_steps
self.accelerator.backward(loss)
if isinstance(model, HSDPModule) and sync_gradients:
hsdp_sync_stream()
return loss.detach()
def create_optimizer(self):
"""Create optimizer and wrap step with SkipDTensorDispatch."""
optimizer = super().create_optimizer()
wrap_optimizer_with_skip_dtensor_dispatch(optimizer)
return optimizer
def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
"""Save model/optimizer shards per-rank and scheduler."""
save_hsdp_checkpoint(
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
output_dir=output_dir,
should_save_scheduler=self.args.should_save and self.lr_scheduler is not None,
)
def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None:
"""Load model from HSDP sharded checkpoint."""
target = model if model is not None else self.model
loaded = load_hsdp_model(target, resume_from_checkpoint)
if not loaded:
return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
self._pending_hsdp_checkpoint = resume_from_checkpoint
return None
def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None:
"""Load optimizer/scheduler from per-rank checkpoint files."""
ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint
if ckpt_dir is None:
return
load_hsdp_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, ckpt_dir)
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""Save model weights in HuggingFace-compatible format."""
save_dir = output_dir or self.args.output_dir
os.makedirs(save_dir, exist_ok=True)
export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir)

View File

@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional
from transformers import DataCollatorForLanguageModeling
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
@@ -21,9 +24,9 @@ from ...extras.misc import calculate_tps
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import SaveProcessorCallback
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import HyperParallelTrainer
if TYPE_CHECKING:
@@ -35,6 +38,90 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelArguments"):
r"""Load HyperParallel arguments and apply LlamaFactory-side overrides.
When activation optimization is enabled, skip native gradient checkpointing
so HP can install its own via ``setup_activation_optimization``.
"""
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
if hp_args.activation_mode != "none":
model_args.disable_gradient_checkpointing = True
return hp_args
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
hp_args = _prepare_hp_args(finetuning_args, model_args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**tokenizer_module,
)
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -43,13 +130,7 @@ def run_sft(
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
HyperParallelArguments,
HyperParallelTrainer,
)
hp_args = _prepare_hp_args(finetuning_args, model_args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
@@ -94,25 +175,6 @@ def run_sft(
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
callbacks = list(callbacks or [])
processor = tokenizer_module.get("processor")
if processor is not None:
callbacks.append(SaveProcessorCallback(processor))
compute_loss_func = None
if finetuning_args.use_dft_loss:
compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss:
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
elif finetuning_args.use_asft_loss:
from functools import partial
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
@@ -122,20 +184,11 @@ def run_sft(
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
compute_loss_func=compute_loss_func,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if finetuning_args.use_badam:
from types import MethodType
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
trainer.add_callback(BAdamCallback)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

View File

@@ -88,9 +88,16 @@ def _training_function(config: dict[str, Any]) -> None:
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel:
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
if finetuning_args.stage == "pt":
from .hyper_parallel import run_pt as run_pt_hp
run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks)
else:
from .hyper_parallel import run_sft as run_sft_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)

View File

@@ -0,0 +1,149 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from llamafactory.model.model_utils.embedding import (
_description_based_initialization,
_existing_embeddings,
_noisy_mean_initialization,
_resolve_new_token_ids,
)
class _StubTokenizer:
"""Minimal tokenizer stub mapping token strings to fixed IDs."""
unk_token_id = 0
def __init__(self, mapping: dict[str, int], desc_ids: list[int] | None = None):
self._mapping = mapping
self._desc_ids = desc_ids or []
def convert_tokens_to_ids(self, token: str) -> int:
return self._mapping.get(token, self.unk_token_id)
def __call__(self, desc, return_tensors=None, add_special_tokens=False):
return {"input_ids": torch.tensor([self._desc_ids], dtype=torch.long)}
class _StubModel:
"""Wraps an embedding matrix so ``get_input_embeddings()`` is a usable lookup."""
def __init__(self, embed_weight: "torch.Tensor"):
self._emb = torch.nn.Embedding.from_pretrained(embed_weight.clone(), freeze=True)
def get_input_embeddings(self):
return self._emb
def test_resolve_new_token_ids_returns_none_without_config():
tokenizer = _StubTokenizer({})
assert _resolve_new_token_ids(None, tokenizer, embed_size=100) is None
assert _resolve_new_token_ids([], tokenizer, embed_size=100) is None
def test_resolve_new_token_ids_filters_invalid_and_dedups():
# "<a>" valid, "<unk_like>" maps to unk_token_id (skipped), "<oob>" out of range (skipped)
tokenizer = _StubTokenizer({"<a>": 10, "<unk_like>": 0, "<oob>": 999, "<b>": 5})
# duplicates and unsorted input -> sorted unique in-range IDs
tokens = ["<a>", "<a>", "<unk_like>", "<oob>", "<b>"]
assert _resolve_new_token_ids(tokens, tokenizer, embed_size=100) == [5, 10]
# passing a dict iterates its keys (config compatibility)
assert _resolve_new_token_ids({"<a>": "desc"}, tokenizer, embed_size=100) == [10]
def test_existing_embeddings_excludes_new_token_ids():
embed_weight = torch.arange(10 * 2, dtype=torch.float32).reshape(10, 2)
# explicit ids take precedence and drop exactly those rows
existing = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=[2, 5])
assert existing.size(0) == 8
# tail fallback when no explicit ids
tail = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=None)
assert torch.allclose(tail, embed_weight[:-3])
# no resize and no ids -> use everything
everything = _existing_embeddings(embed_weight, num_new_tokens=0, new_token_ids=None)
assert torch.allclose(everything, embed_weight)
def test_noisy_mean_initialization_with_token_ids_targets_exact_rows():
"""New tokens placed by explicit IDs must hit those rows, even inside the padding zone."""
torch.manual_seed(0)
vocab_size, embedding_dim = 20, 8
embed_weight = torch.zeros(vocab_size, embedding_dim)
# existing rows carry a constant so the mean is well-defined and non-zero
embed_weight[:16] = 1.0
# num_new_tokens reflects the embedding resize delta (4 padded rows),
# but the real new tokens sit at IDs 16 and 17 (inside what the tail slice would miss/over-cover).
target_ids = [16, 17]
_noisy_mean_initialization(embed_weight, num_new_tokens=4, token_ids=target_ids)
# targeted rows are initialized around the mean (~1.0) and not left at zero
for tid in target_ids:
assert not torch.allclose(embed_weight[tid], torch.zeros(embedding_dim))
assert abs(embed_weight[tid].mean().item() - 1.0) < 0.5
# untouched padding rows (18, 19) must remain zero
assert torch.allclose(embed_weight[18], torch.zeros(embedding_dim))
assert torch.allclose(embed_weight[19], torch.zeros(embedding_dim))
def test_noisy_mean_initialization_tail_fallback():
"""Without token_ids, falls back to the last num_new_tokens rows."""
torch.manual_seed(0)
vocab_size, embedding_dim = 12, 8
embed_weight = torch.zeros(vocab_size, embedding_dim)
embed_weight[:10] = 1.0
_noisy_mean_initialization(embed_weight, num_new_tokens=2, token_ids=None)
# last two rows initialized, earlier rows untouched
assert not torch.allclose(embed_weight[-1], torch.zeros(embedding_dim))
assert not torch.allclose(embed_weight[-2], torch.zeros(embedding_dim))
assert torch.allclose(embed_weight[0], torch.ones(embedding_dim))
def test_description_init_excludes_new_token_ids_from_average():
"""Description tokens that are themselves new (uninitialized) must be excluded.
Reproduces the padding-zone bug: id 17 is a new token and must not pollute the
semantic average for id 16; only the valid existing token (id 5) should be used.
"""
vocab_size, embedding_dim = 20, 4
embed_weight = torch.zeros(vocab_size, embedding_dim)
embed_weight[5] = 3.0 # the only valid description token
# description for "<x>" tokenizes to [5 (existing), 17 (new -> must be skipped)]
tokenizer = _StubTokenizer({"<x>": 16}, desc_ids=[5, 17])
model = _StubModel(embed_weight)
_description_based_initialization(
embed_weight,
num_new_tokens=4,
descriptions={"<x>": "ignored, ids come from the stub"},
tokenizer=tokenizer,
model=model,
new_token_ids=[16, 17],
add_noise=False,
)
# row 16 must equal embedding of id 5 only (3.0), not the (5,17) average (1.5)
assert torch.allclose(embed_weight[16], torch.full((embedding_dim,), 3.0))
if __name__ == "__main__":
import pytest
pytest.main([__file__])