19 Commits

Author SHA1 Message Date
jiaqiw09
8669a22e9c [fix] fix liger kernel patch for npu (#10583) 2026-06-16 18:21:52 +08:00
Hao Liang
897a44386c [docs] add DataFlow and DataFlex blog tutorials (#10582)
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-16 14:20:36 +08:00
jiaqiw09
7a1e9630f2 [fix] update ascend doc link (#10572) 2026-06-15 13:55:53 +08:00
souljoy
cabe59a343 [model] add MiniCPM5-1B-Chat (#10558) 2026-06-10 16:18:27 +08:00
Co-Cl2
9ca4026efe [model] handle unsloth model loading fallback during checkpoint resume (#7156) (#10551) 2026-06-09 01:01:01 +08:00
Ximing Xing
0b7aaf8f6a [fix] correctly place new token embeddings when embedding is padded (#10547) 2026-06-05 10:47:51 +08:00
codingma
8a4f6a3da5 [model] add gemma-4-12B-it (#10549) 2026-06-04 23:43:20 +08:00
A1waysBeenHere
409e8a477f [model] Patch GDN for NPU (#10504)
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
2026-06-04 16:39:02 +08:00
Cui-yshoho
053d43c0ac [feat] support HyperParallel PT training and activation optimization (#10370) 2026-06-02 22:39:32 +08:00
Zhao73
a98a1ef101 [docs] fix README citation typo (#10540) 2026-06-01 21:04:53 +08:00
Yaowei Zheng
8ef7335b6a [misc] set dev version (#10533) 2026-05-31 00:16:07 +08:00
Yaowei Zheng
7af909522a [version] release v0.9.5 (#10532) 2026-05-30 23:57:09 +08:00
xvxuopop
e016d2480e [fix] Fix NPU FusedMoE and RMSNorm (#10512) 2026-05-30 21:42:54 +08:00
jiaqiw09
7d719182c9 [model] fix non-packing batch (bsz>1) for Qwen3.5 with flash attention (#10529) 2026-05-30 21:41:41 +08:00
jiaqiw09
01398eb18d [v1] fix padding free with sp (#10513) 2026-05-26 23:49:21 +08:00
cxy
8e68764b65 [v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
2026-05-25 20:40:21 +08:00
Copilot
16ff5a23cb [fix] use getattr for profiler attrs to support MCA TrainingArguments (#10506)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2026-05-21 17:26:29 +08:00
jiaqiw09
bdcb92d035 [v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469) 2026-05-21 17:14:19 +08:00
sunyi0505
7e20db5735 [v1] support liger_kernel (#10493) 2026-05-21 11:44:56 +08:00
62 changed files with 5069 additions and 292 deletions

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
\[ English | [中文](README_zh.md) \]
@@ -52,14 +50,12 @@ Start local training:
Start cloud training:
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
Read technical notes:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
- **Documentation (ASCEND NPU)**: https://llamafactory.readthedocs.io/en/latest/multibackend/npu/index.html
- **Official Blog**: https://blog.llamafactory.net/en/
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -78,7 +74,6 @@ Read technical notes:
- [Data Preparation](#data-preparation)
- [Quickstart](#quickstart)
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
- [LLaMA Factory Online](#llama-factory-online)
- [Build Docker](#build-docker)
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
- [Download from ModelScope Hub](#download-from-modelscope-hub)
@@ -117,15 +112,13 @@ Read technical notes:
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
- 💡 [DataFlow × LLaMA Factory: Producing High-Quality Data for LLM Training with a Data Preparation Pipeline](https://wcny4qa9krto.feishu.cn/wiki/LWkkwTDBfiiRKqkDSvucG6yjnbW) (English) | [中文](https://wcny4qa9krto.feishu.cn/wiki/LlMxweUAJimrmykRD5qcGuswnHd)
- 💡 [DataFlex × LLaMA Factory: A Data-Centric Dynamic Training System Built on LLaMA-Factory](https://wcny4qa9krto.feishu.cn/wiki/OlREwPQWdi9K6ZkJNHIciLhtnkv) (English) | [中文](https://wcny4qa9krto.feishu.cn/wiki/H2A9wSsbCinzavkT2oyc2C5Vn0e)
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
<details><summary>All Blogs</summary>
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
@@ -661,10 +654,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
llamafactory-cli webui
```
### LLaMA Factory Online
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
### Build Docker
For CUDA users:
@@ -838,7 +827,7 @@ If you have a project that should be incorporated, please contact via email or c
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collapse. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
\[ [English](README.md) | 中文 \]
@@ -52,16 +50,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
开始云端训练:
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online在线微调**https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
阅读技术文档:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **微调视频教程**https://www.bilibili.com/video/BV1djgRzxEts/
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **框架文档(昇腾 NPU**https://llamafactory.readthedocs.io/zh-cn/latest/multibackend/npu/index.html
- **官方博客**https://blog.llamafactory.net/
- **官方课程**https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [数据准备](#数据准备)
- [快速开始](#快速开始)
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
- [构建 Docker](#构建-docker)
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
- [从魔搭社区下载](#从魔搭社区下载)
@@ -119,15 +113,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
- 💡 [DataFlow × LLaMA Factory: 利用数据准备流水线产出高质量数据训练 LLM](https://wcny4qa9krto.feishu.cn/wiki/LlMxweUAJimrmykRD5qcGuswnHd)(中文)| [English](https://wcny4qa9krto.feishu.cn/wiki/LWkkwTDBfiiRKqkDSvucG6yjnbW)
- 💡 [DataFlex × LLaMA Factory: 构建在 LLaMA-Factory 之上的以数据为中心的动态训练系统](https://wcny4qa9krto.feishu.cn/wiki/H2A9wSsbCinzavkT2oyc2C5Vn0e)(中文)| [English](https://wcny4qa9krto.feishu.cn/wiki/OlREwPQWdi9K6ZkJNHIciLhtnkv)
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
<details><summary>全部博客</summary>
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
- [LLaMA Factory微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
@@ -662,10 +654,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
llamafactory-cli webui
```
### LLaMA Factory Online 在线微调
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
### 构建 Docker
CUDA 用户:
@@ -842,7 +830,7 @@ swanlab_run_name: test_run # 可选
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collapse. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)

View File

@@ -36,6 +36,7 @@ COPY . /app
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
RUN pip uninstall -y torch torchvision torchaudio
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
RUN pip install --no-cache-dir -r requirements/triton_ascend.txt
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation

View File

@@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer,Qwen3_5MoeVisionBlock
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
rdzv_backend: static
same_network: true
use_cpu: false

View File

@@ -0,0 +1,51 @@
# Start FSDP2 full fine-tuning on Ascend NPU
# Usage:
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config_qwen35_moe.yaml \
# src/train.py examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
#
# Note: Change `num_processes` in fsdp2_config_qwen35_moe.yaml to match your NPU count
### model
model_name_or_path: Qwen/Qwen3.5-35B-A3B
trust_remote_code: true
use_v1_kernels: false
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: alpaca_en_demo
template: qwen3_5_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
packing: false
### output
output_dir: saves/Qwen3.5-35B/lora/sft
logging_steps: 1
save_steps: 2000
max_steps: 2000
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null
disable_gradient_checkpointing: true

View File

@@ -0,0 +1,31 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: normal
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: dynamic_batching
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: dynamic_padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,28 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: liger_kernel
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,2 @@
--extra-index-url https://triton-ascend.osinfra.cn/pypi/simple
triton-ascend==3.2.1

View File

@@ -886,6 +886,9 @@ register_model_group(
"Gemma-4-E4B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
},
"Gemma-4-12B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-12B-it",
},
},
template="gemma4n",
multimodal=True,
@@ -1912,6 +1915,17 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM5-1B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM5-1B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM5-1B",
},
},
template="empty",
)
register_model_group(
models={
"MiniCPM-o-2.6": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.5.dev0"
VERSION = "0.9.6.dev0"
def print_env() -> None:

View File

@@ -487,7 +487,7 @@ class FinetuningArguments(
metadata={
"help": (
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
"Only supported for the 'sft' stage with full fine-tuning."
"Only supported for the 'pt' and 'sft' stages with full fine-tuning."
)
},
)

View File

@@ -194,10 +194,16 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if isinstance(model, PeftModel):
pass # already loaded via load_unsloth_peft_model in loader.py
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
if model_args.use_unsloth:
peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if peft_model is not None:
model = peft_model
if not model_args.use_unsloth: # unsloth was disabled or fell back
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))

View File

@@ -34,7 +34,7 @@ from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
@@ -142,14 +142,13 @@ def load_model(
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load:
if model is None:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
@@ -176,9 +175,8 @@ def load_model(
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections.abc import Iterable
from contextlib import nullcontext
from typing import TYPE_CHECKING, Optional
@@ -29,7 +30,81 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
def get_embedding_vocab_size(model: "PreTrainedModel") -> int:
r"""Get the vocab size from the input embedding layer.
Handles DeepSpeed ZeRO-3 parameter sharding by gathering the embedding weight
before reading its size.
"""
embedding = model.get_input_embeddings()
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
with deepspeed.zero.GatheredParameters([embedding.weight]):
return embedding.weight.size(0)
return embedding.weight.size(0)
def _resolve_new_token_ids(
new_tokens: Optional[Iterable[str]],
tokenizer: "PreTrainedTokenizer",
embed_size: int,
) -> Optional[list[int]]:
r"""Resolve the explicit embedding-row IDs of the newly added tokens.
Relying on ``embed_weight[-num_new_tokens:]`` to locate new tokens is unsafe when
the model embedding was already padded beyond the tokenizer vocab (e.g. Qwen2.5-VL
has vocab 151665 but embedding 151936). In that case the appended tokens land
inside the original padding zone and the tail slice points at the wrong rows.
Args:
new_tokens: Iterable of the newly added token strings.
tokenizer: The tokenizer instance.
embed_size: Current embedding size (upper bound for valid token IDs).
Returns:
A sorted list of unique, in-range token IDs, or ``None`` when no tokens are
given so that callers can fall back to the tail-slice behaviour.
"""
if not new_tokens:
return None
unk_token_id = getattr(tokenizer, "unk_token_id", None)
token_ids: set[int] = set()
for token_str in new_tokens:
token_id = tokenizer.convert_tokens_to_ids(token_str)
if token_id is None or token_id == unk_token_id or not (0 <= token_id < embed_size):
logger.warning_rank0(f"Token '{token_str}' not found or out of range, skipping during init.")
continue
token_ids.add(token_id)
return sorted(token_ids) or None
def _existing_embeddings(
embed_weight: "torch.Tensor", num_new_tokens: int, new_token_ids: Optional[list[int]]
) -> "torch.Tensor":
"""Return the rows treated as 'existing' embeddings used as the init baseline.
Prefers excluding the explicit new-token rows (robust to padding). Falls back to
dropping the last ``num_new_tokens`` rows when no explicit IDs are available.
"""
if new_token_ids:
mask = torch.ones(embed_weight.size(0), dtype=torch.bool, device=embed_weight.device)
mask[torch.as_tensor(new_token_ids, device=embed_weight.device, dtype=torch.long)] = False
return embed_weight[mask]
if num_new_tokens > 0:
return embed_weight[:-num_new_tokens]
return embed_weight
def _noisy_mean_initialization(
embed_weight: "torch.Tensor", num_new_tokens: int, token_ids: Optional[list[int]] = None
) -> None:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
@@ -37,12 +112,23 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
token_ids: Explicit token IDs to initialize. When provided, these exact rows are
written (robust to padding). When ``None``, falls back to the last
``num_new_tokens`` rows.
"""
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
avg_weight = _existing_embeddings(embed_weight, num_new_tokens, token_ids).mean(dim=0, keepdim=True)
if token_ids:
noise_weight = torch.empty(
len(token_ids), embedding_dim, device=embed_weight.device, dtype=embed_weight.dtype
)
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[token_ids] = avg_weight + noise_weight
else:
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _description_based_initialization(
@@ -51,6 +137,7 @@ def _description_based_initialization(
descriptions: dict[str, str],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
new_token_ids: Optional[list[int]] = None,
add_noise: bool = False,
) -> None:
"""Initialize new token embeddings based on textual descriptions.
@@ -61,6 +148,9 @@ def _description_based_initialization(
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
New tokens are placed by their resolved token ID rather than by tail slicing,
so the initialization is correct even when the embedding matrix was padded.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
@@ -68,6 +158,8 @@ def _description_based_initialization(
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
new_token_ids: IDs of all newly added tokens. Used to exclude not-yet-initialized
rows when averaging description-token embeddings (robust to embedding padding).
add_noise: Whether to add Gaussian noise to the initialization
Example:
@@ -77,38 +169,55 @@ def _description_based_initialization(
}
"""
embedding_dim = embed_weight.size(1)
vocab_size = embed_weight.size(0)
unk_token_id = getattr(tokenizer, "unk_token_id", None)
device = embed_weight.device
# The set of rows that are NOT yet initialized (the newly added tokens). Description
# tokens that fall into this set must be excluded, otherwise we would average garbage.
# `num_new_tokens` (the padded resize delta) is NOT a reliable boundary, so rely on
# the explicit IDs, falling back to resolving them from the description keys.
if new_token_ids is None:
new_token_ids = _resolve_new_token_ids(descriptions.keys(), tokenizer, vocab_size)
new_id_set = set(new_token_ids or [])
fallback_embedding = _existing_embeddings(embed_weight, num_new_tokens, new_token_ids).mean(dim=0)
for token_str, desc in descriptions.items():
# Resolve token ID for correct placement (robust to embedding padding)
token_id = tokenizer.convert_tokens_to_ids(token_str)
if token_id is None or token_id == unk_token_id or not (0 <= token_id < vocab_size):
logger.warning_rank0(f"desc_init: token '{token_str}' not found or out of range, skipping.")
continue
for i, desc in enumerate(descriptions.values()):
# Tokenize description text
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
token_ids = tokens["input_ids"][0]
# Move to the same device as embed_weight
device = embed_weight.device
token_ids = token_ids.to(device)
token_ids = tokens["input_ids"][0].tolist()
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
# Keep only description tokens that already have a meaningful embedding.
valid_token_ids = [tid for tid in token_ids if tid not in new_id_set and 0 <= tid < vocab_size]
if len(valid_token_ids) == 0:
# Fallback: use mean of all existing embeddings
logger.warning_rank0(
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
f"Description for token '{token_str}' contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
base_embedding = fallback_embedding
else:
# Get embeddings of description tokens and average them
token_embeds = model.get_input_embeddings()(valid_token_ids)
valid_ids_tensor = torch.as_tensor(valid_token_ids, device=device, dtype=torch.long)
token_embeds = model.get_input_embeddings()(valid_ids_tensor)
base_embedding = token_embeds.mean(dim=0)
# Add noise if requested (ensure correct device and dtype)
if add_noise:
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
embed_weight[-num_new_tokens + i] = base_embedding + noise
embed_weight[token_id] = base_embedding + noise
else:
embed_weight[-num_new_tokens + i] = base_embedding
embed_weight[token_id] = base_embedding
def _initialize_embeddings(
@@ -118,6 +227,7 @@ def _initialize_embeddings(
new_special_tokens_config: Optional[dict],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
new_token_ids: Optional[list[int]] = None,
) -> None:
"""Single source of truth for embedding initialization.
@@ -130,16 +240,18 @@ def _initialize_embeddings(
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
new_token_ids: Explicit IDs of the newly added tokens (robust to embedding padding).
When ``None``, the init helpers fall back to the last ``num_new_tokens`` rows.
"""
if init_method == "desc_init" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=False
)
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=True
)
else:
if init_method != "noise_init":
@@ -147,20 +259,28 @@ def _initialize_embeddings(
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
)
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
_noisy_mean_initialization(embed_weight, num_new_tokens)
_noisy_mean_initialization(embed_weight, num_new_tokens, token_ids=new_token_ids)
def resize_embedding_layer(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
new_tokens: Optional[Iterable[str]] = None,
new_special_tokens_config: Optional[dict] = None,
init_special_tokens: str = "noise_init",
) -> None:
r"""Resize token embeddings and initialize new tokens.
r"""Resize token embeddings (when needed) and initialize the newly added tokens.
Resizing and initialization are decoupled: even when the tokenizer vocab fits inside
the model's existing (padded) embedding matrix and no resize is triggered, the newly
added tokens still occupy uninitialized rows and must be initialized. We therefore
resolve the explicit row IDs of ``new_tokens`` and always initialize those rows.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_tokens: Iterable of the newly added token strings. Used to locate the exact
embedding rows to initialize, which is robust to pre-existing embedding padding.
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
@@ -175,44 +295,70 @@ def resize_embedding_layer(
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
current_embedding_size = get_embedding_vocab_size(model)
needs_resize = len(tokenizer) > current_embedding_size
if len(tokenizer) > current_embedding_size:
if needs_resize:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
# mean_resizing=False preserves the original embedding distribution exactly.
# HuggingFace's default mean_resizing=True re-samples new rows from the mean/covariance
# of existing embeddings, which conflicts with our explicit initialization below.
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64, mean_resizing=False)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
# Resolve the exact rows of the new tokens. This works whether or not a resize was
# triggered (e.g. tokens added into a model's pre-existing padding zone).
new_token_ids = _resolve_new_token_ids(new_tokens, tokenizer, new_embedding_size)
if num_new_tokens <= 0 and not new_token_ids:
return
if needs_resize:
logger.info_rank0(
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
)
else:
logger.info_rank0(
f"No resize needed (vocab fits in padded embedding {new_embedding_size}); "
f"initializing {len(new_token_ids or [])} new token(s) in place."
)
# Initialize input embeddings
# Initialize input embeddings
_initialize_embeddings(
model.get_input_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
new_token_ids=new_token_ids,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_input_embeddings().weight.data,
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
new_token_ids=new_token_ids,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
if needs_resize:
model.config.vocab_size = new_embedding_size
# Also update the nested text_config for VL models (e.g., Qwen2.5-VL, LLaVA),
# otherwise config.vocab_size and config.text_config.vocab_size become inconsistent.
if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "vocab_size"):
model.config.text_config.vocab_size = new_embedding_size
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@@ -16,6 +16,7 @@ import inspect
from typing import TYPE_CHECKING
from ...extras import logging
from ...extras.misc import get_device_name
if TYPE_CHECKING:
@@ -81,6 +82,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
elif model_type == "qwen3_5":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
elif model_type == "qwen3_5_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
@@ -97,5 +100,12 @@ def apply_liger_kernel(
else:
kwargs = {}
if get_device_name() == "npu":
import torch
if "Ascend910" not in torch.npu.get_device_name(0):
kwargs["swiglu"] = False
kwargs["fused_linear_cross_entropy"] = False
apply_liger_kernel(**kwargs)
logger.info_rank0("Liger kernel has been applied to the model.")

View File

@@ -84,8 +84,12 @@ def load_unsloth_peft_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""Load peft model with unsloth. Used in both training and inference."""
) -> Optional["PreTrainedModel"]:
r"""Load peft model with unsloth. Used in both training and inference.
Returns None if unsloth does not support the model type, and sets
model_args.use_unsloth = False so callers can fall back to standard loading.
"""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
@@ -95,7 +99,9 @@ def load_unsloth_peft_model(
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
return None
if not is_trainable:
FastLanguageModel.for_inference(model)

View File

@@ -20,6 +20,7 @@ from peft import PeftModel
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
from ..extras import logging
from ..extras.misc import infer_optim_dtype
@@ -84,7 +85,60 @@ def _check_fla_dependencies() -> None:
) from exc
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
def patch_qwen3_5_forward_npu(model: "PreTrainedModel") -> None:
"""Patch for Qwen3.5 models on NPU by importing torch_npu to enable torch.cuda compatibility.
On NPU, torch.cuda operations will fail unless torch_npu is imported.
torch_npu provides compatibility layer that maps torch.cuda calls to NPU operations.
Also replaces chunk_gated_delta_rule with NPU-compatible implementation.
"""
import importlib.metadata
if "Ascend910" not in torch.npu.get_device_name(0):
logger.warning_rank0("Currently only 910B series NPUs are supported for the NPU GDN patch.")
return
try:
importlib.metadata.version("triton_ascend")
except importlib.metadata.PackageNotFoundError:
logger.warning_rank0(
"triton_ascend not installed, skipping NPU GDN patch. "
"To enable it on NPU, reinstall Triton with the Ascend build: "
"`pip uninstall -y triton && pip install -r requirements/triton_ascend.txt`. "
"Note: triton and triton_ascend cannot coexist — triton must be uninstalled first."
)
return
logger.info_rank0("triton_ascend detected for NPU compatibility.")
from ..third_party.triton.chunk_gated_delta_rule import chunk_gated_delta_rule as npu_chunk_gated_delta_rule
if model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
try:
# Qwen3.5-MoE structure: model.model.language_model.layers
for layer in model.model.language_model.layers:
if hasattr(layer, "linear_attn"):
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
logger.info_rank0(
"Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5-MoE model."
)
except Exception as e:
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
elif model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
try:
# Qwen3.5 structure: model.model.layers
for layer in model.model.layers:
if hasattr(layer, "linear_attn"):
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
logger.info_rank0("Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5 model.")
except Exception as e:
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
def patch_qwen3_5_forward_gpu(model: "PreTrainedModel") -> None:
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
@@ -162,8 +216,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
if position_ids is not None and batch_size == 1:
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
else:
cu_seqlens = None
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.
@@ -397,9 +457,14 @@ def patch_model(
prepare_valuehead_model(model)
if model_args.resize_vocab:
# Pass the explicit list of newly added tokens so their exact embedding rows can be
# located and initialized, even when they land in a model's pre-existing padding zone.
new_tokens = (model_args.add_tokens or []) + (model_args.add_special_tokens or [])
resize_embedding_layer(
model,
tokenizer,
new_tokens=new_tokens or None,
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
init_special_tokens=model_args.init_special_tokens,
)
@@ -415,8 +480,12 @@ def patch_model(
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
patch_qwen3_5_forward(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"]:
if is_torch_npu_available():
patch_qwen3_5_forward_npu(model)
elif is_torch_cuda_available() and model_args.flash_attn == "fa2":
# this is the patch for packing/neat_packing for GPU GDN. And when setting packing, flash_attn must be fa2.
patch_qwen3_5_forward_gpu(model)
if not model_args.use_unsloth:
print_attn_implementation(model.config)

View File

@@ -0,0 +1,594 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import get_autotune_config, get_npu_properties, prepare_chunk_indices, prepare_chunk_offsets
CUBE_CORE_NUM = get_npu_properties()["num_aicore"]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=get_autotune_config(multibuffer_list=(False,)),
key=["H", "K", "V", "BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
w,
v_new,
g,
gk,
h,
h0,
ht,
cu_seqlens,
chunk_offsets,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_all = T
NT_all = NT
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# Initialize hidden states
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
if IS_VARLEN:
v = v + (i_h * T_all + bos) * V
k = k + (i_h * T_all + bos) * K
w = w + (i_h * T_all + bos) * K
g = g + i_h * T_all + bos
h = h + (i_h * NT_all + boh) * K * V
if SAVE_NEW_VALUE:
v_new_base = v_new + (i_h * T_all + bos) * V
else:
v = v + (i_n * H + i_h) * T * V
k = k + (i_n * H + i_h) * T * K
w = w + (i_n * H + i_h) * T * K
g = g + (i_n * H + i_h) * T
h = h + (i_n * H + i_h) * NT * K * V
if SAVE_NEW_VALUE:
v_new_base = v_new + (i_n * H + i_h) * T * V
if USE_INITIAL_STATE:
h0_ptr = h0 + i_nh * K * V
if STORE_FINAL_STATE:
ht_ptr = ht + i_nh * K * V
# Load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# Main recurrence over chunks
for i_t in range(NT):
# Store current hidden state h_t
p_h1 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
# Compute v_residual = v - w @ h
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
p_v = tl.make_block_ptr(v, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(v_new_base, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
# Apply output gate g
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
b_g_last = tl.load(g + last_idx)
p_g = tl.make_block_ptr(g, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_v *= (m_t * tl.exp(b_g_last - b_g))[:, None]
b_g_last_exp = tl.exp(b_g_last)
b_h1 *= b_g_last_exp
if K > 64:
b_h2 *= b_g_last_exp
if K > 128:
b_h3 *= b_g_last_exp
if K > 192:
b_h4 *= b_g_last_exp
# Apply key gate gk
if USE_GK:
o_k1 = tl.arange(0, 64).to(tl.float32)
gk_base_ptr = gk + (i_n * H + i_h) * T * K
b_gk_last1 = tl.load(gk_base_ptr + last_idx * K + o_k1, mask=(o_k1 < K), other=0.0)
b_h1 *= tl.exp(b_gk_last1)[:, None]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(gk_base_ptr + last_idx * K + o_k2, mask=(o_k2 < K), other=0.0)
b_h2 *= tl.exp(b_gk_last2)[:, None]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(gk_base_ptr + last_idx * K + o_k3, mask=(o_k3 < K), other=0.0)
b_h3 *= tl.exp(b_gk_last3)[:, None]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(gk_base_ptr + last_idx * K + o_k4, mask=(o_k4 < K), other=0.0)
b_h4 *= tl.exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)
# Update hidden state: h += k @ v
p_k = tl.make_block_ptr(k, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last1[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h1 += tl.dot(b_k, b_v)
if K > 64:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last2[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h2 += tl.dot(b_k, b_v)
if K > 128:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last3[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h3 += tl.dot(b_k, b_v)
if K > 192:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last4[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h4 += tl.dot(b_k, b_v)
# Store final state
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # default:64
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, u.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V).permute(0, 2, 1, 3, 4).contiguous()
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
BV = 128
v_new = torch.empty_like(u).permute(0, 2, 1, 3).contiguous() if save_new_value else None
k = k.permute(0, 2, 1, 3).contiguous()
w = w.permute(0, 2, 1, 3).contiguous()
u = u.permute(0, 2, 1, 3).contiguous()
g = g.permute(0, 2, 1).contiguous()
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[(triton.cdiv(V, BV), N * H)](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
NT=NT,
)
h = h.permute(0, 2, 1, 3, 4).contiguous()
v_new = v_new.permute(0, 2, 1, 3).contiguous()
return h, v_new, final_state
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["dh0"] is not None,
"USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=get_autotune_config(multibuffer_list=(True, False)),
key=["H", "K", "V", "BT", "BV", "USE_G", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
q,
k,
w,
g,
gk,
dht,
dh0,
do,
dh,
dv,
dv2,
cu_seqlens,
chunk_offsets,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_FINAL_STATE_GRADIENT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_all = T
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_dh2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_dh3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
w += (bos * H + i_h) * K
do += (bos * H + i_h) * V
dv += (bos * H + i_h) * V
dv2 += (bos * H + i_h) * V
dh += (boh * H + i_h) * K * V
if USE_GK:
gk += (bos * H + i_h) * K
if USE_INITIAL_STATE:
dh0 += i_nh * K * V
if USE_FINAL_STATE_GRADIENT:
dht += i_nh * K * V
stride_v = H * V
stride_h = H * K * V
stride_k = H * K
if USE_FINAL_STATE_GRADIENT:
p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_dh1 += tl.load(p_dht1, boundary_check=(0, 1))
if K > 64:
p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_dh2 += tl.load(p_dht2, boundary_check=(0, 1))
if K > 128:
p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_dh3 += tl.load(p_dht3, boundary_check=(0, 1))
if K > 192:
p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_dh4 += tl.load(p_dht4, boundary_check=(0, 1))
for i_t in range(NT - 1, -1, -1):
p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
if IS_VARLEN:
bos_g = i_h * T_all + bos
else:
bos_g = (i_n * H + i_h) * T_all
bg_last = tl.load(g + bos_g + last_idx)
bg_last_exp = tl.exp(bg_last)
p_g = tl.make_block_ptr(
base=g + bos_g, shape=(T,), strides=(1,), offsets=(i_t * BT,), block_shape=(BT,), order=(0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_exp = tl.exp(b_g)
p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
# Update dv
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0)
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
if K > 64:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
if K > 128:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
if K > 192:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
b_dv *= (m_t * tl.exp(bg_last - b_g))[:, None]
b_dv += tl.load(p_dv, boundary_check=(0, 1))
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# Update dh
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
if USE_G:
b_dh1 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh1 *= tl.exp(b_gk_last1[:, None])
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 64:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh2 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh2 *= tl.exp(b_gk_last2[:, None])
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 128:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh3 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh3 *= tl.exp(b_gk_last3[:, None])
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 192:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh4 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh4 *= tl.exp(b_gk_last4[:, None])
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if USE_INITIAL_STATE:
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_bwd_dhu(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
do: torch.Tensor,
dv: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
h0: torch.Tensor | None = None,
dht: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *q.shape, do.shape[-1]
# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
dh = q.new_empty(B, NT, H, K, V)
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
dv2 = torch.empty_like(dv)
BV = 128
g = g.permute(0, 2, 1).contiguous()
chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[(triton.cdiv(V, BV), N * H)](
q=q,
k=k,
w=w,
g=g,
gk=gk,
dht=dht,
dh0=dh0,
do=do,
dh=dh,
dv=dv,
dv2=dv2,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
)
return dh, dh0, dv2

View File

@@ -0,0 +1,347 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional
import torch
from .chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .solve_tril import solve_tril
from .utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
from .wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
):
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
return g, o, A, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
):
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=False,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dv = chunk_bwd_dv_local(
q=q,
k=k,
g=g,
do=do,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=q,
k=k,
w=w,
g=g,
h0=initial_state,
dht=dht,
do=do,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dq, dk, dw, dg = chunk_bwd_dqkwg(
q=q,
k=k,
v=v_new,
w=w,
g=g,
h=h,
dv=dv,
do=do,
dh=dh,
chunk_size=chunk_size,
scale=scale,
cu_seqlens=cu_seqlens,
)
dk2, dv, db, dg2 = prepare_wy_repr_bwd(
k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size
)
dk.add_(dk2)
dg.add_(dg2)
if dg.dtype != torch.float32:
raise ValueError(f"dg current type is {dg.dtype} , should be float32")
dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False)
return dq, dk, dv, db, dg, dh0
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@autocast_custom_fwd
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
chunk_size: int = 64,
):
q_rstd, k_rstd = None, None
g, o, A, final_state = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
ctx.chunk_size = chunk_size
return o.to(q.dtype), final_state
@staticmethod
@input_guard
@autocast_custom_bwd
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
A=A,
scale=ctx.scale,
initial_state=initial_state,
do=do,
dht=dht,
cu_seqlens=cu_seqlens,
chunk_size=ctx.chunk_size,
)
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
head_first: bool = False,
):
r"""Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor):
betas of shape `[B, T, H]`.
scale (Optional[float]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
use_qk_l2norm_in_kernel (bool):
Whether to apply L2norm to the q/k tensor internally. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
This argument has been deprecated.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
""" # noqa: D205
if q.dtype != k.dtype or k.dtype != v.dtype:
raise ValueError(
f"q current type is {q.dtype} , k current type is {k.dtype} ,v current type is {v.dtype} , they should are equal"
)
if q.dtype == torch.float32:
raise ValueError("ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.")
if len(beta.shape) != 3:
raise ValueError(
f"beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
)
if head_first:
warnings.warn(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
original_dtype = x.dtype
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
# Counteract verl's autocast promotion (bf16 -> fp32) by restoring original dtype
return (x * inv_norm).to(original_dtype)
if use_qk_l2norm_in_kernel:
q = l2norm(q, dim=-1, eps=1e-6)
k = l2norm(k, dim=-1, eps=1e-6)
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, False, chunk_size
)
return o, final_state

View File

@@ -0,0 +1,617 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import exp, prepare_chunk_indices, prepare_chunk_offsets
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"USE_DW": lambda args: args["dw"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_bwd_kernel_dqkwg(
q,
k,
v,
h,
g,
g_gamma,
do,
dh,
dq,
dk,
dg,
w,
dv,
dw,
cu_seqlens,
chunk_indices,
scale,
B: tl.constexpr,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
USE_DW: tl.constexpr,
IS_VARLEN: tl.constexpr,
gdiff,
):
i_t, i_b = tl.program_id(0), tl.program_id(1)
T_max = T
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
total = B * T_max
T = eos - bos
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
total = B * T_max
NK = tl.cdiv(K, BK)
for i_k in range(NK):
if USE_G:
dg_k = dg + i_k * total * H
for i_h in range(H):
v_h = v + (bos * H + i_h) * V
do_h = do + (bos * H + i_h) * V
h_h = h + (i_tg * H + i_h).to(tl.int64) * K * V
dh_h = dh + (i_tg * H + i_h).to(tl.int64) * K * V
q_h = q + (bos * H + i_h) * K
k_h = k + (bos * H + i_h) * K
dq_h = dq + (bos * H + i_h) * K
dk_h = dk + (bos * H + i_h) * K
if USE_DW:
w_h = w + (bos * H + i_h) * K # noqa: F841
dw_h = dw + (bos * H + i_h) * K
dv_h = dv + (bos * H + i_h) * V
if USE_G:
if IS_VARLEN:
dg_h = dg_k + i_h * T_max + bos
g_h = g + i_h * T_max + bos
else:
dg_h = dg_k + (i_b * H + i_h) * T_max
g_h = g + (i_b * H + i_h) * T_max
b_dg_last = tl.zeros(
[
1,
],
dtype=tl.float32,
)
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
b_g_last = b_gamma * min(BT, T - i_t * BT)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
if USE_G:
b_dg_last += tl.sum(b_h * b_dh)
b_ds += tl.dot(b_do, tl.trans(b_v))
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
if USE_DW:
p_dv = tl.make_block_ptr(dv_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
if USE_DW:
p_dw = tl.make_block_ptr(dw_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
p_q = tl.make_block_ptr(q_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
p_dq = tl.make_block_ptr(dq_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_dg = tl.zeros(
[
BT,
],
dtype=tl.float32,
)
p_g = tl.make_block_ptr(g_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_last = tl.load(g_h + (min(i_t * BT + BT, T) - 1) * 1)
b_dg_last *= tl.exp(b_g_last)
b_dq = b_dq * tl.exp(b_g)[:, None] * scale
b_dg += tl.sum(b_dq * b_q, axis=1)
b_dk = b_dk * tl.where(m_t, tl.exp(-b_g + b_g_last), 0)[:, None]
b_dg -= tl.sum(b_k * b_dk, axis=1)
b_dg_last += tl.sum(b_dk * b_k)
if IS_VARLEN:
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
else:
p_gdiff = tl.make_block_ptr(
gdiff + i_b * H * NT * BT * BT + i_h * NT * BT * BT + i_t * BT * BT,
(BT, BT),
(BT, 1),
(0, 0),
(BT, BT),
(1, 0),
)
gdiff_ = tl.load(p_gdiff)
b_ds = b_ds * gdiff_ * scale
b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
b_dg += tl.sum(b_ds2, axis=1)
b_dg -= tl.sum(b_ds2, axis=0)
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)
p_dg = tl.make_block_ptr(dg_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
last_index_local = min(BT, T - i_t * BT) - 1
if last_index_local >= 0:
is_last_mask = tl.arange(0, BT) == last_index_local
b_dg = tl.where(is_last_mask, b_dg + b_dg_last, b_dg)
else:
pass
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
elif USE_G_GAMMA:
b_dq = b_dq * exp(b_g)[:, None] * scale
b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None]
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
else:
b_ds = tl.where(m_A, b_ds, 0)
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_bwd_kernel_dv_local(
q,
k,
g,
g_gamma,
do,
dv,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_b = tl.program_id(0), tl.program_id(1)
T_max = T
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
for i_h in range(H):
offset_kh = (bos * H + i_h) * K
offset_vh = (bos * H + i_h) * V
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + offset_kh, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_q = tl.make_block_ptr(q + offset_kh, (K, T), (1, H * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, b_q)
if USE_G:
if IS_VARLEN:
offset_g = i_h * T_max + bos
else:
offset_g = i_b * H * T_max + i_h * T_max
p_g = tl.make_block_ptr(g + offset_g, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_A = tl.where(m_A, b_A * tl.exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
else:
b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_do = tl.make_block_ptr(do + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
v,
h,
g,
g_gamma,
o,
cu_seqlens,
chunk_offsets,
scale,
T,
H: tl.constexpr,
N: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_max = T
for i_v in range(tl.cdiv(V, BV)):
for i_n in range(N):
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
base_chunks_per_pid = NT // total_cores
remainder = NT % total_cores
if core_id < remainder:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * base_chunks_per_pid + remainder
# offset calculation
for i_h in range(0, H):
q_offset = (bos * Hg + i_h // (H // Hg)) * K
k_offset = (bos * Hg + i_h // (H // Hg)) * K
v_offset = (bos * H + i_h) * V
o_offset = (bos * H + i_h) * V
for i_t in range(start_idx, start_idx + chunks_this_pid):
i_tg = boh + i_t
h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + q_offset, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k + k_offset, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o += tl.dot(b_q, b_h)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A += tl.dot(b_q, b_k)
if USE_G:
if IS_VARLEN:
p_g = tl.make_block_ptr(g + bos + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
else:
p_g = tl.make_block_ptr(g + bos * H + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v + v_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + o_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_bwd_dqkwg(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
do: torch.Tensor,
h: torch.Tensor,
dh: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
w: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
scale: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 128 if cu_seqlens is None else 64
BV = 64
NK = triton.cdiv(K, BK)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
g = g.transpose(1, 2).contiguous()
dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
dw = torch.empty_like(w) if w is not None else None
grid = (NT, B)
if cu_seqlens is None:
if NT * BT == T:
g_ = g.reshape(B, H, NT, BT)
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
g_diff = g_diff.clamp(-60, 60).exp()
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
else:
diff = NT * BT - T
g_ = torch.cat((g, torch.zeros(B, H, diff).to(g.device)), dim=-1).reshape(B, H, NT, BT)
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
g_diff = g_diff.clamp(-60, 60).exp()
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
bias = torch.arange(0, BT).to(g.device)
o_t = (NT - 1) * BT + bias
m_t = o_t < T
m_A = m_t[:, None] & m_t
g_diff[:, :, -1] *= m_A
else:
g_diff = None
chunk_bwd_kernel_dqkwg[grid](
q=q,
k=k,
v=v,
h=h,
g=g,
g_gamma=g_gamma,
do=do,
dh=dh,
dv=dv,
w=w,
dw=dw,
dq=dq,
dk=dk,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
B=B,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
gdiff=g_diff,
)
if dg is not None:
dg = dg.sum(0)
dg = dg.transpose(1, 2).contiguous()
return dq, dk, dw, dg
def chunk_bwd_dv_local(
q: torch.Tensor,
k: torch.Tensor,
do: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
scale: float = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, H, K, V = *k.shape, do.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
BK = 128
BV = 128
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g = g.transpose(1, 2).contiguous()
dv = torch.empty_like(do)
grid = (NT, B)
chunk_bwd_kernel_dv_local[grid](
q=q,
k=k,
g=g,
g_gamma=g_gamma,
do=do,
dv=dv,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return dv
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) # noqa: F841
if scale is None:
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
if cu_seqlens is None:
N, chunk_offsets = B, None
else:
N, chunk_offsets = (
len(cu_seqlens) - 1,
prepare_chunk_offsets(cu_seqlens, BT),
)
def grid(meta):
return (triton.cdiv(V, meta["BV"]), N * H)
g = g.transpose(1, 2).contiguous()
h = h.contiguous()
CV_kernel_num = 24
chunk_fwd_kernel_o[(CV_kernel_num,)](
q,
k,
v,
h,
g,
g_gamma,
o,
cu_seqlens,
chunk_offsets,
scale,
T=T,
H=H,
N=N,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=128,
BV=128,
)
return o
bwd_chunk_dqkwg = chunk_bwd_dqkwg
bwd_chunk_dv_local = chunk_bwd_dv_local

View File

@@ -0,0 +1,359 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import prepare_chunk_indices
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
USE_G: tl.constexpr,
NT,
B,
TOTAL_TASKS,
):
core_id = tl.program_id(0)
num_blocks = tl.num_programs(0)
T_max = T
base_tasks_per_block = TOTAL_TASKS // num_blocks
remainder_tasks = TOTAL_TASKS % num_blocks
if core_id < remainder_tasks:
tasks_this_core = base_tasks_per_block + 1
start_idx = core_id * tasks_this_core
else:
tasks_this_core = base_tasks_per_block
start_idx = core_id * base_tasks_per_block + remainder_tasks
for idx in range(start_idx, start_idx + tasks_this_core):
i_b = idx // NT
local_idx = idx % NT
if IS_VARLEN:
i_n = tl.load(chunk_indices + local_idx * 2).to(tl.int32)
i_t = tl.load(chunk_indices + local_idx * 2 + 1).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_local = eos - bos
else:
bos, eos = 0, T
i_t = local_idx
T_local = T
for i_h in range(H):
k_batch_off = i_b * T_max * H * K
beta_batch_off = i_b * H * T_max
g_batch_off = i_b * H * T_max
A_batch_off = i_b * T_max * H * BT
p_beta = tl.make_block_ptr(
beta + beta_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + k_batch_off + (bos * H + i_h) * K,
(T_local, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
dot_product = tl.dot(b_k, tl.trans(b_k))
o_t = i_t * BT + tl.arange(0, BT)
o_t = o_t.to(tl.float32)
T_mask = (o_t < T_local).to(tl.float32)
row_indices = tl.arange(0, BT)[:, None]
col_indices = tl.arange(0, BT)[None, :]
tril_mask = (row_indices > col_indices).to(tl.float32)
tril_mask = tril_mask * T_mask[:, None]
masked_dot = dot_product * tril_mask
b_A += masked_dot
if USE_G:
p_g = tl.make_block_ptr(
g + g_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_g_diff = tl.minimum(tl.maximum(b_g_diff, -50.0), 50.0)
b_A *= tl.exp(b_g_diff)
b_A *= b_beta[:, None]
p_A = tl.make_block_ptr(
A + A_batch_off + (bos * H + i_h) * BT, (T_local, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(configs=[triton.Config({"BK": BK}) for BK in [32, 64]], key=["BC"])
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_i, i_j = i_c // NC, i_c % NC
for i_h in range(H):
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_val = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_val = T
should_compute = (i_t * BT + i_i * BC < T_val) and (i_i > i_j)
if should_compute:
k_ptr = k + (bos * H + i_h) * K
g_ptr = g + (bos * H + i_h) * K
A_ptr = A + (bos * H + i_h) * BT
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T_val,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
b_kt = tl.make_block_ptr(
k_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
p_gk = tl.make_block_ptr(
g_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
b_gn = tl.load(g_ptr + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
b_g = tl.load(p_g, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.exp(b_g - b_gn[None, :])
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kt = tl.load(b_kt, boundary_check=(0, 1)) * tl.exp(b_gn[:, None] - b_gk)
b_A += tl.dot(b_k, b_kt)
b_A *= b_beta[:, None]
p_A = tl.make_block_ptr(A_ptr, (T_val, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_i, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
for i_h in range(H):
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_val = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_val = T
should_compute = i_t * BT + i_i * BC < T_val
if should_compute:
o_i = tl.arange(0, BC)
o_k = tl.arange(0, BK)
m_k = o_k < K
m_A = (i_t * BT + i_i * BC + o_i) < T_val
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
)
p_beta = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.load(p_beta, mask=m_A, other=0)[:, None]
b_g = tl.load(p_g, boundary_check=(0, 1))
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
for j in range(0, min(BC, T_val - i_t * BT - i_i * BC)):
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
b_A = tl.sum(b_k * b_kt[None, :] * tl.exp(b_g - b_gk[None, :]), 1)
# 转化成f32
o_i_tmp = o_i.to(tl.float32)
b_A = tl.where(o_i_tmp > j, b_A, 0.0)
tl.store(A + o_A + j, b_A, mask=m_A)
p_kt += H * K
p_gk += H * K
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
beta: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, H, K = k.shape
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
beta = beta.transpose(1, 2).contiguous()
g = g.transpose(1, 2).contiguous()
BK = 128
kernel_num = 24
if gk is None:
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(kernel_num,)](
k=k,
g=g,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BK=BK,
NT=NT,
B=B,
TOTAL_TASKS=B * NT,
)
return A
BC = min(16, BT)
NC = triton.cdiv(BT, BC)
BK = max(triton.next_power_of_2(K), 16)
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
grid = (NT, NC * NC, B)
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
k=k,
g=gk,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
NC=NC,
)
grid = (NT, NC, B)
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
k=k,
g=gk,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
BK=BK,
)
return A

View File

@@ -0,0 +1,147 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import prepare_chunk_indices
@triton.heuristics(
{"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None}
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_scalar_kernel(
s,
o,
scale,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BLOCK_T: tl.constexpr,
REVERSE: tl.constexpr,
HAS_SCALE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
CHUNK_SIZE: tl.constexpr = 64,
):
i_block, i_b = tl.program_id(0), tl.program_id(1)
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
if IS_VARLEN:
i_s, i_block = (
tl.load(chunk_indices + i_block * 2).to(tl.int32),
tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
b_s = tl.trans(b_s, (1, 0, 2))
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
if HAS_SCALE:
b_o *= scale
b_o = tl.trans(b_o, (1, 0, 2))
b_o = tl.reshape(b_o, (BLOCK_T, H))
tl.store(ptr_o, b_o.to(ptr_o.dtype.element_ty), boundary_check=(0,))
return
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
) -> torch.Tensor:
B, T, H = g.shape
if chunk_size != 2 ** (chunk_size.bit_length() - 1):
raise ValueError(f"chunk_size must be a power of 2, chunk_size is{chunk_size}")
# We adjust the tiling strategy to prevent overflow in in backward passes and context parallel scenarios
# while maximizing UB utilization where possible.
# The tiling strategy is as follows:
# 1. BT must be greater than or equal to chunk_size.
# 2. UB estimation varies directly with H.
# 3. BT in reverse mode is smaller than in forward mode.
BT = max(chunk_size, triton.next_power_of_2((1 << 11 if reverse else 1 << 12) // H))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B)
chunk_local_cumsum_scalar_kernel[grid](
s=g_org,
o=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
BLOCK_T=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
CHUNK_SIZE=chunk_size,
)
return g
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is not None:
if g.shape[0] != 1:
raise ValueError(
f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}"
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(
g=g,
chunk_size=chunk_size,
reverse=reverse,
scale=scale,
cu_seqlens=cu_seqlens,
head_first=head_first,
output_dtype=output_dtype,
)
else:
raise ValueError(
f"Unsupported input shape {g.shape}, "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise"
)

View File

@@ -0,0 +1,272 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import input_guard, make_tensor_descriptor, prepare_chunk_indices
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def solve_tril_16x16_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = pid_bh // H, pid_bh % H
base_t = pid_t * TPP
if IS_VARLEN:
i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_eff = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_eff = T
o_i = tl.arange(0, 16) # noqa: F841
o_i_fp32 = tl.arange(0, 16).to(tl.float32)
m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
m_I = o_i_fp32[:, None] == o_i_fp32[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * BT
for tpp in tl.static_range(0, TPP):
tile_t = base_t + tpp
tile_row = tile_t * 16
offset = (tile_t * 16) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0))
b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16])
b_A_raw = desc.load([tile_row, offset]).to(tl.float32)
b_A_neg = -b_A_raw
b_A = b_A_neg * m_A
for i in range(2, min(16, T_eff - tile_row)):
slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1])
b_a_val = tl.reshape(slice_res, (16,), can_reorder=True)
dot_prod = tl.sum(b_a_val[:, None] * b_A, 0)
b_a_update = b_a_val + dot_prod
b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
b_Ai_11 += m_I
b_Ai_22 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def solve_tril_64x64_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 64)
m_I = o_i[:, None] == o_i[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * 64
offset = (i_t * 64) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 64, offset), (64, 64), (1, 0))
b_A = -tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [64, 64])
desc_o = make_tensor_descriptor(Ai, [T, 64], [H * 64, 1], [64, 64])
b_A = -desc.load([i_t * 64, offset]).to(tl.float32)
for i in range(2, min(64, T - i_t * 64)):
b_a = -tl.load(A + (i_t * 64 + i) * H * BT + o_i + offset)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (64, 64), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * 64, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@input_guard
def solve_tril(
A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float
) -> torch.Tensor:
"""Compute the inverse of the matrix I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
(I + A)^-1 with the same shape as A
""" # noqa: D205
if A.shape[-1] not in [16, 32, 64]:
raise ValueError(f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}")
output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype)
if BT == 16:
merge_fn = solve_tril_16x16_kernel
elif BT == 32:
merge_fn = merge_16x16_to_32x32_inverse_kernel
elif BT == 64:
merge_fn = solve_tril_64x64_kernel
merge_fn[NT, B * H](
A=A,
Ai=Ai,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
USE_TMA=False,
DOT_PRECISION=FLA_TRIL_PRECISION,
)
return Ai

View File

@@ -0,0 +1,359 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import itertools
import logging
import os
import warnings
from collections.abc import Callable
from enum import Enum
from typing import Any, Optional
import torch
import triton
import triton.language as tl
import triton.language.extra.libdevice as tldevice
import triton.runtime.driver as driver
from packaging import version
logger = logging.getLogger(__name__)
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
def tensor_cache(fn: Optional[Callable[..., torch.Tensor]] = None, *, maxsize: int = 1) -> Any:
"""A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the outputs of the decorated function for the most recent
set of input tensors, up to `maxsize` entries. If the function is called again with
the same input tensors, it will return the cached result.
When maxsize=1 (default), the behavior is identical to caching only the most recent result.
Can be used as @tensor_cache or @tensor_cache(maxsize=n).
Args:
fn (Callable[..., torch.Tensor], optional):
The function to be decorated when used without parentheses.
maxsize (int):
Maximum number of input combinations to cache. Default is 1.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with caching.
"""
if maxsize < 1:
raise ValueError("maxsize must be at least 1")
def _is_match(a: Any, b: Any) -> bool:
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
return a is b
try:
return a == b
except Exception:
return a is b
def _make_wrapper(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
cache: list = []
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
for i, (cached_args, cached_kwargs, cached_result) in enumerate(cache):
if len(args) == len(cached_args) and len(kwargs) == len(cached_kwargs):
if all(_is_match(a, b) for a, b in zip(args, cached_args)) and all(
k in cached_kwargs and _is_match(v, cached_kwargs[k]) for k, v in kwargs.items()
):
if i != 0:
cache.insert(0, cache.pop(i))
return cached_result
result = fn(*args, **kwargs)
cache.insert(0, (args, kwargs, result))
if len(cache) > maxsize:
cache.pop()
return result
return wrapper
if fn is not None:
return _make_wrapper(fn)
return _make_wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache(maxsize=3)
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
def get_abs_err(x, y):
return (x.detach() - y.detach()).flatten().abs().max().item()
def get_err_ratio(x, y):
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
base = (x.detach()).flatten().square().mean().sqrt().item()
return err / (base + 1e-8)
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
abs_atol = get_abs_err(ref, tri)
msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
logger.info(msg)
error_rate = get_err_ratio(ref, tri)
if abs_atol <= err_atol:
return
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
if error_rate > ratio:
warnings.warn(msg)
else:
assert error_rate < ratio, msg
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
# For Triton 3.3.x
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
elif hasattr(triton.language, "make_tensor_descriptor"):
# For Triton 3.4.x and later
make_tensor_descriptor = triton.language.make_tensor_descriptor
else:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@triton.jit
def make_tensor_descriptor(
base,
shape,
strides,
block_shape,
_builder=None,
):
return None
@functools.cache
def get_available_device() -> str:
try:
return triton.runtime.driver.active.get_current_target().backend
except BaseException:
_cpu_device_warning()
return "cpu"
def map_triton_backend_to_torch_device() -> str:
backend = get_available_device() # 'cuda' | 'hip' | 'xpu' | 'cpu' | ...
return {"cuda": "cuda", "hip": "cuda", "xpu": "xpu"}.get(backend, backend)
device = get_available_device() if get_available_device() != "hip" else "cuda"
device_torch_lib = getattr(torch, device)
device_platform = get_available_device()
is_amd = device_platform == "hip"
is_nvidia = device_platform == "cuda"
is_nvidia_hopper = is_nvidia and (
"NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
)
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
is_tma_supported = (
(is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9)
and os.environ.get("FLA_NO_USE_TMA", "0") != "1"
and (
hasattr(triton.language, "_experimental_make_tensor_descriptor")
or hasattr(triton.language, "make_tensor_descriptor")
)
)
if is_nvidia and not is_tf32_supported:
# Make old card happy, since triton will use tf32 by default.
# This is a workaround for old nvidia card.
os.environ["TRITON_F32_DEFAULT"] = "ieee"
@functools.cache
def check_pytorch_version(version_s: str = "2.4") -> bool:
return version.parse(torch.__version__) >= version.parse(version_s)
if check_pytorch_version("2.4"):
device = "cuda" if device == "cpu" else device
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
def custom_device_ctx(index: int):
return device_torch_lib.device(index)
else:
assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0."
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
def custom_device_ctx(index: int):
return torch.cuda.device(index)
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""A decorator to make sure all input tensors are contiguous and set the device based on input tensors."""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = arg
break
if tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
tensor = value
break
if tensor is not None:
ctx = custom_device_ctx(tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return fn(*contiguous_args, **contiguous_kwargs)
return wrapper
def _cpu_device_warning():
warnings.warn(("Triton is not supported on current platform, roll back to CPU."), stacklevel=1)
@tensor_cache
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
exp = tldevice.fast_expf
exp2 = tldevice.exp2
log = tldevice.fast_logf
log2 = tldevice.fast_log2f
else:
exp = tl.exp
exp2 = tl.math.exp2
log = tl.log
log2 = tl.log2
def get_all_max_shared_mem():
try:
return [
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
for i in range(device_torch_lib.device_count())
]
except BaseException:
_cpu_device_warning()
return [-1]
class Backend(Enum):
ADA = 101376 # RTX 4090
AMPERE = 166912 # A100
HOPPER = 232448 # H100
DEFAULT = 102400 # Default
@classmethod
def get_shared_memory(cls, arch: str) -> int:
try:
return cls[arch.upper()].value
except KeyError:
return cls.DEFAULT.value
@functools.cache
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
try:
device_shared_mem_list = get_all_max_shared_mem()
max_shared_memory = device_shared_mem_list[tensor_idx]
return max_shared_memory >= Backend.get_shared_memory(arch)
except Exception:
return False
def get_autotune_config(
multibuffer_list: tuple = (False,),
unit_flag_list: tuple = (False,),
limit_auto_multi_buffer_only_for_local_buffer_list: tuple = (False,),
limit_auto_multi_buffer_of_local_buffer_list: tuple = ("no-l0c",),
set_workspace_multibuffer_list: tuple = (2, 4),
enable_hivm_auto_cv_balance_list: tuple = (True,),
tile_mix_vector_loop_num_list: tuple = (2, 4),
tile_mix_cube_loop_num_list: tuple = (2, 4),
):
configs = []
for (
multibuffer,
unit_flag,
limit_auto_multi_buffer_only_for_local_buffer,
limit_auto_multi_buffer_of_local_buffer,
) in itertools.product(
list(multibuffer_list),
list(unit_flag_list),
list(limit_auto_multi_buffer_only_for_local_buffer_list),
list(limit_auto_multi_buffer_of_local_buffer_list),
):
base_config_dict = {
"multibuffer": multibuffer,
"unit_flag": unit_flag,
"limit_auto_multi_buffer_only_for_local_buffer": limit_auto_multi_buffer_only_for_local_buffer,
"limit_auto_multi_buffer_of_local_buffer": limit_auto_multi_buffer_of_local_buffer,
}
if limit_auto_multi_buffer_only_for_local_buffer:
configs.append(triton.Config(base_config_dict))
else:
for (
set_workspace_multibuffer,
enable_hivm_auto_cv_balance,
tile_mix_vector_loop,
tile_mix_cube_loop,
) in itertools.product(
list(set_workspace_multibuffer_list),
list(enable_hivm_auto_cv_balance_list),
list(tile_mix_vector_loop_num_list),
list(tile_mix_cube_loop_num_list),
):
full_config_dict = base_config_dict.copy()
full_config_dict.update(
{
"set_workspace_multibuffer": set_workspace_multibuffer,
"enable_hivm_auto_cv_balance": enable_hivm_auto_cv_balance,
"tile_mix_vector_loop": tile_mix_vector_loop,
"tile_mix_cube_loop": tile_mix_cube_loop,
}
)
configs.append(triton.Config(full_config_dict))
return configs
def get_npu_properties():
return driver.active.utils.get_device_properties(torch.npu.current_device())

View File

@@ -0,0 +1,387 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import exp, prepare_chunk_indices
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def prepare_wy_repr_bwd_kernel(
k,
v,
beta,
g,
A,
dw,
du,
dk,
dv,
dbeta,
dg,
cu_seqlens,
chunk_indices,
T,
B,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
NT: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
T_max = T
base_chunks_per_pid = NT // total_cores
remainder_chunks = NT % total_cores
if core_id < remainder_chunks:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * chunks_this_pid + remainder_chunks
for idx in range(start_idx, start_idx + chunks_this_pid):
for i_b in range(B):
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + idx * 2).to(tl.int32),
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
i_t = idx
bos, eos = i_b * T, i_b * T + T
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
for i_h in range(0, H):
if IS_VARLEN:
offset = bos + i_h * T_max
else:
offset = bos * H + i_h * T_max
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_exp = tl.exp(b_g)
b_dbeta = tl.zeros([BT], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
b_dg = tl.zeros([BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dk = tl.make_block_ptr(
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dw = tl.make_block_ptr(
dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_k_beta_g = (b_k * b_beta[:, None] * b_g_exp[:, None]).to(b_k.dtype)
b_dw = tl.load(p_dw, boundary_check=(0, 1))
b_dA += tl.dot(b_dw, tl.trans(b_k_beta_g))
b_dk_beta_g = tl.dot(b_A, b_dw)
b_dk = b_dk_beta_g * b_beta[:, None] * b_g_exp[:, None]
b_dbeta += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None], 1)
b_dg += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None] * b_beta[:, None], 1)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_dv = tl.make_block_ptr(
dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_du = tl.make_block_ptr(
du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
b_du = tl.load(p_du, boundary_check=(0, 1))
b_dA += tl.dot(b_du, tl.trans(b_v_beta))
b_dv_beta = tl.dot(b_A, b_du)
b_dv = b_dv_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
b_dA = tl.where(m_A, b_dA, 0)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA = tl.where(m_A, -b_dA * exp(b_g[:, None] - b_g[None, :]), 0)
b_dA = b_dA.to(k.dtype.element_ty)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dk = tl.make_block_ptr(
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.load(p_dk, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_A += tl.dot(b_k_beta, tl.trans(b_k))
b_dk_beta = tl.dot(b_dA, b_k)
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
b_dk += tl.dot(tl.trans(b_dA), b_k_beta)
b_dk += b_dk_beta * b_beta[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
b_dA_A = b_dA * b_A
b_dg += tl.sum(b_dA_A, axis=1) - tl.sum(b_dA_A, axis=0)
p_dg = tl.make_block_ptr(dg + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_dbeta = tl.make_block_ptr(dbeta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
gk,
cu_seqlens,
chunk_indices,
T_tmp,
B,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
NT: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
T_max = T_tmp
base_chunks_per_pid = NT // total_cores
remainder_chunks = NT % total_cores
if core_id < remainder_chunks:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * chunks_this_pid + remainder_chunks
for idx in range(start_idx, start_idx + chunks_this_pid):
for i_b in range(B):
for i_h in range(0, H):
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + idx * 2).to(tl.int32),
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
offset = bos + i_h * T_max
T = eos - bos
else:
T = T_tmp
i_t = idx
bos, eos = i_b * T, i_b * T + T
offset = bos * H + i_h * T_max
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
if USE_G:
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
if USE_G:
b_kb *= b_g[:, None]
if USE_GK:
p_gk = tl.make_block_ptr(
gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_kb *= tl.exp(tl.load(p_gk, boundary_check=(0, 1)))
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 128
BV = 128
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g = g.transpose(1, 2).contiguous() if g is not None else None
beta = beta.transpose(1, 2).contiguous()
w = torch.empty_like(k)
u = torch.empty_like(v)
cv_kernel_num = 24
recompute_w_u_fwd_kernel[(cv_kernel_num,)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g,
gk=gk,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T_tmp=T,
B=B,
H=H,
K=K,
V=V,
NT=NT,
BT=BT,
BK=BK,
BV=BV,
)
return w, u
def prepare_wy_repr_bwd(
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
dw: torch.Tensor,
du: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
chunk_size: int = 64,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 128
BV = 128
beta = beta.transpose(1, 2).contiguous()
g = g.transpose(1, 2).contiguous()
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dbeta = torch.empty_like(beta)
dg = torch.empty_like(g)
cv_kernel_num = 24
prepare_wy_repr_bwd_kernel[(cv_kernel_num,)](
k=k,
v=v,
beta=beta,
g=g,
A=A,
dw=dw,
du=du,
dk=dk,
dv=dv,
dbeta=dbeta,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
K=K,
V=V,
NT=NT,
BT=BT,
BK=BK,
BV=BV,
)
dbeta = dbeta.transpose(1, 2).contiguous()
dg = dg.transpose(1, 2).contiguous()
return dk, dv, dbeta, dg
bwd_prepare_wy_repr = prepare_wy_repr_bwd
fwd_recompute_w_u = recompute_w_u_fwd

View File

@@ -20,7 +20,6 @@ import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
@@ -584,7 +583,7 @@ class ModuleProfilerCallback(TrainerCallback):
if matched:
logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
)
else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@@ -616,7 +615,7 @@ class ModuleProfilerCallback(TrainerCallback):
bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
logger.info_rank0("\n".join(lines))
self._forward_times.clear()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_sft
from .workflow import run_pt, run_sft
__all__ = ["run_sft"]
__all__ = ["run_pt", "run_sft"]

View File

@@ -0,0 +1,222 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperParallel distributed trainer for LlamaFactory."""
import logging
import os
import types
from contextlib import nullcontext
from typing import Any, Optional
import torch
from hyper_parallel.integration.llamafactory import (
HSDPModule,
HyperParallelArguments,
export_to_hf_format,
fsdp2_prepare_model,
hsdp_sync_stream,
load_hsdp_model,
load_hsdp_optimizer_and_scheduler,
save_hsdp_checkpoint,
wrap_optimizer_with_skip_dtensor_dispatch,
)
from hyper_parallel.integration.llamafactory import (
clip_grad_norm_ as hp_clip_grad_norm_,
)
from torch import nn
from ..sft.trainer import CustomSeq2SeqTrainer
logger = logging.getLogger(__name__)
class HyperParallelTrainer(CustomSeq2SeqTrainer):
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
Inherits CustomSeq2SeqTrainer for training algorithm logic (loss, metrics,
prediction, sampler, etc.) and only overrides HSDP-specific behavior.
"""
def __init__(
self,
hp_args: HyperParallelArguments,
finetuning_args=None,
processor=None,
ref_model: Optional[nn.Module] = None,
**kwargs,
):
self._hp_args = hp_args
# Let CustomSeq2SeqTrainer handle everything except ref_model —
# Custom would prepare it with accelerate's fsdp2_prepare_model,
# but we need HP's version instead.
super().__init__(
finetuning_args=finetuning_args,
processor=processor,
ref_model=None,
**kwargs,
)
if not getattr(self.accelerator, "is_fsdp2", False):
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
# Prepare ref_model with HP's fsdp2_prepare_model
self.ref_model = ref_model
if self.ref_model is not None:
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
self._orig_fsdp2_prepare_model = None
self._accelerator_patches_active = False
def _activate_accelerator_patches(self) -> None:
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
if self._accelerator_patches_active:
return
import accelerate.accelerator as acc_module # pylint: disable=C0415
hp_args = self._hp_args
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
def _hp_fsdp2_prepare_model(accelerator, model):
return fsdp2_prepare_model(accelerator, model, hp_args)
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2):
if getattr(accelerator, "is_fsdp2", False):
accelerator.unscale_gradients()
parameter_list = list(parameters)
parameter_ids = {id(param) for param in parameter_list}
for model in accelerator._models: # pylint: disable=protected-access
if not isinstance(model, HSDPModule):
continue
model_param_ids = {id(param) for param in model.parameters()}
if parameter_ids and parameter_ids.issubset(model_param_ids):
return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type)
return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator)
self._accelerator_patches_active = True
def _restore_accelerator_patches(self) -> None:
"""Restore original Accelerate methods."""
if not self._accelerator_patches_active:
return
import accelerate.accelerator as acc_module # pylint: disable=C0415
if self._orig_fsdp2_prepare_model is not None:
acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model
self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm
self._accelerator_patches_active = False
def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module:
"""Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct."""
del dataloader
if isinstance(model, HSDPModule):
return model
if training and getattr(self.accelerator, "is_fsdp2", False):
return model
return super()._wrap_model(model, training=training)
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
"""Skip redundant device moves for HSDP-wrapped models."""
if isinstance(model, HSDPModule):
return model
if device is None:
return model
return model.to(device)
def train(self, *args, **kwargs):
"""Activate HP patches during training and restore afterwards."""
self._activate_accelerator_patches()
try:
return super().train(*args, **kwargs)
finally:
self._restore_accelerator_patches()
def training_step(
self,
model: nn.Module,
inputs: dict[str, Any],
num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
"""Standard training step with HSDP gradient synchronization."""
model.train()
inputs = self._prepare_inputs(inputs)
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
if isinstance(model, HSDPModule):
model.set_is_last_backward(sync_gradients)
model.set_requires_gradient_sync(sync_gradients)
compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext)
with compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.args.n_gpu > 1:
loss = loss.mean()
if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None:
loss = loss / self.args.gradient_accumulation_steps
self.accelerator.backward(loss)
if isinstance(model, HSDPModule) and sync_gradients:
hsdp_sync_stream()
return loss.detach()
def create_optimizer(self):
"""Create optimizer and wrap step with SkipDTensorDispatch."""
optimizer = super().create_optimizer()
wrap_optimizer_with_skip_dtensor_dispatch(optimizer)
return optimizer
def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
"""Save model/optimizer shards per-rank and scheduler."""
save_hsdp_checkpoint(
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
output_dir=output_dir,
should_save_scheduler=self.args.should_save and self.lr_scheduler is not None,
)
def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None:
"""Load model from HSDP sharded checkpoint."""
target = model if model is not None else self.model
loaded = load_hsdp_model(target, resume_from_checkpoint)
if not loaded:
return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
self._pending_hsdp_checkpoint = resume_from_checkpoint
return None
def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None:
"""Load optimizer/scheduler from per-rank checkpoint files."""
ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint
if ckpt_dir is None:
return
load_hsdp_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, ckpt_dir)
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""Save model weights in HuggingFace-compatible format."""
save_dir = output_dir or self.args.output_dir
os.makedirs(save_dir, exist_ok=True)
export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir)

View File

@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional
from transformers import DataCollatorForLanguageModeling
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
@@ -21,9 +24,9 @@ from ...extras.misc import calculate_tps
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import SaveProcessorCallback
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import HyperParallelTrainer
if TYPE_CHECKING:
@@ -35,6 +38,90 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelArguments"):
r"""Load HyperParallel arguments and apply LlamaFactory-side overrides.
When activation optimization is enabled, skip native gradient checkpointing
so HP can install its own via ``setup_activation_optimization``.
"""
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
if hp_args.activation_mode != "none":
model_args.disable_gradient_checkpointing = True
return hp_args
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
hp_args = _prepare_hp_args(finetuning_args, model_args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**tokenizer_module,
)
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -43,13 +130,7 @@ def run_sft(
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
HyperParallelArguments,
HyperParallelTrainer,
)
hp_args = _prepare_hp_args(finetuning_args, model_args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
@@ -94,25 +175,6 @@ def run_sft(
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
callbacks = list(callbacks or [])
processor = tokenizer_module.get("processor")
if processor is not None:
callbacks.append(SaveProcessorCallback(processor))
compute_loss_func = None
if finetuning_args.use_dft_loss:
compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss:
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
elif finetuning_args.use_asft_loss:
from functools import partial
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
@@ -122,20 +184,11 @@ def run_sft(
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
compute_loss_func=compute_loss_func,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if finetuning_args.use_badam:
from types import MethodType
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
trainer.add_callback(BAdamCallback)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

View File

@@ -80,20 +80,27 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
if training_args.enable_torch_profiler:
if getattr(training_args, "enable_torch_profiler", False):
callbacks.append(TorchProfilerCallback(training_args))
if training_args.profile_modules:
if getattr(training_args, "profile_modules", None):
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel:
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from .hyper_parallel import run_sft as run_sft_hp
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
if finetuning_args.stage == "pt":
from .hyper_parallel import run_pt as run_pt_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks)
else:
from .hyper_parallel import run_sft as run_sft_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available():

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.types import AttentionFunction
from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
__all__ = [
"AttentionFunction",
"BatchingStrategy",
"DataArguments",
"InputArgument",

View File

@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
model_args, data_args, training_args, sample_args = parsed_args
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
return tuple(parsed_args)
return model_args, data_args, training_args, sample_args
if __name__ == "__main__":

View File

@@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from ..utils.types import AttentionFunction
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -32,6 +33,12 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
flash_attn: AttentionFunction = field(
default=AttentionFunction.SDPA,
metadata={
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
@@ -54,6 +61,12 @@ class ModelArguments:
)
def __post_init__(self) -> None:
supported_flash_attn = [item.value for item in AttentionFunction]
if self.flash_attn not in supported_flash_attn:
raise ValueError(
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
)
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
full_determinism: bool = field(
default=False,
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
)
resume_from_checkpoint: str | None = field(
default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
@@ -116,3 +120,9 @@ class TrainingArguments:
self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
if self.max_steps is None or self.max_steps <= 0:
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
if self.save_epochs is not None:
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")

View File

@@ -34,7 +34,7 @@ import torch.nn.functional as F
from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..config import BatchingStrategy, TrainingArguments
from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
@@ -147,13 +147,19 @@ class BaseTrainer:
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
raise ValueError(
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
)
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None:
if (
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
):
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
renderer=self.renderer,
@@ -237,6 +243,7 @@ class BaseTrainer:
self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
for micro_batches in self.train_batch_generator:
self.global_step += 1

View File

@@ -120,6 +120,7 @@ class ModelEngine:
init_device = DistributedInterface().current_device
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
@@ -164,6 +165,7 @@ class ModelEngine:
self.args.model,
config=self.model_config,
dtype="auto",
attn_implementation=self.args.flash_attn,
trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
)
@@ -188,9 +190,12 @@ class ModelEngine:
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model, include_kernels=self.args.kernel_config.get("include_kernels")
)
kernel_config = self.args.kernel_config
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
if kernel_config.name == "liger_kernel":
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
kernel_kwargs["require_logits"] = self.is_train
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
return model

View File

@@ -42,6 +42,8 @@ from .rendering import Renderer
logger = logging.get_logger(__name__)
__all__ = ["BatchGenerator"]
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
if not self.drop_last:
raise ValueError("Drop last must be True.")
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
}
self._init_data_provider()
self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer = StatefulBuffer()
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
"data_iter": self._data_iter,
}
logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
@@ -137,12 +138,19 @@ class BatchGenerator(Iterator):
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
if self.batching_strategy == BatchingStrategy.NORMAL:
batch_size = self.micro_batch_size * self.num_micro_batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
generator_seed = torch.Generator()
generator_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
batch_size=batch_size,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
@@ -156,8 +164,7 @@ class BatchGenerator(Iterator):
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
def __len__(self) -> int:
return self._length
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
def _generate_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
try:
return next(self._data_iter)
except StopIteration:
if not restart:
return None
# Dynamic batching may restart the provider to fill one token-budgeted batch.
self._data_iter = iter(self._data_provider)
try:
return next(self._data_iter)
except StopIteration:
return None
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer.state_dict(),

View File

@@ -34,7 +34,7 @@ class BaseKernel(ABC):
"""
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
_device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc.
@classmethod
def get_kernel_id(cls) -> str:
@@ -42,8 +42,8 @@ class BaseKernel(ABC):
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
def get_device(cls) -> list[DeviceType]:
"""Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"])."""
return cls._device
@classmethod
@@ -58,7 +58,7 @@ class BaseKernel(ABC):
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if cls._device != get_current_accelerator().type:
if get_current_accelerator().type not in cls._device:
return False
return True

View File

@@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode
apply_kernel(kernel, model=model)
return model
@KernelPlugin("liger_kernel").register()
def apply_liger_kernels(
model: HFModel,
include_kernels: str = None,
require_logits: bool = False,
) -> HFModel:
"""Applies Liger kernel to the model.
Args:
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with
library defaults. If a comma-separated list (e.g.
``rope,rms_norm``), enable only those ops; names match
``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``,
``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``.
If ``None`` or ``False``, do nothing. Defaults to ``None``.
require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor
of non-fused CE so the forward pass returns ``logits``. Needed
for trainers that compute weighted loss from logits (e.g. v1
SFT with ``loss_weights``). Defaults to ``False`` (fused CE
when supported). The v1 ``run_sft`` entrypoint sets
``require_logits`` to true for ``liger_kernel`` when the key
is omitted so SFT weighted loss keeps working.
Returns:
HFModel: The model with Liger kernel applied.
"""
if not include_kernels:
return model
if include_kernels == "auto" or include_kernels is True:
use_kernels = "auto"
else:
use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()]
if not use_kernels:
return model
try:
from .liger_kernel_ops import LigerKernel
except ImportError as e:
logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}")
return model
return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits)

View File

@@ -0,0 +1,148 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of Liger Kernel.
Init Phase:
1. Define LigerKernel class.
2. Register Liger kernel.
"""
import inspect
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.logging import get_logger
from ....utils.types import HFModel
from .base import BaseKernel
logger = get_logger(__name__)
_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = {
"qwen3": "apply_liger_kernel_to_qwen3",
"qwen3_moe": "apply_liger_kernel_to_qwen3_moe",
"qwen3_next": "apply_liger_kernel_to_qwen3_next",
"qwen3_5": "apply_liger_kernel_to_qwen3_5",
"qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text",
"qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe",
"qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text",
}
class LigerKernel(BaseKernel):
"""Liger Kernel for optimized model training."""
_device = [DeviceType.CUDA, DeviceType.NPU]
@classmethod
def check_deps(cls) -> bool:
"""Checks if the required dependencies for the kernel are available."""
try:
import liger_kernel # noqa: F401
return super().check_deps()
except ImportError:
logger.warning_rank0(
"Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel."
)
return False
@classmethod
def apply(cls, **kwargs) -> "HFModel":
"""Applies the Liger kernel to the model.
Args:
**kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op
names to enable exclusively, or the string ``"auto"`` to use each
``apply_liger_kernel_to_*`` function's signature defaults (same as calling
upstream with only ``model``). Optional ``require_logits`` forces non-fused
cross entropy when supported.
Returns:
HFModel: The model with Liger kernel applied.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model")
use_kernels = kwargs.get("use_kernels", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(
f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}"
)
require_logits = kwargs.get("require_logits", False)
model_type = getattr(model.config, "model_type", None)
if model_type not in _LIGER_FN_BY_MODEL_TYPE:
logger.warning_rank0("Current model does not support liger kernel.")
return model
import liger_kernel.transformers as liger_transformers
apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type])
sig = inspect.signature(apply_liger_kernel).parameters
togglable = [name for name in sig if name != "model"]
def _normalize_op_name(raw: str) -> str:
key = raw.strip().lower().replace("-", "_")
aliases = {
"rmsnorm": "rms_norm",
"flce": "fused_linear_cross_entropy",
"lce": "fused_linear_cross_entropy",
"fused_ce": "fused_linear_cross_entropy",
}
return aliases.get(key, key)
if use_kernels is not None and len(use_kernels) == 0:
return model
if use_kernels != "auto":
selected = {_normalize_op_name(k) for k in use_kernels}
ops = selected - set(togglable)
if ops:
raise ValueError(
f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}"
)
if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected:
raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.")
call_kwargs = {name: (name in selected) for name in togglable}
call_kwargs["model"] = model
else:
# Mirror ``liger_kernel`` signature defaults so patches match upstream defaults
# and logging reflects enabled ops (omitted kwargs only live in the callee).
call_kwargs = {"model": model}
for name in togglable:
param = sig[name]
if param.default is not inspect.Parameter.empty:
call_kwargs[name] = param.default
if require_logits and "fused_linear_cross_entropy" in sig:
logger.warning_rank0("Current training stage does not support chunked cross entropy.")
call_kwargs["fused_linear_cross_entropy"] = False
call_kwargs["cross_entropy"] = True
apply_liger_kernel(**call_kwargs)
applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on)
logger.info_rank0(f"These Liger ops are applied to the model: {applied}")
return model

View File

@@ -228,6 +228,30 @@ class NpuMoeFused:
routed_out = self.experts(hidden_states, routing_weights, router_indices)
return routed_out
@staticmethod
def npu_moe_experts_v5_forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
Transformers v5 stores expert weights in F.linear layout:
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
down_proj: [num_experts, hidden_dim, intermediate_dim]
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
"""
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
hidden_states, top_k_index.to(torch.int32)
)
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
gate_up_proj = self.gate_up_proj.transpose(1, 2)
down_proj = self.down_proj.transpose(1, 2)
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
class Qwen3NpuMoeFused:
"""Container for Qwen3 NPU fused MoE forward functions."""
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
# moe patch config mapping
kernel_moe_mapping = {
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
if is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForCausalLM": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForConditionalGeneration": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
}
}
if not is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
else:
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
},
}

View File

@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
.. note::
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
without hardcoding version numbers. Two methods are used: weight value inspection
(most reliable) and class name pattern matching (backward compatibility).
This must follow the module's forward semantics. Do not infer it from trained
weight values because standard RMSNorm weights can also be close to zero.
"""
if hasattr(module, "weight") and module.weight is not None:
weight_mean = module.weight.data.mean().item()
if abs(weight_mean) < 0.3:
return True
residual_rmsnorm_classes = {
"Qwen3_5RMSNorm",
"Qwen3_5MoeRMSNorm",
"Qwen3NextRMSNorm",
}
class_name = module.__class__.__name__
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
for pattern in residual_patterns:
if pattern in class_name:
return True
return False
return class_name in residual_rmsnorm_classes
def npu_rms_norm_forward(self, hidden_states):
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
if hasattr(self, "weight") and self.weight is not None:
if _should_use_residual_rmsnorm(self):
if getattr(self, "_npu_use_residual_rmsnorm", False):
effective_weight = 1.0 + self.weight.float()
else:
effective_weight = self.weight.float()
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
if "Gated" in module.__class__.__name__:
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
else:
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model

View File

@@ -58,7 +58,7 @@ class Registry:
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
if get_current_accelerator().type not in device:
return
if not kernel_id:

View File

@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
@@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
if attention_mask is None:
if position_ids is not None:
attention_mask = torch.ones_like(position_ids).to(torch.int64)
else:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
if position_ids is not None:
global_position_ids = [
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_position_ids, position_ids, group=self.spg)
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
attention_mask = None
else:
attention_mask = attention_mask.to(torch.int64)
if attention_mask is None:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
else:
attention_mask = attention_mask.to(torch.int64)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
context_layer = self.attn_fn(
q,

View File

@@ -12,23 +12,272 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from math import ceil
from typing import Any
import torch
from torch.utils.data import default_collate
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import pad_and_truncate
from ...utils.objects import StatefulBuffer
from ...utils.plugin import BasePlugin
from ...utils.types import BatchInfo, BatchInput, DataLoader
from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput
class BatchingPlugin(BasePlugin):
def compute_length(self, data_provider: DataLoader) -> int:
def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int:
"""Return the raw data provider batch size for this batching strategy."""
return self["get_data_provider_batch_size"](batch_info)
def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int:
"""Compute the length of the batch generator.
The approximate length is used to calculate the lr schedule.
"""
raise NotImplementedError()
return self["compute_length"](data_provider, batch_info)
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
def fill_buffer(
self,
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
"""Fill the buffer with data."""
raise NotImplementedError()
return self["fill_buffer"](buffer, batch_info, next_samples)
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
"""Generate a batch from the buffer."""
raise NotImplementedError()
return self["generate_batch"](buffer, batch_info)
def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
"""Return sample counts for micro batches formed by one padded-token budget."""
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
max_sample_len = 0
used = 0
is_complete = False
while index + used < len(samples):
sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len)
padded_tokens = max(max_sample_len, sample_len) * (used + 1)
if used > 0 and padded_tokens > budget:
is_complete = True
break
max_sample_len = max(max_sample_len, sample_len)
used += 1
if max_sample_len * used >= budget:
is_complete = True
break
if used == 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
current_tokens = 0
used = 0
is_complete = False
while index + used < len(samples):
sample = samples[index + used]
sample_len = min(len(sample["input_ids"]), cutoff_len)
if current_tokens + sample_len > budget:
is_complete = True
break
current_tokens += sample_len
used += 1
if used <= 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
"""Pack fixed samples into one padding-free sequence without a token budget."""
packed: dict[str, list[Any]] = {}
position_ids: list[int] = []
for sample_index, sample in enumerate(samples):
# Padding-free still truncates each sample by cutoff_len before packing
# all samples into one contiguous sequence.
sample_len = min(len(sample["input_ids"]), cutoff_len)
if sample_len <= 0:
continue
for key, value in sample.items():
if key in ("attention_mask", "position_ids") or isinstance(value, str):
continue
if key not in packed:
packed[key] = []
sliced_value = list(value[:sample_len])
if sample_index > 0 and sliced_value:
if key == "labels":
sliced_value[0] = IGNORE_INDEX
elif key == "loss_weights":
sliced_value[0] = 0.0
packed[key].extend(sliced_value)
position_ids.extend(range(sample_len))
if not position_ids:
return None
packed["position_ids"] = position_ids
packed["attention_mask"] = [1] * len(position_ids)
return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()}
@BatchingPlugin("padding_free").register("get_data_provider_batch_size")
def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
@BatchingPlugin("padding_free").register("compute_length")
def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
return len(data_provider)
@BatchingPlugin("padding_free").register("fill_buffer")
def fill_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]:
samples = next_samples(False)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("padding_free").register("generate_batch")
def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
num_micro_batch = batch_info["num_micro_batch"]
cutoff_len = batch_info["cutoff_len"]
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return None
samples = buffer.get(batch_size)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len)
if packed_micro_batch is None:
return None
batch.append(packed_micro_batch)
return batch
@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size")
def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_batching").register("compute_length")
def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_batching").register("fill_buffer")
def fill_dynamic_batching_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_batching").register("generate_batch")
def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
return batch
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_padding_free").register("compute_length")
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
def fill_dynamic_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
if packed_batch is None:
return None
batch.append(packed_batch)
return batch

View File

@@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
if dist_config.get("cp_size", 1) > 1:
raise ValueError("CP currently requires `dist_config.name: fsdp2`.")
from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(

View File

@@ -13,22 +13,46 @@
# limitations under the License.
import random
import numpy as np
import torch
from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.helper import is_torch_npu_available
from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
def enable_full_determinism(seed: int) -> None:
"""Enable full deterministic mode for reproducible distributed training."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
if is_torch_npu_available():
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
def set_seed(seed: int, full_determinism: bool = False) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
full_determinism: Whether to enable full deterministic mode.
"""
hf_set_seed(seed)
if full_determinism:
enable_full_determinism(seed)
else:
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool:

View File

@@ -33,6 +33,10 @@ class StatefulBuffer:
def size(self) -> int:
return self._buffer_size
@property
def samples(self) -> list[ModelInput]:
return self._buffer
def put(self, samples: list[ModelInput]) -> None:
"""Add samples to the buffer."""
num_tokens = sum(len(sample["input_ids"]) for sample in samples)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from enum import StrEnum, unique
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
@@ -54,6 +54,13 @@ else:
ProcessGroup = None
@unique
class AttentionFunction(StrEnum):
EAGER = "eager"
SDPA = "sdpa"
FLASH_ATTENTION_2 = "flash_attention_2"
class DatasetInfo(TypedDict, total=False):
path: str
"""Local file path."""
@@ -171,8 +178,6 @@ class BatchInfo(TypedDict):
"""Number of micro batches."""
cutoff_len: int
"""Cutoff length."""
data_iter: Iterator[list[ModelInput]]
"""Data iterator."""
class ModelOutput(NamedTuple):

View File

@@ -0,0 +1,149 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from llamafactory.model.model_utils.embedding import (
_description_based_initialization,
_existing_embeddings,
_noisy_mean_initialization,
_resolve_new_token_ids,
)
class _StubTokenizer:
"""Minimal tokenizer stub mapping token strings to fixed IDs."""
unk_token_id = 0
def __init__(self, mapping: dict[str, int], desc_ids: list[int] | None = None):
self._mapping = mapping
self._desc_ids = desc_ids or []
def convert_tokens_to_ids(self, token: str) -> int:
return self._mapping.get(token, self.unk_token_id)
def __call__(self, desc, return_tensors=None, add_special_tokens=False):
return {"input_ids": torch.tensor([self._desc_ids], dtype=torch.long)}
class _StubModel:
"""Wraps an embedding matrix so ``get_input_embeddings()`` is a usable lookup."""
def __init__(self, embed_weight: "torch.Tensor"):
self._emb = torch.nn.Embedding.from_pretrained(embed_weight.clone(), freeze=True)
def get_input_embeddings(self):
return self._emb
def test_resolve_new_token_ids_returns_none_without_config():
tokenizer = _StubTokenizer({})
assert _resolve_new_token_ids(None, tokenizer, embed_size=100) is None
assert _resolve_new_token_ids([], tokenizer, embed_size=100) is None
def test_resolve_new_token_ids_filters_invalid_and_dedups():
# "<a>" valid, "<unk_like>" maps to unk_token_id (skipped), "<oob>" out of range (skipped)
tokenizer = _StubTokenizer({"<a>": 10, "<unk_like>": 0, "<oob>": 999, "<b>": 5})
# duplicates and unsorted input -> sorted unique in-range IDs
tokens = ["<a>", "<a>", "<unk_like>", "<oob>", "<b>"]
assert _resolve_new_token_ids(tokens, tokenizer, embed_size=100) == [5, 10]
# passing a dict iterates its keys (config compatibility)
assert _resolve_new_token_ids({"<a>": "desc"}, tokenizer, embed_size=100) == [10]
def test_existing_embeddings_excludes_new_token_ids():
embed_weight = torch.arange(10 * 2, dtype=torch.float32).reshape(10, 2)
# explicit ids take precedence and drop exactly those rows
existing = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=[2, 5])
assert existing.size(0) == 8
# tail fallback when no explicit ids
tail = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=None)
assert torch.allclose(tail, embed_weight[:-3])
# no resize and no ids -> use everything
everything = _existing_embeddings(embed_weight, num_new_tokens=0, new_token_ids=None)
assert torch.allclose(everything, embed_weight)
def test_noisy_mean_initialization_with_token_ids_targets_exact_rows():
"""New tokens placed by explicit IDs must hit those rows, even inside the padding zone."""
torch.manual_seed(0)
vocab_size, embedding_dim = 20, 8
embed_weight = torch.zeros(vocab_size, embedding_dim)
# existing rows carry a constant so the mean is well-defined and non-zero
embed_weight[:16] = 1.0
# num_new_tokens reflects the embedding resize delta (4 padded rows),
# but the real new tokens sit at IDs 16 and 17 (inside what the tail slice would miss/over-cover).
target_ids = [16, 17]
_noisy_mean_initialization(embed_weight, num_new_tokens=4, token_ids=target_ids)
# targeted rows are initialized around the mean (~1.0) and not left at zero
for tid in target_ids:
assert not torch.allclose(embed_weight[tid], torch.zeros(embedding_dim))
assert abs(embed_weight[tid].mean().item() - 1.0) < 0.5
# untouched padding rows (18, 19) must remain zero
assert torch.allclose(embed_weight[18], torch.zeros(embedding_dim))
assert torch.allclose(embed_weight[19], torch.zeros(embedding_dim))
def test_noisy_mean_initialization_tail_fallback():
"""Without token_ids, falls back to the last num_new_tokens rows."""
torch.manual_seed(0)
vocab_size, embedding_dim = 12, 8
embed_weight = torch.zeros(vocab_size, embedding_dim)
embed_weight[:10] = 1.0
_noisy_mean_initialization(embed_weight, num_new_tokens=2, token_ids=None)
# last two rows initialized, earlier rows untouched
assert not torch.allclose(embed_weight[-1], torch.zeros(embedding_dim))
assert not torch.allclose(embed_weight[-2], torch.zeros(embedding_dim))
assert torch.allclose(embed_weight[0], torch.ones(embedding_dim))
def test_description_init_excludes_new_token_ids_from_average():
"""Description tokens that are themselves new (uninitialized) must be excluded.
Reproduces the padding-zone bug: id 17 is a new token and must not pollute the
semantic average for id 16; only the valid existing token (id 5) should be used.
"""
vocab_size, embedding_dim = 20, 4
embed_weight = torch.zeros(vocab_size, embedding_dim)
embed_weight[5] = 3.0 # the only valid description token
# description for "<x>" tokenizes to [5 (existing), 17 (new -> must be skipped)]
tokenizer = _StubTokenizer({"<x>": 16}, desc_ids=[5, 17])
model = _StubModel(embed_weight)
_description_based_initialization(
embed_weight,
num_new_tokens=4,
descriptions={"<x>": "ignored, ids come from the stub"},
tokenizer=tokenizer,
model=model,
new_token_ids=[16, 17],
add_noise=False,
)
# row 16 must equal embedding of id 5 only (3.0), not the (5,17) average (1.5)
assert torch.allclose(embed_weight[16], torch.full((embedding_dim,), 3.0))
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@@ -58,10 +58,3 @@ def test_multi_device():
master_port = find_available_port()
world_size = 2
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
if __name__ == "__main__":
"""
python tests_v1/accelerator/test_interface.py
"""
test_all_device()

View File

@@ -70,13 +70,3 @@ def test_get_args_from_yaml(tmp_path: Path):
assert training_args.bf16 is False
assert training_args.dist_config is None
assert sample_args.sample_backend == "hf"
if __name__ == "__main__":
"""
python -m tests_v1.config.test_args_parser
"""
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
test_get_args_from_yaml(tmp_path=Path(tmp_dir))

View File

@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
for index in indexes:
print(data_engine[index])
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
if __name__ == "__main__":
"""
python -m tests_v1.core.test_data_engine
"""
test_map_dataset(1)

View File

@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
if __name__ == "__main__":
"""
python -m tests_v1.core.test_model_loader
"""
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -16,6 +16,164 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.plugins.trainer_plugins.batching import (
BatchingPlugin,
_get_dynamic_micro_batch_sizes,
_get_dynamic_padding_free_micro_batch_sizes,
)
from llamafactory.v1.utils.constants import IGNORE_INDEX
from llamafactory.v1.utils.objects import StatefulBuffer
def _make_model_input(length: int, start: int = 0):
input_ids = list(range(start, start + length))
return {
"input_ids": input_ids,
"attention_mask": [1] * length,
"labels": input_ids.copy(),
"loss_weights": [1.0] * length,
}
class _RestartableDataProvider:
def __init__(self, batches):
self.batches = batches
self.num_iters = 0
def __iter__(self):
self.num_iters += 1
return iter(self.batches)
def test_padding_free():
buffer = StatefulBuffer()
# Input samples:
# sample 0 input_ids: [0, 1]
# sample 1 input_ids: [10, 11, 12, 13]
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
# Output batch:
# sample 1 is truncated to [10, 11, 12]
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 5)
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
assert len(buffer) == 0
def test_batching_plugin_data_provider_batch_sizes():
batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 3,
"cutoff_len": 10,
}
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
def test_dynamic_batching():
# Input samples:
# sample lengths: [3, 4, 6, 2, 8, 9]
# input_ids:
# [0, 1, 2]
# [10, 11, 12, 13]
# [20, 21, 22, 23, 24, 25]
# [30, 31]
# [40, 41, 42, 43, 44, 45, 46, 47]
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Dynamic batching output plan:
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
# as the padded-token budget for one training micro batch.
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
assert len(buffer) == 3
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
# Input buffer:
# only one sample with length [6].
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
# this buffer has not filled the budget and has no next sample to prove overflow,
# so dynamic batching cannot produce a batch yet.
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
# Input data provider:
# each iterator pass yields one sample with length [6].
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
# _fill_buffer keeps restarting the iterator until the next appended sample
# proves that the previous dynamic micro batch has reached its budget boundary.
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_batching"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
batch_generator._fill_buffer()
# Filled buffer after restart:
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
# because adding it to the first dynamic micro batch would exceed the budget.
assert data_provider.num_iters == 4
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic batching returns [micro_batch_0]
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert len(batch_generator._buffer) == 1
def test_normal_batching():
@@ -45,8 +203,166 @@ def test_normal_batching():
assert batch[0]["input_ids"].shape == (4, 10)
if __name__ == "__main__":
def test_dynamic_padding_free():
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
# input_ids breakdown:
# sample 0: [0,1,2] (length=3)
# sample 1: [10,11,12,13] (length=4)
# sample 2: [20,21,22,23,24,25] (length=6)
# sample 3: [30,31] (length=2)
# sample 4: [40-47] (length=8)
# sample 5: [50-58] (length=9)
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1 # num_micro_batch=1
packed_batch = batch[0]
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
assert packed_batch["input_ids"].shape == (1, 15)
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
assert packed_batch["input_ids"].tolist() == [
[
0,
1,
2, # Sample 0
10,
11,
12,
13, # Sample 1
20,
21,
22,
23,
24,
25, # Sample 2
30,
31,
] # Sample 3
]
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
assert packed_batch["labels"].tolist() == [
[
0,
1,
2, # Sample 0
IGNORE_INDEX,
11,
12,
13, # Sample 1
IGNORE_INDEX,
21,
22,
23,
24,
25, # Sample 2
IGNORE_INDEX,
31,
] # Sample 3
]
# Verify attention_mask
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
# Verify position_ids
assert packed_batch["position_ids"].tolist() == [
[
0,
1,
2, # Sample 0
0,
1,
2,
3, # Sample 1
0,
1,
2,
3,
4,
5, # Sample 2
0,
1,
] # Sample 3
]
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
assert len(buffer) == 2
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
Data provider yields one sample of length 6 per iteration.
_fill_buffer keeps restarting iterator until next sample exceeds budget.
Budget = 2 * 10 = 20 tokens.
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
So buffer will have 4 samples after fill_buffer.
"""
python -m tests_v1.core.utils.test_batching
"""
test_normal_batching()
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_padding_free"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
batch_generator._fill_buffer()
# Buffer after restarts:
# 3 samples can fit (18 tokens)
# 4th sample is kept in buffer for next batch
# => num_iters = 4
assert data_provider.num_iters == 4
assert _get_dynamic_padding_free_micro_batch_sizes(
batch_generator._buffer.samples, batch_generator._batch_info
) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic_padding_free returns [micro_batch_0]
# 3 samples packed into shape [1, 18]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 18)
assert len(batch_generator._buffer) == 1

View File

@@ -227,17 +227,3 @@ def test_process_dpo_samples():
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_rendering
"""
test_chatml_rendering()
test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)
test_process_sft_samples()
test_process_dpo_samples()

View File

@@ -117,12 +117,3 @@ def test_pair_converter(num_samples: int):
],
}
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
if __name__ == "__main__":
"""
python -m tests_v1.plugins.data_plugins.test_converter
"""
test_alpaca_converter(1)
test_sharegpt_converter()
test_pair_converter(1)

View File

@@ -52,12 +52,3 @@ def test_init_on_default():
)
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device
if __name__ == "__main__":
"""
python tests_v1/plugins/model_plugins/test_init_plugin.py
"""
test_init_on_meta()
test_init_on_rank0()
test_init_on_default()

View File

@@ -35,10 +35,3 @@ def test_sync_sampler():
"role": "assistant",
"content": [{"type": "text", "value": "This is a test."}],
}
if __name__ == "__main__":
"""
python tests_v1/sampler/test_cli_sampler.py
"""
test_sync_sampler()