42 Commits

Author SHA1 Message Date
jiaqiw09
8669a22e9c [fix] fix liger kernel patch for npu (#10583) 2026-06-16 18:21:52 +08:00
Hao Liang
897a44386c [docs] add DataFlow and DataFlex blog tutorials (#10582)
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-16 14:20:36 +08:00
jiaqiw09
7a1e9630f2 [fix] update ascend doc link (#10572) 2026-06-15 13:55:53 +08:00
souljoy
cabe59a343 [model] add MiniCPM5-1B-Chat (#10558) 2026-06-10 16:18:27 +08:00
Co-Cl2
9ca4026efe [model] handle unsloth model loading fallback during checkpoint resume (#7156) (#10551) 2026-06-09 01:01:01 +08:00
Ximing Xing
0b7aaf8f6a [fix] correctly place new token embeddings when embedding is padded (#10547) 2026-06-05 10:47:51 +08:00
codingma
8a4f6a3da5 [model] add gemma-4-12B-it (#10549) 2026-06-04 23:43:20 +08:00
A1waysBeenHere
409e8a477f [model] Patch GDN for NPU (#10504)
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
2026-06-04 16:39:02 +08:00
Cui-yshoho
053d43c0ac [feat] support HyperParallel PT training and activation optimization (#10370) 2026-06-02 22:39:32 +08:00
Zhao73
a98a1ef101 [docs] fix README citation typo (#10540) 2026-06-01 21:04:53 +08:00
Yaowei Zheng
8ef7335b6a [misc] set dev version (#10533) 2026-05-31 00:16:07 +08:00
Yaowei Zheng
7af909522a [version] release v0.9.5 (#10532) 2026-05-30 23:57:09 +08:00
xvxuopop
e016d2480e [fix] Fix NPU FusedMoE and RMSNorm (#10512) 2026-05-30 21:42:54 +08:00
jiaqiw09
7d719182c9 [model] fix non-packing batch (bsz>1) for Qwen3.5 with flash attention (#10529) 2026-05-30 21:41:41 +08:00
jiaqiw09
01398eb18d [v1] fix padding free with sp (#10513) 2026-05-26 23:49:21 +08:00
cxy
8e68764b65 [v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
2026-05-25 20:40:21 +08:00
Copilot
16ff5a23cb [fix] use getattr for profiler attrs to support MCA TrainingArguments (#10506)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2026-05-21 17:26:29 +08:00
jiaqiw09
bdcb92d035 [v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469) 2026-05-21 17:14:19 +08:00
sunyi0505
7e20db5735 [v1] support liger_kernel (#10493) 2026-05-21 11:44:56 +08:00
浮梦
2322bf1cc2 [v1] add cuda fused moe kernel, implementing with triton (#10481) 2026-05-20 20:49:42 +08:00
浮梦
368c48968f [callback] add torch profiler callback (#10463) 2026-05-20 20:47:52 +08:00
浮梦
8b5ea65770 [v1] support reward training stage (#10431)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-05-20 20:46:52 +08:00
Dennis Huang
40e786d016 [data] add missing return statement in MiniCPM V Plugin (#10500)
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-20 01:50:00 +08:00
xvxuopop
6b9df75ab9 [docker] update npu docker (#10479) 2026-05-13 20:56:43 +08:00
马境远
ca50f22c38 [fix] Fix MiniCPM-V-4.6 image preprocessing behavior (#10478) 2026-05-12 11:35:23 +08:00
马境远
53e77a9bfa [model] support MiniCPM-V-4.6 (#10472) 2026-05-08 18:14:34 +08:00
浮梦
55bd4944b6 [fix] fix qwen3_6 template doc (#10470) 2026-05-08 11:47:02 +08:00
Tai An
7e09152275 fix(data/converter): handle None tool_calls in OpenAI-style messages (#10455) 2026-05-07 17:44:41 +08:00
simulikeit
1e503a982d [assets] correct typo in examples/README_zh.md (#10462) 2026-05-07 00:42:01 +08:00
luca-888
8752280dd7 [data] Optimize QwenVL video dataset preprocessing (#10404)
Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
2026-05-03 18:36:56 +08:00
Kingsley
468723c5d9 [packing] fix GDN crash when meeting dummy image (#10453) 2026-05-01 12:10:13 +08:00
Peilin Li
887ee2b121 [refactor] Add KTransformers AMX MoE SFT support via Accelerate (#10430)
Co-authored-by: mrhaoxx <mr.haoxx@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-01 01:47:58 +08:00
Kingsley
6b08b948c9 [misc] bump transformers version upperbound (#10446) 2026-05-01 01:30:11 +08:00
Hertz
f7f3bfcbd7 [model] support Hy3-Preview (#10432) 2026-04-29 23:21:13 +08:00
Kingsley
3475198d1e [fa2] fix IMA when train qwen3_5 (#10448) 2026-04-29 20:20:55 +08:00
sunyi0505
50945ef850 [v1] fix device_mesh and sp for fsdp2 (#10429) 2026-04-28 11:20:11 +08:00
Octopus
2f0bef207a [export] handle NotImplementedError in export_model for transformers>=5.0 (fixes #10410) (#10438)
Co-authored-by: octo-patch <octo-patch@github.com>
2026-04-27 23:36:23 +08:00
curnane-lab
2092abc217 [npu] add Qwen3.5 support with Partial RoPE and Hybrid Attention (#10421)
Co-authored-by: Curnane <mingliangfu@users.noreply.github.com>
2026-04-27 23:36:07 +08:00
Kingsley
99464b3d03 [misc] code lint (#10439)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-04-27 14:07:31 +08:00
jiaqiw09
9a0cfdccfa [v1] fix init on meta in transformers v5 (#10414) 2026-04-27 00:37:09 +08:00
Kingsley
c8890c32db [data] support discard history cot for multiturn (#10435) 2026-04-27 00:32:44 +08:00
Kingsley
79c8332e4c [train] add qwen35 patch for neat_packing (#10436) 2026-04-27 00:31:49 +08:00
138 changed files with 8252 additions and 2473 deletions

View File

@@ -109,7 +109,7 @@ jobs:
platforms: linux/amd64,linux/arm64
file: ./docker/docker-npu/Dockerfile
build-args: |
BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
BASE_IMAGE=quay.io/ascend/cann:9.0.0-a3-ubuntu22.04-py3.11
push: ${{ github.event_name != 'pull_request' }}
tags: |
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3

View File

@@ -38,7 +38,7 @@ jobs:
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
container:
image: ascendai/cann:8.3.rc2-910b-ubuntu22.04-py3.11
image: ascendai/cann:9.0.0-910b-ubuntu22.04-py3.11
env:
HF_ENDPOINT: https://hf-mirror.com
HF_TOKEN: ${{ secrets.HF_TOKEN }}
@@ -65,8 +65,8 @@ jobs:
- name: Install dependencies
run: |
uv venv
uv pip install -r requirements/npu.txt
uv pip install -e .
uv pip install -r requirements/npu.txt
uv pip install -r requirements/dev.txt
- name: Install node
@@ -89,5 +89,7 @@ jobs:
make build
- name: Test with pytest
shell: bash
run: |
source /usr/local/Ascend/ascend-toolkit/set_env.sh
make test

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
.ai/CLAUDE.md

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
\[ English | [中文](README_zh.md) \]
@@ -52,14 +50,12 @@ Start local training:
Start cloud training:
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
Read technical notes:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
- **Documentation (ASCEND NPU)**: https://llamafactory.readthedocs.io/en/latest/multibackend/npu/index.html
- **Official Blog**: https://blog.llamafactory.net/en/
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -78,7 +74,6 @@ Read technical notes:
- [Data Preparation](#data-preparation)
- [Quickstart](#quickstart)
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
- [LLaMA Factory Online](#llama-factory-online)
- [Build Docker](#build-docker)
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
- [Download from ModelScope Hub](#download-from-modelscope-hub)
@@ -117,15 +112,13 @@ Read technical notes:
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
- 💡 [DataFlow × LLaMA Factory: Producing High-Quality Data for LLM Training with a Data Preparation Pipeline](https://wcny4qa9krto.feishu.cn/wiki/LWkkwTDBfiiRKqkDSvucG6yjnbW) (English) | [中文](https://wcny4qa9krto.feishu.cn/wiki/LlMxweUAJimrmykRD5qcGuswnHd)
- 💡 [DataFlex × LLaMA Factory: A Data-Centric Dynamic Training System Built on LLaMA-Factory](https://wcny4qa9krto.feishu.cn/wiki/OlREwPQWdi9K6ZkJNHIciLhtnkv) (English) | [中文](https://wcny4qa9krto.feishu.cn/wiki/H2A9wSsbCinzavkT2oyc2C5Vn0e)
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
<details><summary>All Blogs</summary>
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
@@ -320,7 +313,7 @@ Read technical notes:
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5/qwen3_5_nothink |
| [Qwen3.6](https://huggingface.co/Qwen) | 35B | qwen3_6/qwen3_6_nothink |
| [Qwen3.6](https://huggingface.co/Qwen) | 27B/35B | qwen3_6 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
@@ -661,10 +654,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
llamafactory-cli webui
```
### LLaMA Factory Online
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
### Build Docker
For CUDA users:
@@ -838,7 +827,7 @@ If you have a project that should be incorporated, please contact via email or c
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collapse. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
\[ [English](README.md) | 中文 \]
@@ -52,16 +50,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
开始云端训练:
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online在线微调**https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
阅读技术文档:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **微调视频教程**https://www.bilibili.com/video/BV1djgRzxEts/
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **框架文档(昇腾 NPU**https://llamafactory.readthedocs.io/zh-cn/latest/multibackend/npu/index.html
- **官方博客**https://blog.llamafactory.net/
- **官方课程**https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [数据准备](#数据准备)
- [快速开始](#快速开始)
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
- [构建 Docker](#构建-docker)
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
- [从魔搭社区下载](#从魔搭社区下载)
@@ -119,15 +113,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
- 💡 [DataFlow × LLaMA Factory: 利用数据准备流水线产出高质量数据训练 LLM](https://wcny4qa9krto.feishu.cn/wiki/LlMxweUAJimrmykRD5qcGuswnHd)(中文)| [English](https://wcny4qa9krto.feishu.cn/wiki/LWkkwTDBfiiRKqkDSvucG6yjnbW)
- 💡 [DataFlex × LLaMA Factory: 构建在 LLaMA-Factory 之上的以数据为中心的动态训练系统](https://wcny4qa9krto.feishu.cn/wiki/H2A9wSsbCinzavkT2oyc2C5Vn0e)(中文)| [English](https://wcny4qa9krto.feishu.cn/wiki/OlREwPQWdi9K6ZkJNHIciLhtnkv)
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
<details><summary>全部博客</summary>
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
- [LLaMA Factory微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
@@ -322,7 +314,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5/qwen3_5_nothink |
| [Qwen3.6](https://huggingface.co/Qwen) | 35B | qwen3_6/qwen3_6_nothink |
| [Qwen3.6](https://huggingface.co/Qwen) | 27B/35B | qwen3_6 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
@@ -662,10 +654,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
llamafactory-cli webui
```
### LLaMA Factory Online 在线微调
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
### 构建 Docker
CUDA 用户:
@@ -842,7 +830,7 @@ swanlab_run_name: test_run # 可选
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collapse. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)

View File

@@ -1,6 +1,6 @@
# https://hub.docker.com/r/ascendai/cann/tags
ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
ARG BASE_IMAGE=quay.io/ascend/cann:9.0.0-910b-ubuntu22.04-py3.11
FROM ${BASE_IMAGE}
# Installation arguments
@@ -36,6 +36,7 @@ COPY . /app
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
RUN pip uninstall -y torch torchvision torchaudio
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
RUN pip install --no-cache-dir -r requirements/triton_ascend.txt
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation

View File

@@ -33,7 +33,7 @@ services:
dockerfile: ./docker/docker-npu/Dockerfile
context: ../..
args:
BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
BASE_IMAGE: quay.io/ascend/cann:9.0.0-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory-a3
image: llamafactory:npu-a3

View File

@@ -96,7 +96,7 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
### 支持弹性和容错的多机指令监督微调
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID由参与该作业的所有节点共享。更多可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID由参与该作业的所有节点共享。更多细节可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
```bash
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/qwen3_full_sft.yaml

View File

@@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer,Qwen3_5VisionBlock
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
rdzv_backend: static
same_network: true
use_cpu: false

View File

@@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer,Qwen3_5MoeVisionBlock
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
rdzv_backend: static
same_network: true
use_cpu: false

View File

@@ -0,0 +1,47 @@
# Start FSDP2 full fine-tuning on Ascend NPU
# Usage:
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config_qwen35.yaml \
# src/train.py examples/ascend/qwen3_5_full_sft_fsdp2.yaml
#
# Note: Change `num_processes` in fsdp2_config_qwen35.yaml to match your NPU count
### model
model_name_or_path: Qwen/Qwen3.5-4B
trust_remote_code: true
use_v1_kernels: true
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: full
### dataset
dataset: alpaca_en_demo
template: qwen3_5_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Qwen3.5-4B/full/sft
logging_steps: 1
save_steps: 500
max_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 8
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null

View File

@@ -0,0 +1,51 @@
# Start FSDP2 full fine-tuning on Ascend NPU
# Usage:
# accelerate launch \
# --config_file examples/accelerate/fsdp2_config_qwen35_moe.yaml \
# src/train.py examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
#
# Note: Change `num_processes` in fsdp2_config_qwen35_moe.yaml to match your NPU count
### model
model_name_or_path: Qwen/Qwen3.5-35B-A3B
trust_remote_code: true
use_v1_kernels: false
flash_attn: fa2
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: alpaca_en_demo
template: qwen3_5_nothink
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
packing: false
### output
output_dir: saves/Qwen3.5-35B/lora/sft
logging_steps: 1
save_steps: 2000
max_steps: 2000
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 1800
resume_from_checkpoint: null
disable_gradient_checkpointing: true

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXBF16 # Use with original BF16 expert weights.
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXINT4 # Use with online-converted INT4 expert weights
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 4 # Adjust based on your GPU count; 4 is suitable for 4 GPUs
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 1 # Adjust based on your GPU count; 1 is suitable for 1 GPU
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_machines: 1
num_processes: 8 # Adjust based on your GPU count; 8 is suitable for 8 GPUs
rdzv_backend: static
same_network: true
use_cpu: false
kt_config:
enabled: true
kt_backend: AMXINT8 # Use with online-converted INT8 expert weights
kt_num_threads: 96
kt_tp_enabled: true
kt_threadpool_count: 2
kt_max_cache_depth: 2
kt_share_backward_bb: true
lora_rank: 8

View File

@@ -1,10 +0,0 @@
model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
adapter_name_or_path: saves/Kllama_deepseekV2
template: deepseek
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,9 +0,0 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
template: deepseek3
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,10 +0,0 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
adapter_name_or_path: saves/Kllama_deepseekV3
template: deepseek3
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,10 +0,0 @@
model_name_or_path: Qwen/Qwen3-235B-A22B-Instruct-2507
adapter_name_or_path: saves/Kllama_Qwen3MoE_235bA22b
template: qwen3_nothink
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,69 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,68 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,139 +0,0 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([12][0-9])\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
10: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View File

@@ -1,69 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,68 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cpu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,68 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace:
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,77 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,392 +0,0 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
# === Rotary Embedding Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.(?!self_attn\\.kv_b_proj).*$"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# === MLP (MoE) Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === MLP Gate Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
# === MLP Experts Replacement ===
# replace with marlin expert. Open and modify layer-num as needed.
# Each layer of malin experts takes about 6GB of GPU memory.
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# GPU 0: layers 34
# - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:0"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 1: layers 1517
# - match:
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:1"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 2: layers 3032
# - match:
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:2"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 3: layers 4546
# - match:
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:3"
# generate_op: "KExpertsMarlin"
# recursive: False
# === MLP Experts Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:2"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:2"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts
kwargs:
prefill_device: "cuda:3"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:3"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False
# === Self-Attention Replacement ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
absorb_for_prefill: False
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
absorb_for_prefill: False
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
absorb_for_prefill: False
# GPU 3: layers 4560
- match:
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
absorb_for_prefill: False
# === Overall Model Replacement with Transfer Map ===
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 means close layerwise prefill
transfer_map:
15: "cuda:1" # Layers 15+ on GPU 1
30: "cuda:2" # Layers 30+ on GPU 2
45: "cuda:3" # Layers 45+ on GPU 3
# === Default Catch-All for Other Modules ===
# GPU 0: layers 014
- match:
name: "^model\\.layers\\.([0-9]|1[0-4])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
# GPU 1: layers 1529
- match:
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
# GPU 2: layers 3044
- match:
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
- match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"

View File

@@ -1,156 +0,0 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:0"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda:1"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View File

@@ -1,77 +0,0 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"

View File

@@ -1,80 +0,0 @@
- match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# - match:
# name: "^model\\.layers\\..*$" # regular expression
# class: torch.nn.Linear # only match modules matching name and class simultaneously
# replace:
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
# kwargs:
# generate_device: "cuda"
# prefill_device: "cuda"
# generate_op: "KLinearTorch"
# prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
replace:
class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KSFTExpertsCPU"
out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "AMXInt8"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KQwen3MoeModel"
kwargs:
per_layer_prefill_intput_threshold: 0

View File

@@ -19,7 +19,7 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_deepseekV2
output_dir: saves/KT_FT_deepseekV2
logging_steps: 10
save_steps: 500
plot_loss: true
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500
use_kt: true
# Pair with fsdp2_kt_bf16.yaml for original BF16 checkpoints.
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
# kt_weight_path: /path/to/DeepSeek-V2-Lite-AMXINT8

View File

@@ -1,5 +1,5 @@
### model
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
model_name_or_path: deepseek-ai/DeepSeek-V3-0324-BF16 # need to convert to BF16 checkpoint first
trust_remote_code: true
### method
@@ -19,7 +19,7 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_deepseekV3
output_dir: saves/KT_FT_deepseekV3
logging_steps: 10
save_steps: 500
plot_loss: true
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500
use_kt: true
# Pair with fsdp2_kt_bf16.yaml for original BF16 checkpoints.
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
# kt_weight_path: /path/to/DeepSeek-V3-AMXINT8

View File

@@ -0,0 +1,46 @@
### model
model_name_or_path: Qwen/Qwen3.5-397B-A17B
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity, alpaca_en_demo
template: qwen3_5
cutoff_len: 2048
max_samples: 100000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/KT_FT_qwen35Moe
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true
# For original BF16 checkpoints, start with examples/ktransformers/accelerate/fsdp2_kt_bf16.yaml.
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
# Pair the 397B path with fsdp2_kt_int8.yaml, tune cutoff_len to prepared weights and GPU memory.
# kt_weight_path: /path/to/Qwen3.5-MoE-AMXINT8

View File

@@ -11,7 +11,7 @@ lora_target: all
### dataset
dataset: identity, alpaca_en_demo
template: qwen3_nothink
template: qwen3
cutoff_len: 2048
max_samples: 100000
overwrite_cache: true
@@ -19,9 +19,9 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/Kllama_Qwen3MoE_235bA22b
output_dir: saves/KT_FT_qwen3Moe
logging_steps: 10
save_steps: 200
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
@@ -31,7 +31,7 @@ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
@@ -39,14 +39,7 @@ ddp_timeout: 180000000
resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500
use_kt: true
# Pair with examples/ktransformers/accelerate/fsdp2_kt_bf16.yaml for original BF16 checkpoints.
# For pre-converted expert weights, uncomment kt_weight_path and use fsdp2_kt_int8.yaml or fsdp2_kt_int4.yaml.
# kt_weight_path: /path/to/Qwen3-235B-A22B-Instruct-2507-AMXINT8

View File

@@ -0,0 +1,31 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: normal
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: dynamic_batching
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: dynamic_padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,28 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: liger_kernel
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -40,7 +40,7 @@ dependencies = [
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
"transformers>=4.55.0,<=5.2.0,!=4.52.0,!=4.57.0",
"transformers>=4.55.0,<=5.6.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.18.0,<=0.18.1",

View File

@@ -0,0 +1 @@
ktransformers[sft]

View File

@@ -1,4 +1,5 @@
torch==2.7.1
torch-npu==2.7.1.post2
torch-npu==2.7.1.post4
torchvision==0.22.1
torchaudio==2.7.1
decorator

View File

@@ -0,0 +1,2 @@
--extra-index-url https://triton-ascend.osinfra.cn/pypi/simple
triton-ascend==3.2.1

View File

@@ -71,16 +71,6 @@ class ChatModel:
"SGLang not install, you may need to run `pip install sglang[all]`\n"
"or try to use HuggingFace backend: --infer_backend huggingface"
) from e
elif model_args.infer_backend == EngineName.KT:
try:
from .kt_engine import KTransformersEngine
self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
except ImportError as e:
raise ImportError(
"KTransformers not install, you may need to run `pip install ktransformers`\n"
"or try to use HuggingFace backend: --infer_backend huggingface"
) from e
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")

View File

@@ -1,284 +0,0 @@
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import platform
from collections.abc import AsyncGenerator
from threading import Thread
from typing import TYPE_CHECKING, Any, Optional
import torch
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import EngineName
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from trl import PreTrainedModelWrapper
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
from ktransformers.server.config.config import Config
from ktransformers.util.utils import (
get_compute_capability,
prefill_and_generate_capture,
)
from ktransformers.util.vendors import GPUVendor, device_manager
logger = logging.get_logger(__name__)
class KTransformersEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.KT
self.can_generate = finetuning_args.stage == "sft"
tok_mod = load_tokenizer(model_args)
self.tokenizer = tok_mod["tokenizer"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.generating_args = generating_args.to_dict()
self.max_new_tokens = model_args.kt_maxlen
self.use_cuda_graph = model_args.kt_use_cuda_graph
self.mode = model_args.kt_mode
self.force_think = model_args.kt_force_think
self.chunk_size = model_args.chunk_size
try:
asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod
@torch.inference_mode()
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: list[str],
input_kwargs: Optional[dict[str, Any]] = {},
) -> list[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=False,
).to(device)
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
async def _generate(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
paired = messages + [{"role": "assistant", "content": ""}]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
prompt_len = len(prompt_ids)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
if "max_new_tokens" in self.generating_args:
max_tokens = int(self.generating_args["max_new_tokens"])
elif "max_length" in self.generating_args:
gl = int(self.generating_args["max_length"])
max_tokens = gl - prompt_len if gl > prompt_len else 1
else:
max_tokens = self.max_new_tokens or 256
if max_length is not None:
max_tokens = max(max_length - prompt_len, 1)
if max_new_tokens is not None:
max_tokens = int(max_new_tokens)
max_tokens = max(1, int(max_tokens))
if self.mode == "long_context":
max_len_cfg = Config().long_context_config["max_seq_len"]
need = prompt_len + max_tokens
assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
device = next(self.model.parameters()).device
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
if self.force_think:
think = torch.tensor(
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
)
input_tensor = torch.cat([input_tensor, think], dim=1)
use_flashinfer = (
platform.system() != "Windows"
and getattr(self.model.config, "architectures", [""])[0]
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
and flashinfer_enabled
and get_compute_capability() >= 8
and device_manager.gpu_vendor == GPUVendor.NVIDIA
)
def make_gen():
if use_flashinfer:
return prefill_and_generate_capture(
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
use_flashinfer_mla=True,
num_heads=self.model.config.num_attention_heads,
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
+ getattr(self.model.config, "qk_nope_head_dim", 0),
echo_stream=False,
)
else:
return prefill_and_generate_capture(
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
echo_stream=False,
)
loop = asyncio.get_running_loop()
q: asyncio.Queue[Optional[str]] = asyncio.Queue()
def producer():
try:
gen = make_gen()
if hasattr(gen, "__aiter__"):
async def drain_async():
async for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
asyncio.run(drain_async())
elif hasattr(gen, "__iter__"):
for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
else:
loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
finally:
loop.call_soon_threadsafe(q.put_nowait, None)
Thread(target=producer, daemon=True).start()
while True:
item = await q.get()
if item is None:
break
yield item
@override
async def chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
async with self.semaphore:
produced = ""
final_text = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t
produced = produced + delta
if delta:
final_text += delta
prompt_ids, _ = self.template.encode_oneturn(
self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
)
return [
Response(
response_text=final_text,
response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
prompt_length=len(prompt_ids),
finish_reason="stop",
)
]
@override
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
async with self.semaphore:
produced = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t[len(produced) :] if t.startswith(produced) else t
produced = t
if delta:
yield delta
@override
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore:
return await asyncio.to_thread(self._get_scores, *args)

View File

@@ -157,9 +157,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
else:
self.get_rope_func = None
def _compute_rope_position_ids(
self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]
) -> None:
def _compute_rope_position_ids(self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]) -> None:
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
rope_index_kwargs = {
"input_ids": features["input_ids"],
@@ -167,8 +165,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(),
}
if features["attention_mask"].sum() == 0:
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape))
if features["attention_mask"].sum() == 0: # for pad tokens
seq_len = features["input_ids"].shape[-1]
features["position_ids"] = (
torch.arange(seq_len).view(1, 1, seq_len).expand(3, *features["input_ids"].shape).contiguous()
)
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
return
@@ -196,9 +197,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(-1)
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
@@ -224,7 +223,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length))
# avoid continual cuseqlens breaking varlen attention @kuangdd
# https://github.com/hiyouga/LlamaFactory/issues/10452
dummy_image_right_padding_mrope = (
torch.arange(fake_input_padding_length)
.view(1, 1, fake_input_padding_length)
.expand(3, bsz, fake_input_padding_length)
)
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
dummy_mm_inputs = copy.deepcopy(mm_inputs)
@@ -232,14 +237,20 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for sample_idx in range(bsz):
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
sequence_boundaries = sample_packing.get("sequence_boundaries")
num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
num_sub_seqs = (
(len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
)
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
images_per_subseq = (
[image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None
[image_subseq_ids.count(i) for i in range(num_sub_seqs)]
if image_subseq_ids and num_sub_seqs > 1
else None
)
videos_per_subseq = (
[video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None
[video_subseq_ids.count(i) for i in range(num_sub_seqs)]
if video_subseq_ids and num_sub_seqs > 1
else None
)
if has_dummy_image:
mm_inputs = {}
@@ -263,7 +274,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
subseq_end = sequence_boundaries[subseq_idx + 1]
subseq_features = {
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
"attention_mask": features["attention_mask"][
sample_idx : sample_idx + 1, subseq_start:subseq_end
],
}
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
mm_inputs,
@@ -272,10 +285,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
sample_idx,
images_per_subseq,
videos_per_subseq,
subseq_idx
subseq_idx,
)
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
sample_position_ids.append(subseq_features["position_ids"])
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
@@ -284,16 +298,22 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if has_dummy_image:
mm_inputs = dummy_mm_inputs
expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else (
all_position_ids[0].size(0),
bsz,
seq_len,
expected_position_ids_shape = (
(bsz, seq_len)
if all_position_ids[0].dim() == 2
else (
all_position_ids[0].size(0),
bsz,
seq_len,
)
)
# Check if position_ids shape matches expected shape.
# for further usage, we should padding to the right when some padding token on the right.
if has_dummy_image:
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1)
features["attention_mask"] = torch.cat(
[features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1
)
if features["position_ids"].shape != expected_position_ids_shape:
raise ValueError(
@@ -380,7 +400,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4
if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4
mm_token_type_ids = mm_inputs.pop("mm_token_type_ids")
max_len = max(len(ids) for ids in mm_token_type_ids)
padded = []
@@ -405,19 +425,17 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.get_rope_func is not None:
# for mmrope situation, we should calculate position_ids and rope_deltas per sample.
# When neat_packing is on, each sample has packing_params; None means no packing for that sample.
boundaries_list = [
p.get("sequence_boundaries") if p is not None else None for p in packing_params_list
]
boundaries_list = [p.get("sequence_boundaries") if p is not None else None for p in packing_params_list]
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
if has_dummy_image and has_packing:
# FIXME: too tricky, need to be refactored
# FIXME: too tricky, need to be refactored @kuangdd
features["has_dummy_image"] = True
# When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
if not has_packing:
self._compute_rope_position_ids(features, mm_inputs)
else:
if is_omni:
if is_omni: # TODO: support omni models for packed sequences @kuangdd
raise RuntimeError("Omni models are not supported for packed sequences for now.")
self._compute_rope_position_ids_with_packing(
@@ -471,8 +489,8 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
def __post_init__(self):
super().__post_init__()
if self.neat_packing and self.attn_implementation == "flash_attention_2":
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]:
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.")
if self.model is not None and getattr(self.model.config, "model_type", None) in ["gemma4", "gpt_oss"]:
raise ValueError("Neat packing is not supported for gemma4, gpt_oss models for now.")
@staticmethod
def _unpad_packed_features(features: dict[str, Any]) -> None:
@@ -493,7 +511,9 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
if key == "position_ids" and value.size(-1) == seq_len:
features[key] = value.index_select(-1, non_padding_indices)
elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len:
elif (
key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len
):
features[key] = value.index_select(1, non_padding_indices)
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
features[key] = value.index_select(1, non_padding_indices)
@@ -504,7 +524,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
if not has_dummy_image:
self._unpad_packed_features(features)

View File

@@ -257,8 +257,8 @@ class OpenAIDatasetConverter(DatasetConverter):
content = message[self.dataset_attr.content_tag]
if role in [self.dataset_attr.assistant_tag, self.dataset_attr.function_tag]:
if "tool_calls" in message and len(message["tool_calls"]) > 0:
tool_calls_list = [tool["function"] for tool in message["tool_calls"]]
if tool_calls := message.get("tool_calls"):
tool_calls_list = [tool["function"] for tool in tool_calls]
content = json.dumps(tool_calls_list, ensure_ascii=False)
role = self.dataset_attr.function_tag

View File

@@ -22,7 +22,8 @@ import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
import numpy as np
import torch
@@ -245,6 +246,14 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
r"""Build metadata used to expand video tokens without decoding frames."""
return None
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
@@ -642,7 +651,12 @@ class Gemma4Plugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
return {
"videos": results,
"fps_per_video": fps_per_video,
"durations": durations,
"frames_indices": frames_indices,
}
@override
def _get_mm_inputs(
@@ -674,8 +688,15 @@ class Gemma4Plugin(BasePlugin):
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"])
{
"fps": getattr(processor, "video_fps", 2.0),
"duration": duration,
"total_num_frames": len(video),
"frames_indices": sample_indices,
}
for video, duration, sample_indices in zip(
video_data["videos"], video_data["durations"], video_data["frames_indices"]
)
]
mm_inputs.update(
video_processor(
@@ -687,7 +708,7 @@ class Gemma4Plugin(BasePlugin):
)
)
if len(audios) != 0: # only for gemma4n
if len(audios) != 0: # only for gemma4n
audios = self._regularize_audios(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
@@ -695,11 +716,11 @@ class Gemma4Plugin(BasePlugin):
mm_inputs.update(
feature_extractor(
audios,
padding="max_length",
return_tensors="pt",
audios,
padding="max_length",
return_tensors="pt",
)
)
)
return mm_inputs
@@ -751,7 +772,10 @@ class Gemma4Plugin(BasePlugin):
num_soft_tokens_per_frame, metadata = next(video_iter)
if self.expand_mm_tokens:
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs]
frame_strs = [
f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
for ts in timestamp_strs
]
video_str = " ".join(frame_strs)
else:
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
@@ -760,7 +784,9 @@ class Gemma4Plugin(BasePlugin):
while AUDIO_PLACEHOLDER in content:
current_audio = next(audio_iter)
if self.expand_mm_tokens:
num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate)
num_audio_tokens = processor._compute_audio_num_tokens(
current_audio, processor.feature_extractor.sampling_rate
)
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
else:
audio_str = f"{boa_token}{audio_token}{eoa_token}"
@@ -786,8 +812,14 @@ class Gemma4Plugin(BasePlugin):
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
# Pop metadata keys that must not be passed to the model.
for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata",
"_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"):
for key in (
"num_soft_tokens_per_image",
"num_soft_tokens_per_video",
"video_metadata",
"_gemma4_fps_per_video",
"_gemma4_frames_indices",
"_gemma4_num_audio_soft_tokens",
):
mm_inputs.pop(key, None)
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
@@ -1409,6 +1441,225 @@ class MiniCPMVPlugin(BasePlugin):
return mm_inputs
@dataclass
class MiniCPMV4_6Plugin(BasePlugin):
"""Plugin for MiniCPM-V-4.6 with new transformers (NaViT vision + get_placeholder_mask API)."""
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
**kwargs,
) -> dict[str, "torch.Tensor"]:
image_processor = getattr(processor, "image_processor")
video_processor = getattr(processor, "video_processor", None)
mm_inputs = {}
if len(images) != 0:
# The image_processor ignores downsample_mode; target_sizes are always based on patch_size.
# downsample_mode only affects the token divisor in _build_v4_6_placeholder and model forward.
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
if video_processor is not None:
video_inputs = video_processor(videos, return_tensors="pt")
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values_videos"]
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes_videos"]
else:
video_inputs = image_processor(videos, return_tensors="pt")
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values"]
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes"]
if len(audios) != 0:
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
[audios],
chunk_input=True,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)
audio_feature_lens = [
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs})
return mm_inputs
def _build_v4_6_placeholder(
self,
image_inputs: dict[str, Any],
image_idx: int,
use_image_id: bool,
processor: "MMProcessor",
) -> str:
"""Build image placeholder for MiniCPM-V-4.6 using NaViT token count computation."""
grids = image_inputs.get("grids", [[0, 0]])
num_patches_per_image = image_inputs.get("num_patches_per_image", [1])
target_sizes = image_inputs.get("target_sizes")
downsample_mode = os.getenv("DOWNSAMPLE_MODE")
if downsample_mode is None:
image_processor = getattr(processor, "image_processor")
downsample_mode = getattr(image_processor, "downsample_mode", "16x")
token_divisor = 4 if downsample_mode == "4x" else 16
flat_index = 0
for idx in range(image_idx):
flat_index += num_patches_per_image[idx]
n_patches = num_patches_per_image[image_idx]
img_target_sizes = target_sizes[flat_index : flat_index + n_patches]
num_tokens_per_patch = img_target_sizes.prod(-1) // token_divisor
num_rows, num_cols = grids[image_idx]
image_start = getattr(processor, "image_start_token", "<image>")
image_end = getattr(processor, "image_end_token", "</image>")
slice_start = getattr(processor, "slice_start_token", "<slice>")
slice_end = getattr(processor, "slice_end_token", "</slice>")
image_id_start = getattr(processor, "image_id_start_token", "<image_id>")
image_id_end = getattr(processor, "image_id_end_token", "</image_id>")
image_token = (
getattr(processor, "image_token", None)
or getattr(getattr(processor, "tokenizer", None), "image_token", None)
or "<image>"
)
image_placeholder = image_start + "<|ph|>" * int(num_tokens_per_patch[0]) + image_end
if use_image_id:
image_placeholder = f"{image_id_start}{image_idx}{image_id_end}" + image_placeholder
slice_mode = getattr(processor, "slice_mode", True)
if slice_mode and num_rows > 0 and num_cols > 0:
per_slice_tokens = int(num_tokens_per_patch[1]) if len(num_tokens_per_patch) > 1 else 0
slice_placeholder = slice_start + "<|ph|>" * per_slice_tokens + slice_end
slices = [slice_placeholder * num_cols for _ in range(num_rows)]
image_placeholder += "\n".join(slices)
return image_placeholder.replace("<|ph|>", image_token)
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
mm_inputs, audio_inputs = {}, {}
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
use_image_id = getattr(processor, "default_use_image_id", True)
if len(videos) != 0:
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, [], processor)
for i, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
num_frames = 1
if "num_frames_per_video" in mm_inputs:
num_frames = sum(mm_inputs["num_frames_per_video"])
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * num_frames, 1)
num_video_tokens += 1
while AUDIO_PLACEHOLDER in content:
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
"{{audio}}", "(<audio>./</audio>)"
)
if len(images):
mm_inputs = self._get_mm_inputs(images, [], [], processor)
if len(audios):
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
if self.expand_mm_tokens and mm_inputs:
pattern = "(<image>./</image>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
image_placeholder = self._build_v4_6_placeholder(mm_inputs, idx, use_image_id, processor)
final_text = final_text + text_chunks[i] + image_placeholder
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if self.expand_mm_tokens and audio_inputs:
pattern = "(<audio>./</audio>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
audio_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(audio_tags)):
audio_placeholder = audio_inputs["audio_phs"][0][idx]
final_text = final_text + text_chunks[i] + audio_placeholder
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
# v4.6 does NOT use image_bound — the model finds image tokens via get_placeholder_mask
# Ensure target_sizes key name matches the model's expected input
if "target_sizes" not in mm_inputs and "tgt_sizes" in mm_inputs:
mm_inputs["target_sizes"] = mm_inputs.pop("tgt_sizes")
if "target_sizes" not in mm_inputs:
mm_inputs["target_sizes"] = torch.empty(0, 2, dtype=torch.int32)
if "pixel_values" not in mm_inputs:
mm_inputs["pixel_values"] = torch.empty(1, 3, 14, 0)
# Pass downsample_mode to model forward so it matches the placeholder divisor
_ds = os.getenv("DOWNSAMPLE_MODE")
if _ds is None:
_ds = getattr(getattr(processor, "image_processor", None), "downsample_mode", "16x")
mm_inputs["downsample_mode"] = _ds
if len(audios) > 0:
audio_inputs = self._get_mm_inputs([], [], audios, processor)
mm_inputs.update(audio_inputs)
return mm_inputs
@dataclass
class MllamaPlugin(BasePlugin):
@override
@@ -1696,7 +1947,9 @@ class Qwen2VLPlugin(BasePlugin):
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
original_fps = float(video_stream.average_rate)
# for qwen3vl video timestamp calculation
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
frames_indices.append(
[idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]
) # hack usage when do_sample_frames=False
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
@@ -1715,7 +1968,205 @@ class Qwen2VLPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
return {
"videos": results,
"fps_per_video": fps_per_video,
"durations": durations,
"frames_indices": frames_indices,
}
def _get_qwen_video_size_after_regularization(
self, width: int, height: int, image_max_pixels: int, image_min_pixels: int
) -> tuple[int, int]:
r"""Compute the frame size produced by Qwen-VL image regularization."""
if (width * height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if (width * height) < image_min_pixels:
resize_factor = math.sqrt(image_min_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if min(width, height) < 28:
width, height = max(width, 28), max(height, 28)
if width / height > 200:
width, height = height * 180, height
if height / width > 200:
width, height = width, width * 180
return width, height
def _get_qwen_video_stream_metadata(
self,
video: "VideoInput",
video_fps: float,
video_maxlen: int,
) -> Optional[dict[str, Any]]:
if not is_pyav_available() or not isinstance(video, (str, os.PathLike)):
return None
try:
container = av.open(video, "r")
except (av.FFmpegError, OSError):
return None
try:
video_stream = next((stream for stream in container.streams if stream.type == "video"), None)
if video_stream is None:
return None
if video_stream.duration is None or video_stream.average_rate is None:
return None
average_fps = float(video_stream.average_rate)
if average_fps <= 0:
return None
sample_indices = self._get_video_sample_indices(
video_stream, video_fps=video_fps, video_maxlen=video_maxlen
)
return {
"width": video_stream.width,
"height": video_stream.height,
"average_fps": average_fps,
"sample_indices": sample_indices,
}
finally:
container.close()
def _get_qwen_video_resize(
self,
num_frames: int,
height: int,
width: int,
patch_size: int,
temporal_patch_size: int,
merge_size: int,
min_pixels: int,
max_pixels: int,
) -> tuple[int, int]:
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
return smart_resize(
height=height,
width=width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
def _get_qwen_video_grid_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
if len(videos) == 0:
return {"video_grid_thw": torch.empty((0, 3), dtype=torch.long), "frames_indices": [], "fps": 2.0}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) or image_processor
if image_processor is None or video_processor is None:
return None
patch_size = getattr(video_processor, "patch_size", None)
temporal_patch_size = getattr(video_processor, "temporal_patch_size", None)
merge_size = getattr(video_processor, "merge_size", None)
size = getattr(video_processor, "size", None)
if patch_size is None or temporal_patch_size is None or merge_size is None or size is None:
return None
if isinstance(size, dict):
min_pixels = size.get("shortest_edge")
max_pixels = size.get("longest_edge")
else:
min_pixels = getattr(size, "shortest_edge", None)
max_pixels = getattr(size, "longest_edge", None)
if min_pixels is None or max_pixels is None:
return None
video_fps = getattr(processor, "video_fps", 2.0)
video_maxlen = getattr(processor, "video_maxlen", 128)
image_max_pixels = getattr(processor, "video_max_pixels", 256 * 256)
image_min_pixels = getattr(processor, "video_min_pixels", 16 * 16)
video_grid_thw = []
frames_indices = []
for video in videos:
metadata = self._get_qwen_video_stream_metadata(video, video_fps, video_maxlen)
if metadata is None:
return None
width, height = self._get_qwen_video_size_after_regularization(
metadata["width"], metadata["height"], image_max_pixels, image_min_pixels
)
num_frames = len(metadata["sample_indices"])
if num_frames % 2 != 0:
num_frames += 1
resized_size = self._get_qwen_video_resize(
num_frames,
height,
width,
patch_size,
temporal_patch_size,
merge_size,
min_pixels,
max_pixels,
)
resized_height, resized_width = resized_size
video_grid_thw.append(
[
math.ceil(num_frames / temporal_patch_size),
resized_height // patch_size,
resized_width // patch_size,
]
)
frames_indices.append([idx / metadata["average_fps"] * video_fps for idx in metadata["sample_indices"]])
return {
"video_grid_thw": torch.tensor(video_grid_thw, dtype=torch.long),
"frames_indices": frames_indices,
"fps": video_fps,
}
@override
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
if video_metadata is None:
return None
return {"video_grid_thw": video_metadata["video_grid_thw"]}
def _get_mm_token_metadata(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
if len(audios) != 0:
return None
mm_inputs = {}
if len(images) != 0:
mm_inputs.update(self._get_mm_inputs(images, [], [], processor))
if len(videos) != 0:
video_inputs = self._get_video_token_metadata(videos, processor)
if video_inputs is None:
return None
mm_inputs.update(video_inputs)
return mm_inputs
@override
def _get_mm_inputs(
@@ -1768,7 +2219,10 @@ class Qwen2VLPlugin(BasePlugin):
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
if mm_inputs is None:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
@@ -1802,6 +2256,51 @@ class Qwen2VLPlugin(BasePlugin):
@dataclass
class Qwen3VLPlugin(Qwen2VLPlugin):
@override
def _get_qwen_video_resize(
self,
num_frames: int,
height: int,
width: int,
patch_size: int,
temporal_patch_size: int,
merge_size: int,
min_pixels: int,
max_pixels: int,
) -> tuple[int, int]:
from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize
return smart_resize(
num_frames=num_frames,
height=height,
width=width,
temporal_factor=temporal_patch_size,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
@override
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
if video_metadata is None:
return None
return {
"video_grid_thw": video_metadata["video_grid_thw"],
"video_metadata": [
SimpleNamespace(
frames_indices=frames_indices,
fps=video_metadata["fps"],
)
for frames_indices in video_metadata["frames_indices"]
],
}
@override
def _get_mm_inputs(
self,
@@ -1830,8 +2329,15 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
{
"fps": getattr(processor, "video_fps", 2.0),
"duration": duration,
"total_num_frames": len(video),
"frames_indices": sample_indices,
}
for video, duration, sample_indices in zip(
videos["videos"], videos["durations"], videos["frames_indices"]
)
]
mm_inputs.update(
video_processor(
@@ -1839,7 +2345,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_metadata=video_metadata,
fps=getattr(processor, "video_fps", 2.0),
return_metadata=True,
do_sample_frames=False, # avoid changing frames_indices
do_sample_frames=False, # avoid changing frames_indices
)
)
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
@@ -1867,7 +2373,10 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
if mm_inputs is None:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
@@ -2405,6 +2914,7 @@ PLUGINS = {
"llava_next_video": LlavaNextVideoPlugin,
"lfm2_vl": LFMVLPlugin,
"minicpm_v": MiniCPMVPlugin,
"minicpm_v_4_6": MiniCPMV4_6Plugin,
"mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,

View File

@@ -27,7 +27,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
@dataclass
class PackingParams:
@@ -45,6 +46,7 @@ class PackingParams:
audio_subseq_ids: list[int]
right_padding_length: int
@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
@@ -61,7 +63,8 @@ class SupervisedDatasetProcessor(DatasetProcessor):
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
)
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
discarding_history_cot = self.data_args.mask_history and not self.template.preserve_thinking
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools, discarding_history_cot)
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
@@ -232,7 +235,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
if requires_packing_params:
packing_params = PackingParams(
sequence_boundaries=sequence_boundaries,
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
right_padding_length=pad_length,

View File

@@ -79,6 +79,7 @@ class Template:
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
discarding_history_cot: bool = False, # only effect reasoning template
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
@@ -441,14 +442,24 @@ class ReasoningTemplate(Template):
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
discarding_history_cot: bool = False,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False: # remove all cot
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if discarding_history_cot:
for i in range(1, len(messages) - 2, 2): # preserve the last cot
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if discarding_history_cot:
turn_indices = [len(messages) - 2]
else:
turn_indices = range(0, len(messages), 2)
for i in turn_indices:
if (
self.thought_words[0].strip() not in messages[i + 1]["content"]
and self.thought_words[1].strip() not in messages[i + 1]["content"]
@@ -822,6 +833,19 @@ register_template(
)
register_template(
name="hy3",
format_user=StringFormatter(slots=["<hy_User>{{content}}<hy_Assistant>"]),
format_assistant=StringFormatter(slots=["{{content}}<hy_eos>"]),
format_system=StringFormatter(slots=["{{content}}"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<hy_eos>"],
replace_eos=True,
thought_words=("<think>", "</think>"),
template_class=ReasoningTemplate,
)
register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
@@ -1007,15 +1031,17 @@ register_template(
name="gemma4",
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
format_system=StringFormatter(
slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]
), # default thought singal contained
format_observation=StringFormatter(
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
), # seem not consistent with the chattemplate
), # seem not consistent with the chattemplate
format_tools=ToolFormatter(tool_format="gemma4"),
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<turn|>"],
default_system="You are a helpful assistant.", # important for thinking
default_system="You are a helpful assistant.", # important for thinking
thought_words=("<|channel>thought\n", "<channel|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(
@@ -1031,15 +1057,15 @@ register_template(
name="gemma4n",
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
format_observation=StringFormatter(
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
),
format_system=StringFormatter(
slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]
), # default thought singal contained
format_observation=StringFormatter(slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]),
format_tools=ToolFormatter(tool_format="gemma4"),
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<turn|>"],
default_system="You are a helpful assistant.", # important for thinking
default_system="You are a helpful assistant.", # important for thinking
thought_words=("<|channel>thought\n", "<channel|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(
@@ -1678,6 +1704,17 @@ register_template(
)
register_template(
name="minicpm_v_4_6",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
default_system="You are a helpful assistant.",
mm_plugin=get_mm_plugin(name="minicpm_v_4_6", image_token="<image>", video_token="<video>"),
)
# copied from minicpm_v template
register_template(
name="minicpm_o",
@@ -2135,23 +2172,6 @@ register_template(
)
# copied from qwen3_5_nothink template
register_template(
name="qwen3_6_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
@@ -2362,4 +2382,3 @@ register_template(
efficient_eos=True,
template_class=Glm47ReasoningTemplate,
)

View File

@@ -209,6 +209,7 @@ class DefaultToolUtils(ToolUtils):
return results
class Gemma4ToolUtils(ToolUtils):
r"""Gemma-4 tool using template."""
@@ -292,7 +293,7 @@ class Gemma4ToolUtils(ToolUtils):
flags=re.DOTALL,
)
# Quote unquoted object keys so the payload can be parsed by json.loads.
normalized = re.sub(r'(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)', r'\1"\2"\3', normalized)
normalized = re.sub(r"(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)", r'\1"\2"\3', normalized)
try:
return json.loads(normalized)
except json.JSONDecodeError:
@@ -368,6 +369,7 @@ class Gemma4ToolUtils(ToolUtils):
return "".join(function_texts)
class GLM4ToolUtils(ToolUtils):
r"""GLM-4 tool using template."""

View File

@@ -139,7 +139,6 @@ class EngineName(StrEnum):
HF = "huggingface"
VLLM = "vllm"
SGLANG = "sglang"
KT = "ktransformers"
class DownloadSource(StrEnum):
@@ -887,6 +886,9 @@ register_model_group(
"Gemma-4-E4B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
},
"Gemma-4-12B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-12B-it",
},
},
template="gemma4n",
multimodal=True,
@@ -1257,6 +1259,17 @@ register_model_group(
)
register_model_group(
models={
"Hy3-Preview": {
DownloadSource.DEFAULT: "tencent/Hy3-preview",
DownloadSource.MODELSCOPE: "tencent/Hy3-preview",
},
},
template="hy3",
)
register_model_group(
models={
"Index-1.9B-Base": {
@@ -1902,6 +1915,17 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM5-1B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM5-1B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM5-1B",
},
},
template="empty",
)
register_model_group(
models={
"MiniCPM-o-2.6": {
@@ -1938,6 +1962,18 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM-V-4.6": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_6",
},
},
template="minicpm_v_4_6",
multimodal=True,
)
register_model_group(
models={
"Ministral-8B-Instruct-2410": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.5.dev0"
VERSION = "0.9.6.dev0"
def print_env() -> None:

View File

@@ -94,10 +94,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.55.0,<=5.2.0")
check_version("transformers>=4.55.0,<=5.6.0")
check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.18.0,<=0.18.1")
check_version("accelerate>=1.3.0,<=1.15.0")
check_version("peft>=0.18.0,<=0.20.0")
check_version("trl>=0.18.0,<=0.24.0")

View File

@@ -20,6 +20,7 @@ import importlib.util
from functools import lru_cache
from typing import TYPE_CHECKING
import transformers.utils.import_utils as import_utils
from packaging import version
@@ -87,7 +88,7 @@ def is_ray_available():
def is_kt_available():
return _is_package_available("ktransformers")
return _is_package_available("kt_kernel")
def is_requests_available():
@@ -126,3 +127,26 @@ def is_uvicorn_available():
def is_vllm_available():
return _is_package_available("vllm")
_orig_is_package_available = import_utils._is_package_available
class PackageAvailability(tuple):
__slots__ = ()
def __new__(cls, available: bool, pkg_version: str = "N/A"):
return super().__new__(cls, (bool(available), pkg_version))
def __bool__(self) -> bool:
return self[0]
def _patched_is_package_available(pkg_name: str, return_version: bool = False):
available, version = _orig_is_package_available(pkg_name, return_version=return_version)
return PackageAvailability(available, version)
if is_transformers_version_greater_than("5.3.0"):
import_utils._is_package_available = _patched_is_package_available

View File

@@ -190,4 +190,3 @@ class DataArguments:
def to_dict(self) -> dict[str, Any]:
return asdict(self)

View File

@@ -487,7 +487,7 @@ class FinetuningArguments(
metadata={
"help": (
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
"Only supported for the 'sft' stage with full fine-tuning."
"Only supported for the 'pt' and 'sft' stages with full fine-tuning."
)
},
)

View File

@@ -16,6 +16,7 @@
# limitations under the License.
import json
import os
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Self
@@ -460,47 +461,81 @@ class SGLangArguments:
@dataclass
class KTransformersArguments:
r"""Arguments pertaining to the KT training."""
r"""Arguments pertaining to KTransformers AMX MoE SFT training.
These fields are normalized into the transformers/accelerate KT config before training starts.
"""
use_kt: bool = field(
default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
metadata={"help": "Whether to use KTransformers AMX MoE backend for SFT training."},
)
kt_optimize_rule: str | None = field(
kt_weight_path: str | None = field(
default=None,
metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
metadata={"help": "Path to pre-quantized INT8 expert weights (.kt files)."},
)
cpu_infer: int | None = field(
default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."},
kt_expert_checkpoint_path: str | None = field(
default=None,
metadata={"help": "Path to expert checkpoint (safetensors) for online conversion."},
)
chunk_size: int | None = field(
default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
kt_use_lora_experts: bool | None = field(
default=None,
metadata={"help": "Whether to use GPU-side LoRA Experts."},
)
mode: str | None = field(
default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."},
kt_lora_expert_num: int | None = field(
default=None,
metadata={"help": "Number of GPU-side LoRA Experts."},
)
kt_lora_expert_intermediate_size: int | None = field(
default=None,
metadata={"help": "Intermediate size for GPU-side LoRA Experts."},
)
kt_maxlen: int = field(
default=4096,
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
)
kt_use_cuda_graph: bool = field(
default=True,
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
)
kt_mode: str = field(
default="normal",
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
)
kt_force_think: bool = field(
default=False,
metadata={"help": "Force-Think Toggle For The KT Engine."},
)
def get_kt_config_dict(self, finetuning_args: Any, model_max_length: int | None) -> dict[str, Any]:
r"""Build KT config values from LLaMA-Factory model and LoRA arguments."""
kt_config = {
"kt_lora_rank": getattr(finetuning_args, "lora_rank", None),
"kt_lora_alpha": getattr(finetuning_args, "lora_alpha", None),
"kt_weight_path": self.kt_weight_path,
"kt_expert_checkpoint_path": self.kt_expert_checkpoint_path,
"kt_model_max_length": model_max_length,
"kt_use_lora_experts": self.kt_use_lora_experts,
"kt_lora_expert_num": self.kt_lora_expert_num,
"kt_lora_expert_intermediate_size": self.kt_lora_expert_intermediate_size,
}
return {key: value for key, value in kt_config.items() if value is not None}
def apply_kt_config(self, finetuning_args: Any, training_args: Any, model_max_length: int | None) -> None:
r"""Apply LLaMA-Factory KT args to transformers/accelerate KT integration points."""
if not self.use_kt:
return
kt_config = self.get_kt_config_dict(finetuning_args, model_max_length)
env_mapping = {
"kt_weight_path": "ACCELERATE_KT_WEIGHT_PATH",
"kt_expert_checkpoint_path": "ACCELERATE_KT_EXPERT_CHECKPOINT_PATH",
"kt_model_max_length": "ACCELERATE_KT_MODEL_MAX_LENGTH",
"kt_lora_rank": "ACCELERATE_KT_LORA_RANK",
"kt_lora_alpha": "ACCELERATE_KT_LORA_ALPHA",
"kt_use_lora_experts": "ACCELERATE_KT_USE_LORA_EXPERTS",
"kt_lora_expert_num": "ACCELERATE_KT_LORA_EXPERT_NUM",
"kt_lora_expert_intermediate_size": "ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE",
}
for key, env_key in env_mapping.items():
value = kt_config.get(key)
if value is not None:
os.environ[env_key] = str(value)
hf_kt = getattr(training_args, "hf_kt_config", None)
if hf_kt is None or not hasattr(hf_kt, "_kt_config") or not isinstance(hf_kt._kt_config, dict):
return
hf_kt._kt_config.update(kt_config)
gc_enabled = getattr(training_args, "gradient_checkpointing", False) or not getattr(
self, "disable_gradient_checkpointing", True
)
if gc_enabled:
hf_kt._kt_config.setdefault("kt_share_cache_pool", True)
@dataclass

View File

@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_ARGS = [
ModelArguments,
DataArguments,
TrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_MCA_ARGS = [
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
else:
_TRAIN_MCA_ARGS = []
@@ -192,7 +208,9 @@ def _check_extra_dependencies(
training_args: Optional["TrainingArguments"] = None,
) -> None:
if model_args.use_kt:
check_version("ktransformers", mandatory=True)
check_version("kt-kernel", mandatory=True)
check_version("transformers-kt", mandatory=True)
check_version("accelerate-kt", mandatory=True)
if model_args.use_unsloth:
check_version("unsloth", mandatory=True)
@@ -467,7 +485,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
@@ -510,6 +528,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
)
transformers.set_seed(training_args.seed)
if model_args.use_kt:
model_args.apply_kt_config(finetuning_args, training_args, model_args.model_max_length)
return model_args, data_args, training_args, finetuning_args, generating_args

View File

@@ -14,6 +14,7 @@
import json
from dataclasses import dataclass, field
from typing import Optional
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
@@ -63,6 +64,58 @@ class RayArguments:
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
@dataclass
class ProfilerArguments:
r"""Arguments for torch profiler configuration."""
enable_torch_profiler: bool = field(
default=False,
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
)
profiler_output_dir: Optional[str] = field(
default=None,
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
)
profiler_wait_steps: int = field(
default=1,
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
)
profiler_warmup_steps: int = field(
default=1,
metadata={"help": "Number of profiler warm-up steps per cycle."},
)
profiler_active_steps: int = field(
default=1,
metadata={"help": "Number of steps to actively record per cycle."},
)
profiler_repeat: int = field(
default=1,
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
)
profiler_record_shapes: bool = field(
default=True,
metadata={"help": "Whether to record tensor shapes during profiling."},
)
profiler_profile_memory: bool = field(
default=True,
metadata={"help": "Whether to profile memory usage."},
)
profiler_with_stack: bool = field(
default=True,
metadata={"help": "Whether to record stack traces during profiling."},
)
profile_modules: Optional[str] = field(
default=None,
metadata={
"help": (
"Comma-separated list of module name patterns to profile with CUDA events. "
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
"Reports per-module forward/backward timing statistics at each logging step."
)
},
)
@dataclass
class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
@@ -87,7 +140,7 @@ class Fp8Arguments:
@dataclass
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(

View File

@@ -20,8 +20,6 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.constants import EngineName
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@@ -188,12 +186,6 @@ def _setup_lora_tuning(
"token": model_args.hf_hub_token,
}
if model_args.use_kt:
if model_args.infer_backend != EngineName.KT:
raise ValueError(
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
)
for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
@@ -202,12 +194,16 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_kt:
model = load_kt_peft_model(model_args, model)
elif model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if isinstance(model, PeftModel):
pass # already loaded via load_unsloth_peft_model in loader.py
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
if model_args.use_unsloth:
peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
if peft_model is not None:
model = peft_model
if not model_args.use_unsloth: # unsloth was disabled or fell back
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
@@ -217,16 +213,6 @@ def _setup_lora_tuning(
else:
target_modules = finetuning_args.lora_target
if model_args.use_kt:
new_list = []
for m in target_modules:
if m in ("down_proj", "up_proj", "gate_proj"):
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
new_list.append(m)
target_modules[:] = new_list
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
@@ -270,19 +256,11 @@ def _setup_lora_tuning(
}
if model_args.use_kt:
if finetuning_args.finetuning_type == "oft":
raise ValueError("KTransformers is currently not supported for OFT.")
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
else:
raise ValueError("KTransformers is currently only supported for LoRA.")
if finetuning_args.finetuning_type != "lora":
raise ValueError("KTransformers only supports LoRA finetuning.")
model = get_kt_peft_model(model, peft_config)
print(f"KT_model:{model}")
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, **peft_kwargs)
model = get_peft_model(model, peft_config)
elif model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")

View File

@@ -31,11 +31,10 @@ from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
@@ -143,19 +142,13 @@ def load_model(
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False
if model_args.use_kt:
from ktransformers.sft.monkey_patch_torch_module import install_patch
install_patch()
model = load_kt_pretrained_model(config, model_args)
elif model_args.use_unsloth:
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load:
if model is None:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
@@ -182,9 +175,8 @@ def load_model(
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections.abc import Iterable
from contextlib import nullcontext
from typing import TYPE_CHECKING, Optional
@@ -29,7 +30,81 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
def get_embedding_vocab_size(model: "PreTrainedModel") -> int:
r"""Get the vocab size from the input embedding layer.
Handles DeepSpeed ZeRO-3 parameter sharding by gathering the embedding weight
before reading its size.
"""
embedding = model.get_input_embeddings()
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
with deepspeed.zero.GatheredParameters([embedding.weight]):
return embedding.weight.size(0)
return embedding.weight.size(0)
def _resolve_new_token_ids(
new_tokens: Optional[Iterable[str]],
tokenizer: "PreTrainedTokenizer",
embed_size: int,
) -> Optional[list[int]]:
r"""Resolve the explicit embedding-row IDs of the newly added tokens.
Relying on ``embed_weight[-num_new_tokens:]`` to locate new tokens is unsafe when
the model embedding was already padded beyond the tokenizer vocab (e.g. Qwen2.5-VL
has vocab 151665 but embedding 151936). In that case the appended tokens land
inside the original padding zone and the tail slice points at the wrong rows.
Args:
new_tokens: Iterable of the newly added token strings.
tokenizer: The tokenizer instance.
embed_size: Current embedding size (upper bound for valid token IDs).
Returns:
A sorted list of unique, in-range token IDs, or ``None`` when no tokens are
given so that callers can fall back to the tail-slice behaviour.
"""
if not new_tokens:
return None
unk_token_id = getattr(tokenizer, "unk_token_id", None)
token_ids: set[int] = set()
for token_str in new_tokens:
token_id = tokenizer.convert_tokens_to_ids(token_str)
if token_id is None or token_id == unk_token_id or not (0 <= token_id < embed_size):
logger.warning_rank0(f"Token '{token_str}' not found or out of range, skipping during init.")
continue
token_ids.add(token_id)
return sorted(token_ids) or None
def _existing_embeddings(
embed_weight: "torch.Tensor", num_new_tokens: int, new_token_ids: Optional[list[int]]
) -> "torch.Tensor":
"""Return the rows treated as 'existing' embeddings used as the init baseline.
Prefers excluding the explicit new-token rows (robust to padding). Falls back to
dropping the last ``num_new_tokens`` rows when no explicit IDs are available.
"""
if new_token_ids:
mask = torch.ones(embed_weight.size(0), dtype=torch.bool, device=embed_weight.device)
mask[torch.as_tensor(new_token_ids, device=embed_weight.device, dtype=torch.long)] = False
return embed_weight[mask]
if num_new_tokens > 0:
return embed_weight[:-num_new_tokens]
return embed_weight
def _noisy_mean_initialization(
embed_weight: "torch.Tensor", num_new_tokens: int, token_ids: Optional[list[int]] = None
) -> None:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
@@ -37,12 +112,23 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
token_ids: Explicit token IDs to initialize. When provided, these exact rows are
written (robust to padding). When ``None``, falls back to the last
``num_new_tokens`` rows.
"""
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
avg_weight = _existing_embeddings(embed_weight, num_new_tokens, token_ids).mean(dim=0, keepdim=True)
if token_ids:
noise_weight = torch.empty(
len(token_ids), embedding_dim, device=embed_weight.device, dtype=embed_weight.dtype
)
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[token_ids] = avg_weight + noise_weight
else:
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _description_based_initialization(
@@ -51,6 +137,7 @@ def _description_based_initialization(
descriptions: dict[str, str],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
new_token_ids: Optional[list[int]] = None,
add_noise: bool = False,
) -> None:
"""Initialize new token embeddings based on textual descriptions.
@@ -61,6 +148,9 @@ def _description_based_initialization(
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
New tokens are placed by their resolved token ID rather than by tail slicing,
so the initialization is correct even when the embedding matrix was padded.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
@@ -68,6 +158,8 @@ def _description_based_initialization(
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
new_token_ids: IDs of all newly added tokens. Used to exclude not-yet-initialized
rows when averaging description-token embeddings (robust to embedding padding).
add_noise: Whether to add Gaussian noise to the initialization
Example:
@@ -77,38 +169,55 @@ def _description_based_initialization(
}
"""
embedding_dim = embed_weight.size(1)
vocab_size = embed_weight.size(0)
unk_token_id = getattr(tokenizer, "unk_token_id", None)
device = embed_weight.device
# The set of rows that are NOT yet initialized (the newly added tokens). Description
# tokens that fall into this set must be excluded, otherwise we would average garbage.
# `num_new_tokens` (the padded resize delta) is NOT a reliable boundary, so rely on
# the explicit IDs, falling back to resolving them from the description keys.
if new_token_ids is None:
new_token_ids = _resolve_new_token_ids(descriptions.keys(), tokenizer, vocab_size)
new_id_set = set(new_token_ids or [])
fallback_embedding = _existing_embeddings(embed_weight, num_new_tokens, new_token_ids).mean(dim=0)
for token_str, desc in descriptions.items():
# Resolve token ID for correct placement (robust to embedding padding)
token_id = tokenizer.convert_tokens_to_ids(token_str)
if token_id is None or token_id == unk_token_id or not (0 <= token_id < vocab_size):
logger.warning_rank0(f"desc_init: token '{token_str}' not found or out of range, skipping.")
continue
for i, desc in enumerate(descriptions.values()):
# Tokenize description text
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
token_ids = tokens["input_ids"][0]
# Move to the same device as embed_weight
device = embed_weight.device
token_ids = token_ids.to(device)
token_ids = tokens["input_ids"][0].tolist()
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
# Keep only description tokens that already have a meaningful embedding.
valid_token_ids = [tid for tid in token_ids if tid not in new_id_set and 0 <= tid < vocab_size]
if len(valid_token_ids) == 0:
# Fallback: use mean of all existing embeddings
logger.warning_rank0(
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
f"Description for token '{token_str}' contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
base_embedding = fallback_embedding
else:
# Get embeddings of description tokens and average them
token_embeds = model.get_input_embeddings()(valid_token_ids)
valid_ids_tensor = torch.as_tensor(valid_token_ids, device=device, dtype=torch.long)
token_embeds = model.get_input_embeddings()(valid_ids_tensor)
base_embedding = token_embeds.mean(dim=0)
# Add noise if requested (ensure correct device and dtype)
if add_noise:
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
embed_weight[-num_new_tokens + i] = base_embedding + noise
embed_weight[token_id] = base_embedding + noise
else:
embed_weight[-num_new_tokens + i] = base_embedding
embed_weight[token_id] = base_embedding
def _initialize_embeddings(
@@ -118,6 +227,7 @@ def _initialize_embeddings(
new_special_tokens_config: Optional[dict],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
new_token_ids: Optional[list[int]] = None,
) -> None:
"""Single source of truth for embedding initialization.
@@ -130,16 +240,18 @@ def _initialize_embeddings(
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
new_token_ids: Explicit IDs of the newly added tokens (robust to embedding padding).
When ``None``, the init helpers fall back to the last ``num_new_tokens`` rows.
"""
if init_method == "desc_init" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=False
)
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=True
)
else:
if init_method != "noise_init":
@@ -147,20 +259,28 @@ def _initialize_embeddings(
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
)
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
_noisy_mean_initialization(embed_weight, num_new_tokens)
_noisy_mean_initialization(embed_weight, num_new_tokens, token_ids=new_token_ids)
def resize_embedding_layer(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
new_tokens: Optional[Iterable[str]] = None,
new_special_tokens_config: Optional[dict] = None,
init_special_tokens: str = "noise_init",
) -> None:
r"""Resize token embeddings and initialize new tokens.
r"""Resize token embeddings (when needed) and initialize the newly added tokens.
Resizing and initialization are decoupled: even when the tokenizer vocab fits inside
the model's existing (padded) embedding matrix and no resize is triggered, the newly
added tokens still occupy uninitialized rows and must be initialized. We therefore
resolve the explicit row IDs of ``new_tokens`` and always initialize those rows.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_tokens: Iterable of the newly added token strings. Used to locate the exact
embedding rows to initialize, which is robust to pre-existing embedding padding.
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
@@ -175,44 +295,70 @@ def resize_embedding_layer(
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
current_embedding_size = get_embedding_vocab_size(model)
needs_resize = len(tokenizer) > current_embedding_size
if len(tokenizer) > current_embedding_size:
if needs_resize:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
# mean_resizing=False preserves the original embedding distribution exactly.
# HuggingFace's default mean_resizing=True re-samples new rows from the mean/covariance
# of existing embeddings, which conflicts with our explicit initialization below.
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64, mean_resizing=False)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
# Resolve the exact rows of the new tokens. This works whether or not a resize was
# triggered (e.g. tokens added into a model's pre-existing padding zone).
new_token_ids = _resolve_new_token_ids(new_tokens, tokenizer, new_embedding_size)
if num_new_tokens <= 0 and not new_token_ids:
return
if needs_resize:
logger.info_rank0(
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
)
else:
logger.info_rank0(
f"No resize needed (vocab fits in padded embedding {new_embedding_size}); "
f"initializing {len(new_token_ids or [])} new token(s) in place."
)
# Initialize input embeddings
# Initialize input embeddings
_initialize_embeddings(
model.get_input_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
new_token_ids=new_token_ids,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_input_embeddings().weight.data,
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
new_token_ids=new_token_ids,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
if needs_resize:
model.config.vocab_size = new_embedding_size
# Also update the nested text_config for VL models (e.g., Qwen2.5-VL, LLaVA),
# otherwise config.vocab_size and config.text_config.vocab_size become inconsistent.
if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "vocab_size"):
model.config.text_config.vocab_size = new_embedding_size
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@@ -1,154 +0,0 @@
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util as _u
from typing import TYPE_CHECKING, Any
import torch
from ...extras import logging
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from ...hparams import FinetuningArguments, ModelArguments
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
KT_AVAILABLE = _u.find_spec("ktransformers") is not None
if KT_AVAILABLE:
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeForCausalLM
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.server.config.config import Config
from ktransformers.sft.lora import inject_lora_layer
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
from ktransformers.util.globals import GLOBAL_CONFIG
from ktransformers.util.utils import load_weights
logger = logging.get_logger(__name__)
def _get_kt_kwargs(
config: "PretrainedConfig",
model_name_or_path: str,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"full_finetuning": finetuning_args.finetuning_type == "full",
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": model_args.trust_remote_code,
"use_gradient_checkpointing": "ktransformers",
}
def load_kt_pretrained_model(config: "PretrainedConfig", model_args: "ModelArguments") -> "PreTrainedModel":
r"""Optionally load pretrained model with KTransformers. Used in training."""
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"Qwen3MoeForCausalLM": Qwen3MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
Config().cpu_infer = model_args.cpu_infer
Config().chunk_size = model_args.chunk_size
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
if model_args.mode == "long_context":
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation=attn_implementation
)
optimize_config_path = model_args.kt_optimize_rule
gguf_path = model_args.model_name_or_path
assert optimize_config_path is not None, "optimize_config_path must be provided (path to YAML rules file)."
assert gguf_path is not None, "gguf_path must be provided (path to a folder or .gguf file)."
GLOBAL_CONFIG._config["mod"] = "infer"
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
return model
def get_kt_peft_model(model: "PreTrainedModel", peft_kwargs: dict[str, Any]) -> "PreTrainedModel":
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
from ktransformers.sft.peft_utils.mapping import get_peft_model
return get_peft_model(model, peft_kwargs)
def load_kt_peft_model(model_args: "ModelArguments", model: "PreTrainedModel") -> "PreTrainedModel":
r"""Load peft model with KTransformers. Used in both training and inference."""
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
if load_adapter_name_or_path.endswith(".gguf"):
inject_lora_layer(model, load_adapter_name_or_path)
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
model.train()
else:
inject_lora_layer(model, load_adapter_name_or_path)
adapter_loader = SafeTensorLoader(load_adapter_name_or_path)
device = next(model.parameters()).device
for key in adapter_loader.tensor_file_map.keys():
try:
tensor = adapter_loader.load_tensor(key, device=device)
model_key = key.replace("base_model.model.", "")
model_key = model_key.replace(".weight", ".default.weight")
model_key = model_key.replace(".default.default.weight", ".default.weight")
param = model.get_parameter(model_key)
param.data.copy_(tensor.data)
print(f"Loaded adapter weight: {key} -> {model_key}")
except AttributeError:
print(f"Skipping {key}: not a model parameter")
except KeyError:
print(f"Key not found in model: {model_key} (original: {key})")
return model

View File

@@ -16,6 +16,7 @@ import inspect
from typing import TYPE_CHECKING
from ...extras import logging
from ...extras.misc import get_device_name
if TYPE_CHECKING:
@@ -45,7 +46,7 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively
elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "glm4v":
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel
@@ -81,6 +82,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
elif model_type == "qwen3_5":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
elif model_type == "qwen3_5_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
@@ -97,5 +100,12 @@ def apply_liger_kernel(
else:
kwargs = {}
if get_device_name() == "npu":
import torch
if "Ascend910" not in torch.npu.get_device_name(0):
kwargs["swiglu"] = False
kwargs["fused_linear_cross_entropy"] = False
apply_liger_kernel(**kwargs)
logger.info_rank0("Liger kernel has been applied to the model.")

View File

@@ -62,6 +62,10 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
# deepseek v3 and kimi vl use custom code
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
if model_type == "hy_v3":
# hy3 uses custom code
_set_z3_leaf_modules(model, ["HYV3MoE"])
if model_type == "ernie4_5_moe":
from transformers.models.ernie4_5_moe.modeling_ernie4_5_moe import Ernie4_5_MoeSparseMoeBlock

View File

@@ -84,8 +84,12 @@ def load_unsloth_peft_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""Load peft model with unsloth. Used in both training and inference."""
) -> Optional["PreTrainedModel"]:
r"""Load peft model with unsloth. Used in both training and inference.
Returns None if unsloth does not support the model type, and sets
model_args.use_unsloth = False so callers can fall back to standard loading.
"""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
@@ -95,7 +99,9 @@ def load_unsloth_peft_model(
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
return None
if not is_trainable:
FastLanguageModel.for_inference(model)

View File

@@ -44,15 +44,16 @@ class CompositeModel:
language_model_keys: list[str]
lora_conflict_keys: list[str]
def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
mm_projectors: list[torch.nn.Module] = []
for projector_key in self.projector_keys:
project_module = module
for key in projector_key.split("."):
project_module = getattr(project_module, key, None)
if project_module is None: # i,e gemma4 bigger one, there is no embed_audio
logger.warning_rank0(f"Projector key {projector_key} not found in module {module.__class__.__name__}.")
if project_module is None: # i,e gemma4 bigger one, there is no embed_audio
logger.warning_rank0(
f"Projector key {projector_key} not found in module {module.__class__.__name__}."
)
break
if project_module is not None:
@@ -320,6 +321,14 @@ _register_composite_model(
)
_register_composite_model(
model_type="minicpmv4_6",
projector_keys=["model.merger"],
vision_model_keys=["model.vision_tower"],
language_model_keys=["model.language_model", "lm_head"],
)
_register_composite_model(
model_type="minicpmo",
projector_keys=["resampler"],

View File

@@ -20,6 +20,7 @@ from peft import PeftModel
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
from ..extras import logging
from ..extras.misc import infer_optim_dtype
@@ -60,6 +61,248 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
def _check_fla_dependencies() -> None:
"""Check that the FLA dependencies required for varlen GDN forwarding are available.
Requires ``flash-linear-attention >= 0.4.1`` (which exposes the varlen
``causal_conv1d`` under ``fla.modules.convolution`` and the
``chunk_gated_delta_rule`` / ``fused_recurrent_gated_delta_rule`` kernels
under ``fla.ops.gated_delta_rule``). Raises ``ImportError`` with an
actionable message otherwise.
"""
try:
from fla.modules.convolution import causal_conv1d # noqa: F401
from fla.ops.gated_delta_rule import ( # noqa: F401
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
except ImportError as exc:
raise ImportError(
"Qwen3.5 packing-seq forwarding requires `flash-linear-attention>=0.4.1` "
"(provides `fla.modules.convolution.causal_conv1d` and "
"`fla.ops.gated_delta_rule.{chunk,fused_recurrent}_gated_delta_rule`). "
"Please install/upgrade it."
) from exc
def patch_qwen3_5_forward_npu(model: "PreTrainedModel") -> None:
"""Patch for Qwen3.5 models on NPU by importing torch_npu to enable torch.cuda compatibility.
On NPU, torch.cuda operations will fail unless torch_npu is imported.
torch_npu provides compatibility layer that maps torch.cuda calls to NPU operations.
Also replaces chunk_gated_delta_rule with NPU-compatible implementation.
"""
import importlib.metadata
if "Ascend910" not in torch.npu.get_device_name(0):
logger.warning_rank0("Currently only 910B series NPUs are supported for the NPU GDN patch.")
return
try:
importlib.metadata.version("triton_ascend")
except importlib.metadata.PackageNotFoundError:
logger.warning_rank0(
"triton_ascend not installed, skipping NPU GDN patch. "
"To enable it on NPU, reinstall Triton with the Ascend build: "
"`pip uninstall -y triton && pip install -r requirements/triton_ascend.txt`. "
"Note: triton and triton_ascend cannot coexist — triton must be uninstalled first."
)
return
logger.info_rank0("triton_ascend detected for NPU compatibility.")
from ..third_party.triton.chunk_gated_delta_rule import chunk_gated_delta_rule as npu_chunk_gated_delta_rule
if model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
try:
# Qwen3.5-MoE structure: model.model.language_model.layers
for layer in model.model.language_model.layers:
if hasattr(layer, "linear_attn"):
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
logger.info_rank0(
"Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5-MoE model."
)
except Exception as e:
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
elif model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
try:
# Qwen3.5 structure: model.model.layers
for layer in model.model.layers:
if hasattr(layer, "linear_attn"):
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
logger.info_rank0("Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5 model.")
except Exception as e:
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
def patch_qwen3_5_forward_gpu(model: "PreTrainedModel") -> None:
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
"""
if is_transformers_version_greater_than("5.2.0"):
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
from torch.nn import functional as F
from transformers.modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
_check_fla_dependencies()
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
def _patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
cache_position: torch.LongTensor | None = None,
**kwargs,
) -> torch.FloatTensor:
"""Decoder layer forward that passes position_ids through to linear attention."""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids, # passing position_ids to linear attention
)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids[None, 0], # keep [1, B, L]
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
# gdn forward (training only, cache_params is always None)
def _patch_gdn_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
):
# @kuangdd fix: here attention_mask is None
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
# Qwen3.5 VL passes 3-D MRoPE position_ids ([axes, B, T]); collapse to [B, T].
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
if position_ids is not None and batch_size == 1:
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
else:
cu_seqlens = None
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.
mixed_qkv = self.in_proj_qkv(hidden_states)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
# FLA's causal_conv1d returns (out, final_state); we don't use the state here.
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
)
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
output = self.out_proj(core_attn_out)
return output
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeDecoderLayer,
Qwen3_5MoeGatedDeltaNet,
)
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
logger.info_rank0("Patched Qwen3.5 decoder forward to support cu_seqlens input only patch when do training.")
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
original_forward = model.forward
@@ -214,9 +457,14 @@ def patch_model(
prepare_valuehead_model(model)
if model_args.resize_vocab:
# Pass the explicit list of newly added tokens so their exact embedding rows can be
# located and initialized, even when they land in a model's pre-existing padding zone.
new_tokens = (model_args.add_tokens or []) + (model_args.add_special_tokens or [])
resize_embedding_layer(
model,
tokenizer,
new_tokens=new_tokens or None,
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
init_special_tokens=model_args.init_special_tokens,
)
@@ -232,6 +480,13 @@ def patch_model(
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"]:
if is_torch_npu_available():
patch_qwen3_5_forward_npu(model)
elif is_torch_cuda_available() and model_args.flash_attn == "fa2":
# this is the patch for packing/neat_packing for GPU GDN. And when setting packing, flash_attn must be fa2.
patch_qwen3_5_forward_gpu(model)
if not model_args.use_unsloth:
print_attn_implementation(model.config)

View File

@@ -0,0 +1,594 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import get_autotune_config, get_npu_properties, prepare_chunk_indices, prepare_chunk_offsets
CUBE_CORE_NUM = get_npu_properties()["num_aicore"]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=get_autotune_config(multibuffer_list=(False,)),
key=["H", "K", "V", "BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
w,
v_new,
g,
gk,
h,
h0,
ht,
cu_seqlens,
chunk_offsets,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_all = T
NT_all = NT
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# Initialize hidden states
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
if IS_VARLEN:
v = v + (i_h * T_all + bos) * V
k = k + (i_h * T_all + bos) * K
w = w + (i_h * T_all + bos) * K
g = g + i_h * T_all + bos
h = h + (i_h * NT_all + boh) * K * V
if SAVE_NEW_VALUE:
v_new_base = v_new + (i_h * T_all + bos) * V
else:
v = v + (i_n * H + i_h) * T * V
k = k + (i_n * H + i_h) * T * K
w = w + (i_n * H + i_h) * T * K
g = g + (i_n * H + i_h) * T
h = h + (i_n * H + i_h) * NT * K * V
if SAVE_NEW_VALUE:
v_new_base = v_new + (i_n * H + i_h) * T * V
if USE_INITIAL_STATE:
h0_ptr = h0 + i_nh * K * V
if STORE_FINAL_STATE:
ht_ptr = ht + i_nh * K * V
# Load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# Main recurrence over chunks
for i_t in range(NT):
# Store current hidden state h_t
p_h1 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
# Compute v_residual = v - w @ h
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
p_v = tl.make_block_ptr(v, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(v_new_base, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
# Apply output gate g
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
b_g_last = tl.load(g + last_idx)
p_g = tl.make_block_ptr(g, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_v *= (m_t * tl.exp(b_g_last - b_g))[:, None]
b_g_last_exp = tl.exp(b_g_last)
b_h1 *= b_g_last_exp
if K > 64:
b_h2 *= b_g_last_exp
if K > 128:
b_h3 *= b_g_last_exp
if K > 192:
b_h4 *= b_g_last_exp
# Apply key gate gk
if USE_GK:
o_k1 = tl.arange(0, 64).to(tl.float32)
gk_base_ptr = gk + (i_n * H + i_h) * T * K
b_gk_last1 = tl.load(gk_base_ptr + last_idx * K + o_k1, mask=(o_k1 < K), other=0.0)
b_h1 *= tl.exp(b_gk_last1)[:, None]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(gk_base_ptr + last_idx * K + o_k2, mask=(o_k2 < K), other=0.0)
b_h2 *= tl.exp(b_gk_last2)[:, None]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(gk_base_ptr + last_idx * K + o_k3, mask=(o_k3 < K), other=0.0)
b_h3 *= tl.exp(b_gk_last3)[:, None]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(gk_base_ptr + last_idx * K + o_k4, mask=(o_k4 < K), other=0.0)
b_h4 *= tl.exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)
# Update hidden state: h += k @ v
p_k = tl.make_block_ptr(k, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last1[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h1 += tl.dot(b_k, b_v)
if K > 64:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last2[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h2 += tl.dot(b_k, b_v)
if K > 128:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last3[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h3 += tl.dot(b_k, b_v)
if K > 192:
p_k = tl.make_block_ptr(k, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
b_k = (b_k * tl.exp(b_gk_last4[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
b_h4 += tl.dot(b_k, b_v)
# Store final state
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # default:64
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, u.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V).permute(0, 2, 1, 3, 4).contiguous()
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
BV = 128
v_new = torch.empty_like(u).permute(0, 2, 1, 3).contiguous() if save_new_value else None
k = k.permute(0, 2, 1, 3).contiguous()
w = w.permute(0, 2, 1, 3).contiguous()
u = u.permute(0, 2, 1, 3).contiguous()
g = g.permute(0, 2, 1).contiguous()
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[(triton.cdiv(V, BV), N * H)](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
NT=NT,
)
h = h.permute(0, 2, 1, 3, 4).contiguous()
v_new = v_new.permute(0, 2, 1, 3).contiguous()
return h, v_new, final_state
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["dh0"] is not None,
"USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=get_autotune_config(multibuffer_list=(True, False)),
key=["H", "K", "V", "BT", "BV", "USE_G", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
q,
k,
w,
g,
gk,
dht,
dh0,
do,
dh,
dv,
dv2,
cu_seqlens,
chunk_offsets,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_FINAL_STATE_GRADIENT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_all = T
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_dh2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_dh3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
w += (bos * H + i_h) * K
do += (bos * H + i_h) * V
dv += (bos * H + i_h) * V
dv2 += (bos * H + i_h) * V
dh += (boh * H + i_h) * K * V
if USE_GK:
gk += (bos * H + i_h) * K
if USE_INITIAL_STATE:
dh0 += i_nh * K * V
if USE_FINAL_STATE_GRADIENT:
dht += i_nh * K * V
stride_v = H * V
stride_h = H * K * V
stride_k = H * K
if USE_FINAL_STATE_GRADIENT:
p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_dh1 += tl.load(p_dht1, boundary_check=(0, 1))
if K > 64:
p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_dh2 += tl.load(p_dht2, boundary_check=(0, 1))
if K > 128:
p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_dh3 += tl.load(p_dht3, boundary_check=(0, 1))
if K > 192:
p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_dh4 += tl.load(p_dht4, boundary_check=(0, 1))
for i_t in range(NT - 1, -1, -1):
p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
if IS_VARLEN:
bos_g = i_h * T_all + bos
else:
bos_g = (i_n * H + i_h) * T_all
bg_last = tl.load(g + bos_g + last_idx)
bg_last_exp = tl.exp(bg_last)
p_g = tl.make_block_ptr(
base=g + bos_g, shape=(T,), strides=(1,), offsets=(i_t * BT,), block_shape=(BT,), order=(0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_exp = tl.exp(b_g)
p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
# Update dv
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0)
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
if K > 64:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
if K > 128:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
if K > 192:
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0)
b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
b_dv *= (m_t * tl.exp(bg_last - b_g))[:, None]
b_dv += tl.load(p_dv, boundary_check=(0, 1))
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# Update dh
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
if USE_G:
b_dh1 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh1 *= tl.exp(b_gk_last1[:, None])
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 64:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh2 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh2 *= tl.exp(b_gk_last2[:, None])
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 128:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh3 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh3 *= tl.exp(b_gk_last3[:, None])
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 192:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
b_dh4 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
if USE_GK:
b_dh4 *= tl.exp(b_gk_last4[:, None])
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if USE_INITIAL_STATE:
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_bwd_dhu(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
do: torch.Tensor,
dv: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
h0: torch.Tensor | None = None,
dht: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *q.shape, do.shape[-1]
# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
dh = q.new_empty(B, NT, H, K, V)
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
dv2 = torch.empty_like(dv)
BV = 128
g = g.permute(0, 2, 1).contiguous()
chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[(triton.cdiv(V, BV), N * H)](
q=q,
k=k,
w=w,
g=g,
gk=gk,
dht=dht,
dh0=dh0,
do=do,
dh=dh,
dv=dv,
dv2=dv2,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BV=BV,
)
return dh, dh0, dv2

View File

@@ -0,0 +1,347 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional
import torch
from .chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .solve_tril import solve_tril
from .utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
from .wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
):
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
return g, o, A, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
):
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=False,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dv = chunk_bwd_dv_local(
q=q,
k=k,
g=g,
do=do,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=q,
k=k,
w=w,
g=g,
h0=initial_state,
dht=dht,
do=do,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dq, dk, dw, dg = chunk_bwd_dqkwg(
q=q,
k=k,
v=v_new,
w=w,
g=g,
h=h,
dv=dv,
do=do,
dh=dh,
chunk_size=chunk_size,
scale=scale,
cu_seqlens=cu_seqlens,
)
dk2, dv, db, dg2 = prepare_wy_repr_bwd(
k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size
)
dk.add_(dk2)
dg.add_(dg2)
if dg.dtype != torch.float32:
raise ValueError(f"dg current type is {dg.dtype} , should be float32")
dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False)
return dq, dk, dv, db, dg, dh0
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@autocast_custom_fwd
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
chunk_size: int = 64,
):
q_rstd, k_rstd = None, None
g, o, A, final_state = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
ctx.chunk_size = chunk_size
return o.to(q.dtype), final_state
@staticmethod
@input_guard
@autocast_custom_bwd
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
A=A,
scale=ctx.scale,
initial_state=initial_state,
do=do,
dht=dht,
cu_seqlens=cu_seqlens,
chunk_size=ctx.chunk_size,
)
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
head_first: bool = False,
):
r"""Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor):
betas of shape `[B, T, H]`.
scale (Optional[float]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
use_qk_l2norm_in_kernel (bool):
Whether to apply L2norm to the q/k tensor internally. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
This argument has been deprecated.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
""" # noqa: D205
if q.dtype != k.dtype or k.dtype != v.dtype:
raise ValueError(
f"q current type is {q.dtype} , k current type is {k.dtype} ,v current type is {v.dtype} , they should are equal"
)
if q.dtype == torch.float32:
raise ValueError("ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.")
if len(beta.shape) != 3:
raise ValueError(
f"beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
)
if head_first:
warnings.warn(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
original_dtype = x.dtype
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
# Counteract verl's autocast promotion (bf16 -> fp32) by restoring original dtype
return (x * inv_norm).to(original_dtype)
if use_qk_l2norm_in_kernel:
q = l2norm(q, dim=-1, eps=1e-6)
k = l2norm(k, dim=-1, eps=1e-6)
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, False, chunk_size
)
return o, final_state

View File

@@ -0,0 +1,617 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import exp, prepare_chunk_indices, prepare_chunk_offsets
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"USE_DW": lambda args: args["dw"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_bwd_kernel_dqkwg(
q,
k,
v,
h,
g,
g_gamma,
do,
dh,
dq,
dk,
dg,
w,
dv,
dw,
cu_seqlens,
chunk_indices,
scale,
B: tl.constexpr,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
USE_DW: tl.constexpr,
IS_VARLEN: tl.constexpr,
gdiff,
):
i_t, i_b = tl.program_id(0), tl.program_id(1)
T_max = T
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
total = B * T_max
T = eos - bos
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
total = B * T_max
NK = tl.cdiv(K, BK)
for i_k in range(NK):
if USE_G:
dg_k = dg + i_k * total * H
for i_h in range(H):
v_h = v + (bos * H + i_h) * V
do_h = do + (bos * H + i_h) * V
h_h = h + (i_tg * H + i_h).to(tl.int64) * K * V
dh_h = dh + (i_tg * H + i_h).to(tl.int64) * K * V
q_h = q + (bos * H + i_h) * K
k_h = k + (bos * H + i_h) * K
dq_h = dq + (bos * H + i_h) * K
dk_h = dk + (bos * H + i_h) * K
if USE_DW:
w_h = w + (bos * H + i_h) * K # noqa: F841
dw_h = dw + (bos * H + i_h) * K
dv_h = dv + (bos * H + i_h) * V
if USE_G:
if IS_VARLEN:
dg_h = dg_k + i_h * T_max + bos
g_h = g + i_h * T_max + bos
else:
dg_h = dg_k + (i_b * H + i_h) * T_max
g_h = g + (i_b * H + i_h) * T_max
b_dg_last = tl.zeros(
[
1,
],
dtype=tl.float32,
)
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
b_g_last = b_gamma * min(BT, T - i_t * BT)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
if USE_G:
b_dg_last += tl.sum(b_h * b_dh)
b_ds += tl.dot(b_do, tl.trans(b_v))
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
if USE_DW:
p_dv = tl.make_block_ptr(dv_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
if USE_DW:
p_dw = tl.make_block_ptr(dw_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
p_q = tl.make_block_ptr(q_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
p_dq = tl.make_block_ptr(dq_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_dg = tl.zeros(
[
BT,
],
dtype=tl.float32,
)
p_g = tl.make_block_ptr(g_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_last = tl.load(g_h + (min(i_t * BT + BT, T) - 1) * 1)
b_dg_last *= tl.exp(b_g_last)
b_dq = b_dq * tl.exp(b_g)[:, None] * scale
b_dg += tl.sum(b_dq * b_q, axis=1)
b_dk = b_dk * tl.where(m_t, tl.exp(-b_g + b_g_last), 0)[:, None]
b_dg -= tl.sum(b_k * b_dk, axis=1)
b_dg_last += tl.sum(b_dk * b_k)
if IS_VARLEN:
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
else:
p_gdiff = tl.make_block_ptr(
gdiff + i_b * H * NT * BT * BT + i_h * NT * BT * BT + i_t * BT * BT,
(BT, BT),
(BT, 1),
(0, 0),
(BT, BT),
(1, 0),
)
gdiff_ = tl.load(p_gdiff)
b_ds = b_ds * gdiff_ * scale
b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
b_dg += tl.sum(b_ds2, axis=1)
b_dg -= tl.sum(b_ds2, axis=0)
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)
p_dg = tl.make_block_ptr(dg_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
last_index_local = min(BT, T - i_t * BT) - 1
if last_index_local >= 0:
is_last_mask = tl.arange(0, BT) == last_index_local
b_dg = tl.where(is_last_mask, b_dg + b_dg_last, b_dg)
else:
pass
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
elif USE_G_GAMMA:
b_dq = b_dq * exp(b_g)[:, None] * scale
b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None]
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
else:
b_ds = tl.where(m_A, b_ds, 0)
b_ds = b_ds.to(b_k.dtype)
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_bwd_kernel_dv_local(
q,
k,
g,
g_gamma,
do,
dv,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_b = tl.program_id(0), tl.program_id(1)
T_max = T
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
for i_h in range(H):
offset_kh = (bos * H + i_h) * K
offset_vh = (bos * H + i_h) * V
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + offset_kh, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_q = tl.make_block_ptr(q + offset_kh, (K, T), (1, H * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, b_q)
if USE_G:
if IS_VARLEN:
offset_g = i_h * T_max + bos
else:
offset_g = i_b * H * T_max + i_h * T_max
p_g = tl.make_block_ptr(g + offset_g, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_A = tl.where(m_A, b_A * tl.exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
else:
b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_do = tl.make_block_ptr(do + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
v,
h,
g,
g_gamma,
o,
cu_seqlens,
chunk_offsets,
scale,
T,
H: tl.constexpr,
N: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_max = T
for i_v in range(tl.cdiv(V, BV)):
for i_n in range(N):
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
base_chunks_per_pid = NT // total_cores
remainder = NT % total_cores
if core_id < remainder:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * base_chunks_per_pid + remainder
# offset calculation
for i_h in range(0, H):
q_offset = (bos * Hg + i_h // (H // Hg)) * K
k_offset = (bos * Hg + i_h // (H // Hg)) * K
v_offset = (bos * H + i_h) * V
o_offset = (bos * H + i_h) * V
for i_t in range(start_idx, start_idx + chunks_this_pid):
i_tg = boh + i_t
h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + q_offset, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k + k_offset, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o += tl.dot(b_q, b_h)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A += tl.dot(b_q, b_k)
if USE_G:
if IS_VARLEN:
p_g = tl.make_block_ptr(g + bos + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
else:
p_g = tl.make_block_ptr(g + bos * H + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v + v_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + o_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_bwd_dqkwg(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
do: torch.Tensor,
h: torch.Tensor,
dh: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
w: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
scale: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 128 if cu_seqlens is None else 64
BV = 64
NK = triton.cdiv(K, BK)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
g = g.transpose(1, 2).contiguous()
dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
dw = torch.empty_like(w) if w is not None else None
grid = (NT, B)
if cu_seqlens is None:
if NT * BT == T:
g_ = g.reshape(B, H, NT, BT)
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
g_diff = g_diff.clamp(-60, 60).exp()
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
else:
diff = NT * BT - T
g_ = torch.cat((g, torch.zeros(B, H, diff).to(g.device)), dim=-1).reshape(B, H, NT, BT)
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
g_diff = g_diff.clamp(-60, 60).exp()
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
bias = torch.arange(0, BT).to(g.device)
o_t = (NT - 1) * BT + bias
m_t = o_t < T
m_A = m_t[:, None] & m_t
g_diff[:, :, -1] *= m_A
else:
g_diff = None
chunk_bwd_kernel_dqkwg[grid](
q=q,
k=k,
v=v,
h=h,
g=g,
g_gamma=g_gamma,
do=do,
dh=dh,
dv=dv,
w=w,
dw=dw,
dq=dq,
dk=dk,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
B=B,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
gdiff=g_diff,
)
if dg is not None:
dg = dg.sum(0)
dg = dg.transpose(1, 2).contiguous()
return dq, dk, dw, dg
def chunk_bwd_dv_local(
q: torch.Tensor,
k: torch.Tensor,
do: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
scale: float = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, H, K, V = *k.shape, do.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
BK = 128
BV = 128
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g = g.transpose(1, 2).contiguous()
dv = torch.empty_like(do)
grid = (NT, B)
chunk_bwd_kernel_dv_local[grid](
q=q,
k=k,
g=g,
g_gamma=g_gamma,
do=do,
dv=dv,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return dv
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None,
g_gamma: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) # noqa: F841
if scale is None:
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
if cu_seqlens is None:
N, chunk_offsets = B, None
else:
N, chunk_offsets = (
len(cu_seqlens) - 1,
prepare_chunk_offsets(cu_seqlens, BT),
)
def grid(meta):
return (triton.cdiv(V, meta["BV"]), N * H)
g = g.transpose(1, 2).contiguous()
h = h.contiguous()
CV_kernel_num = 24
chunk_fwd_kernel_o[(CV_kernel_num,)](
q,
k,
v,
h,
g,
g_gamma,
o,
cu_seqlens,
chunk_offsets,
scale,
T=T,
H=H,
N=N,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=128,
BV=128,
)
return o
bwd_chunk_dqkwg = chunk_bwd_dqkwg
bwd_chunk_dv_local = chunk_bwd_dv_local

View File

@@ -0,0 +1,359 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import prepare_chunk_indices
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
USE_G: tl.constexpr,
NT,
B,
TOTAL_TASKS,
):
core_id = tl.program_id(0)
num_blocks = tl.num_programs(0)
T_max = T
base_tasks_per_block = TOTAL_TASKS // num_blocks
remainder_tasks = TOTAL_TASKS % num_blocks
if core_id < remainder_tasks:
tasks_this_core = base_tasks_per_block + 1
start_idx = core_id * tasks_this_core
else:
tasks_this_core = base_tasks_per_block
start_idx = core_id * base_tasks_per_block + remainder_tasks
for idx in range(start_idx, start_idx + tasks_this_core):
i_b = idx // NT
local_idx = idx % NT
if IS_VARLEN:
i_n = tl.load(chunk_indices + local_idx * 2).to(tl.int32)
i_t = tl.load(chunk_indices + local_idx * 2 + 1).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_local = eos - bos
else:
bos, eos = 0, T
i_t = local_idx
T_local = T
for i_h in range(H):
k_batch_off = i_b * T_max * H * K
beta_batch_off = i_b * H * T_max
g_batch_off = i_b * H * T_max
A_batch_off = i_b * T_max * H * BT
p_beta = tl.make_block_ptr(
beta + beta_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + k_batch_off + (bos * H + i_h) * K,
(T_local, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
dot_product = tl.dot(b_k, tl.trans(b_k))
o_t = i_t * BT + tl.arange(0, BT)
o_t = o_t.to(tl.float32)
T_mask = (o_t < T_local).to(tl.float32)
row_indices = tl.arange(0, BT)[:, None]
col_indices = tl.arange(0, BT)[None, :]
tril_mask = (row_indices > col_indices).to(tl.float32)
tril_mask = tril_mask * T_mask[:, None]
masked_dot = dot_product * tril_mask
b_A += masked_dot
if USE_G:
p_g = tl.make_block_ptr(
g + g_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_g_diff = tl.minimum(tl.maximum(b_g_diff, -50.0), 50.0)
b_A *= tl.exp(b_g_diff)
b_A *= b_beta[:, None]
p_A = tl.make_block_ptr(
A + A_batch_off + (bos * H + i_h) * BT, (T_local, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(configs=[triton.Config({"BK": BK}) for BK in [32, 64]], key=["BC"])
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_i, i_j = i_c // NC, i_c % NC
for i_h in range(H):
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_val = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_val = T
should_compute = (i_t * BT + i_i * BC < T_val) and (i_i > i_j)
if should_compute:
k_ptr = k + (bos * H + i_h) * K
g_ptr = g + (bos * H + i_h) * K
A_ptr = A + (bos * H + i_h) * BT
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T_val,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
b_kt = tl.make_block_ptr(
k_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
p_gk = tl.make_block_ptr(
g_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
b_gn = tl.load(g_ptr + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
b_g = tl.load(p_g, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.exp(b_g - b_gn[None, :])
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kt = tl.load(b_kt, boundary_check=(0, 1)) * tl.exp(b_gn[:, None] - b_gk)
b_A += tl.dot(b_k, b_kt)
b_A *= b_beta[:, None]
p_A = tl.make_block_ptr(A_ptr, (T_val, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
k,
g,
beta,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_i, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
for i_h in range(H):
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_val = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_val = T
should_compute = i_t * BT + i_i * BC < T_val
if should_compute:
o_i = tl.arange(0, BC)
o_k = tl.arange(0, BK)
m_k = o_k < K
m_A = (i_t * BT + i_i * BC + o_i) < T_val
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
)
p_beta = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.load(p_beta, mask=m_A, other=0)[:, None]
b_g = tl.load(p_g, boundary_check=(0, 1))
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
for j in range(0, min(BC, T_val - i_t * BT - i_i * BC)):
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
b_A = tl.sum(b_k * b_kt[None, :] * tl.exp(b_g - b_gk[None, :]), 1)
# 转化成f32
o_i_tmp = o_i.to(tl.float32)
b_A = tl.where(o_i_tmp > j, b_A, 0.0)
tl.store(A + o_A + j, b_A, mask=m_A)
p_kt += H * K
p_gk += H * K
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
beta: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, H, K = k.shape
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
beta = beta.transpose(1, 2).contiguous()
g = g.transpose(1, 2).contiguous()
BK = 128
kernel_num = 24
if gk is None:
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(kernel_num,)](
k=k,
g=g,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BK=BK,
NT=NT,
B=B,
TOTAL_TASKS=B * NT,
)
return A
BC = min(16, BT)
NC = triton.cdiv(BT, BC)
BK = max(triton.next_power_of_2(K), 16)
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
grid = (NT, NC * NC, B)
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
k=k,
g=gk,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
NC=NC,
)
grid = (NT, NC, B)
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
k=k,
g=gk,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
BK=BK,
)
return A

View File

@@ -0,0 +1,147 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import prepare_chunk_indices
@triton.heuristics(
{"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None}
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_scalar_kernel(
s,
o,
scale,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BLOCK_T: tl.constexpr,
REVERSE: tl.constexpr,
HAS_SCALE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
CHUNK_SIZE: tl.constexpr = 64,
):
i_block, i_b = tl.program_id(0), tl.program_id(1)
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
if IS_VARLEN:
i_s, i_block = (
tl.load(chunk_indices + i_block * 2).to(tl.int32),
tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
b_s = tl.trans(b_s, (1, 0, 2))
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
if HAS_SCALE:
b_o *= scale
b_o = tl.trans(b_o, (1, 0, 2))
b_o = tl.reshape(b_o, (BLOCK_T, H))
tl.store(ptr_o, b_o.to(ptr_o.dtype.element_ty), boundary_check=(0,))
return
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
) -> torch.Tensor:
B, T, H = g.shape
if chunk_size != 2 ** (chunk_size.bit_length() - 1):
raise ValueError(f"chunk_size must be a power of 2, chunk_size is{chunk_size}")
# We adjust the tiling strategy to prevent overflow in in backward passes and context parallel scenarios
# while maximizing UB utilization where possible.
# The tiling strategy is as follows:
# 1. BT must be greater than or equal to chunk_size.
# 2. UB estimation varies directly with H.
# 3. BT in reverse mode is smaller than in forward mode.
BT = max(chunk_size, triton.next_power_of_2((1 << 11 if reverse else 1 << 12) // H))
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B)
chunk_local_cumsum_scalar_kernel[grid](
s=g_org,
o=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
BLOCK_T=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
CHUNK_SIZE=chunk_size,
)
return g
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is not None:
if g.shape[0] != 1:
raise ValueError(
f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}"
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(
g=g,
chunk_size=chunk_size,
reverse=reverse,
scale=scale,
cu_seqlens=cu_seqlens,
head_first=head_first,
output_dtype=output_dtype,
)
else:
raise ValueError(
f"Unsupported input shape {g.shape}, "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise"
)

View File

@@ -0,0 +1,272 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import input_guard, make_tensor_descriptor, prepare_chunk_indices
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def solve_tril_16x16_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = pid_bh // H, pid_bh % H
base_t = pid_t * TPP
if IS_VARLEN:
i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_eff = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_eff = T
o_i = tl.arange(0, 16) # noqa: F841
o_i_fp32 = tl.arange(0, 16).to(tl.float32)
m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
m_I = o_i_fp32[:, None] == o_i_fp32[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * BT
for tpp in tl.static_range(0, TPP):
tile_t = base_t + tpp
tile_row = tile_t * 16
offset = (tile_t * 16) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0))
b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16])
b_A_raw = desc.load([tile_row, offset]).to(tl.float32)
b_A_neg = -b_A_raw
b_A = b_A_neg * m_A
for i in range(2, min(16, T_eff - tile_row)):
slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1])
b_a_val = tl.reshape(slice_res, (16,), can_reorder=True)
dot_prod = tl.sum(b_a_val[:, None] * b_A, 0)
b_a_update = b_a_val + dot_prod
b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
b_Ai_11 += m_I
b_Ai_22 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def solve_tril_64x64_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 64)
m_I = o_i[:, None] == o_i[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * 64
offset = (i_t * 64) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 64, offset), (64, 64), (1, 0))
b_A = -tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [64, 64])
desc_o = make_tensor_descriptor(Ai, [T, 64], [H * 64, 1], [64, 64])
b_A = -desc.load([i_t * 64, offset]).to(tl.float32)
for i in range(2, min(64, T - i_t * 64)):
b_a = -tl.load(A + (i_t * 64 + i) * H * BT + o_i + offset)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (64, 64), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * 64, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@input_guard
def solve_tril(
A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float
) -> torch.Tensor:
"""Compute the inverse of the matrix I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
(I + A)^-1 with the same shape as A
""" # noqa: D205
if A.shape[-1] not in [16, 32, 64]:
raise ValueError(f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}")
output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype)
if BT == 16:
merge_fn = solve_tril_16x16_kernel
elif BT == 32:
merge_fn = merge_16x16_to_32x32_inverse_kernel
elif BT == 64:
merge_fn = solve_tril_64x64_kernel
merge_fn[NT, B * H](
A=A,
Ai=Ai,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
USE_TMA=False,
DOT_PRECISION=FLA_TRIL_PRECISION,
)
return Ai

View File

@@ -0,0 +1,359 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import itertools
import logging
import os
import warnings
from collections.abc import Callable
from enum import Enum
from typing import Any, Optional
import torch
import triton
import triton.language as tl
import triton.language.extra.libdevice as tldevice
import triton.runtime.driver as driver
from packaging import version
logger = logging.getLogger(__name__)
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
def tensor_cache(fn: Optional[Callable[..., torch.Tensor]] = None, *, maxsize: int = 1) -> Any:
"""A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the outputs of the decorated function for the most recent
set of input tensors, up to `maxsize` entries. If the function is called again with
the same input tensors, it will return the cached result.
When maxsize=1 (default), the behavior is identical to caching only the most recent result.
Can be used as @tensor_cache or @tensor_cache(maxsize=n).
Args:
fn (Callable[..., torch.Tensor], optional):
The function to be decorated when used without parentheses.
maxsize (int):
Maximum number of input combinations to cache. Default is 1.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with caching.
"""
if maxsize < 1:
raise ValueError("maxsize must be at least 1")
def _is_match(a: Any, b: Any) -> bool:
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
return a is b
try:
return a == b
except Exception:
return a is b
def _make_wrapper(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
cache: list = []
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
for i, (cached_args, cached_kwargs, cached_result) in enumerate(cache):
if len(args) == len(cached_args) and len(kwargs) == len(cached_kwargs):
if all(_is_match(a, b) for a, b in zip(args, cached_args)) and all(
k in cached_kwargs and _is_match(v, cached_kwargs[k]) for k, v in kwargs.items()
):
if i != 0:
cache.insert(0, cache.pop(i))
return cached_result
result = fn(*args, **kwargs)
cache.insert(0, (args, kwargs, result))
if len(cache) > maxsize:
cache.pop()
return result
return wrapper
if fn is not None:
return _make_wrapper(fn)
return _make_wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache(maxsize=3)
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
def get_abs_err(x, y):
return (x.detach() - y.detach()).flatten().abs().max().item()
def get_err_ratio(x, y):
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
base = (x.detach()).flatten().square().mean().sqrt().item()
return err / (base + 1e-8)
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
abs_atol = get_abs_err(ref, tri)
msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
logger.info(msg)
error_rate = get_err_ratio(ref, tri)
if abs_atol <= err_atol:
return
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
if error_rate > ratio:
warnings.warn(msg)
else:
assert error_rate < ratio, msg
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
# For Triton 3.3.x
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
elif hasattr(triton.language, "make_tensor_descriptor"):
# For Triton 3.4.x and later
make_tensor_descriptor = triton.language.make_tensor_descriptor
else:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@triton.jit
def make_tensor_descriptor(
base,
shape,
strides,
block_shape,
_builder=None,
):
return None
@functools.cache
def get_available_device() -> str:
try:
return triton.runtime.driver.active.get_current_target().backend
except BaseException:
_cpu_device_warning()
return "cpu"
def map_triton_backend_to_torch_device() -> str:
backend = get_available_device() # 'cuda' | 'hip' | 'xpu' | 'cpu' | ...
return {"cuda": "cuda", "hip": "cuda", "xpu": "xpu"}.get(backend, backend)
device = get_available_device() if get_available_device() != "hip" else "cuda"
device_torch_lib = getattr(torch, device)
device_platform = get_available_device()
is_amd = device_platform == "hip"
is_nvidia = device_platform == "cuda"
is_nvidia_hopper = is_nvidia and (
"NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
)
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
is_tma_supported = (
(is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9)
and os.environ.get("FLA_NO_USE_TMA", "0") != "1"
and (
hasattr(triton.language, "_experimental_make_tensor_descriptor")
or hasattr(triton.language, "make_tensor_descriptor")
)
)
if is_nvidia and not is_tf32_supported:
# Make old card happy, since triton will use tf32 by default.
# This is a workaround for old nvidia card.
os.environ["TRITON_F32_DEFAULT"] = "ieee"
@functools.cache
def check_pytorch_version(version_s: str = "2.4") -> bool:
return version.parse(torch.__version__) >= version.parse(version_s)
if check_pytorch_version("2.4"):
device = "cuda" if device == "cpu" else device
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
def custom_device_ctx(index: int):
return device_torch_lib.device(index)
else:
assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0."
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
def custom_device_ctx(index: int):
return torch.cuda.device(index)
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""A decorator to make sure all input tensors are contiguous and set the device based on input tensors."""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = arg
break
if tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
tensor = value
break
if tensor is not None:
ctx = custom_device_ctx(tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return fn(*contiguous_args, **contiguous_kwargs)
return wrapper
def _cpu_device_warning():
warnings.warn(("Triton is not supported on current platform, roll back to CPU."), stacklevel=1)
@tensor_cache
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
exp = tldevice.fast_expf
exp2 = tldevice.exp2
log = tldevice.fast_logf
log2 = tldevice.fast_log2f
else:
exp = tl.exp
exp2 = tl.math.exp2
log = tl.log
log2 = tl.log2
def get_all_max_shared_mem():
try:
return [
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
for i in range(device_torch_lib.device_count())
]
except BaseException:
_cpu_device_warning()
return [-1]
class Backend(Enum):
ADA = 101376 # RTX 4090
AMPERE = 166912 # A100
HOPPER = 232448 # H100
DEFAULT = 102400 # Default
@classmethod
def get_shared_memory(cls, arch: str) -> int:
try:
return cls[arch.upper()].value
except KeyError:
return cls.DEFAULT.value
@functools.cache
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
try:
device_shared_mem_list = get_all_max_shared_mem()
max_shared_memory = device_shared_mem_list[tensor_idx]
return max_shared_memory >= Backend.get_shared_memory(arch)
except Exception:
return False
def get_autotune_config(
multibuffer_list: tuple = (False,),
unit_flag_list: tuple = (False,),
limit_auto_multi_buffer_only_for_local_buffer_list: tuple = (False,),
limit_auto_multi_buffer_of_local_buffer_list: tuple = ("no-l0c",),
set_workspace_multibuffer_list: tuple = (2, 4),
enable_hivm_auto_cv_balance_list: tuple = (True,),
tile_mix_vector_loop_num_list: tuple = (2, 4),
tile_mix_cube_loop_num_list: tuple = (2, 4),
):
configs = []
for (
multibuffer,
unit_flag,
limit_auto_multi_buffer_only_for_local_buffer,
limit_auto_multi_buffer_of_local_buffer,
) in itertools.product(
list(multibuffer_list),
list(unit_flag_list),
list(limit_auto_multi_buffer_only_for_local_buffer_list),
list(limit_auto_multi_buffer_of_local_buffer_list),
):
base_config_dict = {
"multibuffer": multibuffer,
"unit_flag": unit_flag,
"limit_auto_multi_buffer_only_for_local_buffer": limit_auto_multi_buffer_only_for_local_buffer,
"limit_auto_multi_buffer_of_local_buffer": limit_auto_multi_buffer_of_local_buffer,
}
if limit_auto_multi_buffer_only_for_local_buffer:
configs.append(triton.Config(base_config_dict))
else:
for (
set_workspace_multibuffer,
enable_hivm_auto_cv_balance,
tile_mix_vector_loop,
tile_mix_cube_loop,
) in itertools.product(
list(set_workspace_multibuffer_list),
list(enable_hivm_auto_cv_balance_list),
list(tile_mix_vector_loop_num_list),
list(tile_mix_cube_loop_num_list),
):
full_config_dict = base_config_dict.copy()
full_config_dict.update(
{
"set_workspace_multibuffer": set_workspace_multibuffer,
"enable_hivm_auto_cv_balance": enable_hivm_auto_cv_balance,
"tile_mix_vector_loop": tile_mix_vector_loop,
"tile_mix_cube_loop": tile_mix_cube_loop,
}
)
configs.append(triton.Config(full_config_dict))
return configs
def get_npu_properties():
return driver.active.utils.get_device_properties(torch.npu.current_device())

View File

@@ -0,0 +1,387 @@
# Copyright 2025 the LlamaFactory team.
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import exp, prepare_chunk_indices
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
def prepare_wy_repr_bwd_kernel(
k,
v,
beta,
g,
A,
dw,
du,
dk,
dv,
dbeta,
dg,
cu_seqlens,
chunk_indices,
T,
B,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
NT: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
T_max = T
base_chunks_per_pid = NT // total_cores
remainder_chunks = NT % total_cores
if core_id < remainder_chunks:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * chunks_this_pid + remainder_chunks
for idx in range(start_idx, start_idx + chunks_this_pid):
for i_b in range(B):
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + idx * 2).to(tl.int32),
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
i_t = idx
bos, eos = i_b * T, i_b * T + T
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
for i_h in range(0, H):
if IS_VARLEN:
offset = bos + i_h * T_max
else:
offset = bos * H + i_h * T_max
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_exp = tl.exp(b_g)
b_dbeta = tl.zeros([BT], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
b_dg = tl.zeros([BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dk = tl.make_block_ptr(
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dw = tl.make_block_ptr(
dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_k_beta_g = (b_k * b_beta[:, None] * b_g_exp[:, None]).to(b_k.dtype)
b_dw = tl.load(p_dw, boundary_check=(0, 1))
b_dA += tl.dot(b_dw, tl.trans(b_k_beta_g))
b_dk_beta_g = tl.dot(b_A, b_dw)
b_dk = b_dk_beta_g * b_beta[:, None] * b_g_exp[:, None]
b_dbeta += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None], 1)
b_dg += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None] * b_beta[:, None], 1)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_dv = tl.make_block_ptr(
dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_du = tl.make_block_ptr(
du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
b_du = tl.load(p_du, boundary_check=(0, 1))
b_dA += tl.dot(b_du, tl.trans(b_v_beta))
b_dv_beta = tl.dot(b_A, b_du)
b_dv = b_dv_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
b_dA = tl.where(m_A, b_dA, 0)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA = tl.where(m_A, -b_dA * exp(b_g[:, None] - b_g[None, :]), 0)
b_dA = b_dA.to(k.dtype.element_ty)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_dk = tl.make_block_ptr(
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.load(p_dk, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_A += tl.dot(b_k_beta, tl.trans(b_k))
b_dk_beta = tl.dot(b_dA, b_k)
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
b_dk += tl.dot(tl.trans(b_dA), b_k_beta)
b_dk += b_dk_beta * b_beta[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
b_dA_A = b_dA * b_A
b_dg += tl.sum(b_dA_A, axis=1) - tl.sum(b_dA_A, axis=0)
p_dg = tl.make_block_ptr(dg + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_dbeta = tl.make_block_ptr(dbeta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
gk,
cu_seqlens,
chunk_indices,
T_tmp,
B,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
NT: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
core_id = tl.program_id(0)
total_cores = tl.num_programs(0)
T_max = T_tmp
base_chunks_per_pid = NT // total_cores
remainder_chunks = NT % total_cores
if core_id < remainder_chunks:
chunks_this_pid = base_chunks_per_pid + 1
start_idx = core_id * chunks_this_pid
else:
chunks_this_pid = base_chunks_per_pid
start_idx = core_id * chunks_this_pid + remainder_chunks
for idx in range(start_idx, start_idx + chunks_this_pid):
for i_b in range(B):
for i_h in range(0, H):
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + idx * 2).to(tl.int32),
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
offset = bos + i_h * T_max
T = eos - bos
else:
T = T_tmp
i_t = idx
bos, eos = i_b * T, i_b * T + T
offset = bos * H + i_h * T_max
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
if USE_G:
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
if USE_G:
b_kb *= b_g[:, None]
if USE_GK:
p_gk = tl.make_block_ptr(
gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_kb *= tl.exp(tl.load(p_gk, boundary_check=(0, 1)))
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 128
BV = 128
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g = g.transpose(1, 2).contiguous() if g is not None else None
beta = beta.transpose(1, 2).contiguous()
w = torch.empty_like(k)
u = torch.empty_like(v)
cv_kernel_num = 24
recompute_w_u_fwd_kernel[(cv_kernel_num,)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g,
gk=gk,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T_tmp=T,
B=B,
H=H,
K=K,
V=V,
NT=NT,
BT=BT,
BK=BK,
BV=BV,
)
return w, u
def prepare_wy_repr_bwd(
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
dw: torch.Tensor,
du: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
chunk_size: int = 64,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 128
BV = 128
beta = beta.transpose(1, 2).contiguous()
g = g.transpose(1, 2).contiguous()
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dbeta = torch.empty_like(beta)
dg = torch.empty_like(g)
cv_kernel_num = 24
prepare_wy_repr_bwd_kernel[(cv_kernel_num,)](
k=k,
v=v,
beta=beta,
g=g,
A=A,
dw=dw,
du=du,
dk=dk,
dv=dv,
dbeta=dbeta,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
K=K,
V=V,
NT=NT,
BT=BT,
BK=BK,
BV=BV,
)
dbeta = dbeta.transpose(1, 2).contiguous()
dg = dg.transpose(1, 2).contiguous()
return dk, dv, dbeta, dg
bwd_prepare_wy_repr = prepare_wy_repr_bwd
fwd_recompute_w_u = recompute_w_u_fwd

View File

@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import fnmatch
import json
import os
import signal
import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
@@ -31,7 +33,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
from ..extras.packages import is_safetensors_available
@@ -338,6 +340,96 @@ class LogCallback(TrainerCallback):
self.thread_pool.submit(self._write_log, args.output_dir, logs)
class TorchProfilerCallback(TrainerCallback):
r"""A callback for collecting torch.profiler traces during training.
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
Configuration fields (in YAML):
profiler_output_dir where to write traces (default: <output_dir>/profiler)
profiler_wait_steps steps to skip at start of each cycle (default: 1)
profiler_warmup_steps profiler warm-up steps per cycle (default: 1)
profiler_active_steps steps to record per cycle (default: 1)
profiler_repeat number of cycles; 0 = forever (default: 1)
profiler_record_shapes record tensor shapes (default: true)
profiler_profile_memory profile memory usage (default: true)
profiler_with_stack record stack traces (default: true)
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
``<profiler_output_dir>/rank_<N>/``.
"""
def __init__(self, training_args: "TrainingArguments") -> None:
self.profiler = None
self.profiler_args = training_args
@staticmethod
def _get_rank() -> int:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
@override
def on_train_begin(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
pa = self.profiler_args
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
rank = self._get_rank()
trace_dir = os.path.join(output_dir, f"rank_{rank}")
os.makedirs(trace_dir, exist_ok=True)
activities = [torch.profiler.ProfilerActivity.CPU]
try:
if is_torch_cuda_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
if is_torch_npu_available():
activities.append(torch.profiler.ProfilerActivity.NPU)
except Exception:
pass
self.profiler = torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=pa.profiler_wait_steps,
warmup=pa.profiler_warmup_steps,
active=pa.profiler_active_steps,
repeat=pa.profiler_repeat,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
record_shapes=pa.profiler_record_shapes,
profile_memory=pa.profiler_profile_memory,
with_stack=pa.profiler_with_stack,
)
self.profiler.start()
logger.info_rank0(
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
)
@override
def on_step_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.step()
@override
def on_train_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
logger.info_rank0("TorchProfiler stopped.")
class ReporterCallback(TrainerCallback):
r"""A callback for reporting training status to external logger."""
@@ -394,3 +486,143 @@ class ReporterCallback(TrainerCallback):
"generating_args": self.generating_args.to_dict(),
}
)
class ModuleProfilerCallback(TrainerCallback):
r"""Profile forward/backward time of specified modules using accelerator events.
Hooks are registered on modules matching the user-provided name patterns.
Timing statistics are logged at each trainer logging step.
Usage in YAML config:
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
Supports fnmatch wildcards:
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
"""
@staticmethod
def _get_accelerator():
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
if is_torch_cuda_available():
return torch.cuda.Event, torch.cuda.synchronize
if is_torch_npu_available():
return torch.npu.Event, torch.npu.synchronize
return None, None
def __init__(self, profile_modules: str) -> None:
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
self._create_event, self._synchronize = self._get_accelerator()
self._handles: list[Any] = []
self._forward_times: dict[str, list[float]] = defaultdict(list)
self._backward_times: dict[str, list[float]] = defaultdict(list)
self._pending_forward: dict[str, tuple] = {}
self._pending_backward: dict[str, tuple] = {}
@property
def enabled(self) -> bool:
return self._create_event is not None
def _match(self, name: str) -> bool:
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
def _make_forward_pre_hook(self, name: str):
def hook(module, input):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_forward[name] = (start, end)
return hook
def _make_forward_hook(self, name: str):
def hook(module, input, output):
pair = self._pending_forward.get(name)
if pair is not None:
pair[1].record()
return hook
def _make_backward_pre_hook(self, name: str):
def hook(module, grad_output):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_backward[name] = (start, end)
return hook
def _make_backward_hook(self, name: str):
def hook(module, grad_input, grad_output):
pair = self._pending_backward.get(name)
if pair is not None:
pair[1].record()
return hook
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
return
model = kwargs.get("model")
if model is None:
return
matched = []
for name, module in model.named_modules():
if not name or not self._match(name):
continue
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
matched.append(name)
if matched:
logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
)
else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
return
self._synchronize()
for name, (start, end) in self._pending_forward.items():
self._forward_times[name].append(start.elapsed_time(end))
self._pending_forward.clear()
for name, (start, end) in self._pending_backward.items():
self._backward_times[name].append(start.elapsed_time(end))
self._pending_backward.clear()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self._forward_times and not self._backward_times:
return
lines = ["[ModuleProfiler] Timing (ms):"]
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
for name in all_names:
fwd = self._forward_times.get(name, [])
bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
logger.info_rank0("\n".join(lines))
self._forward_times.clear()
self._backward_times.clear()
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
for handle in self._handles:
handle.remove()
self._handles.clear()

View File

@@ -1,62 +0,0 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import torch
from ktransformers.sft.lora import KTrainer # type: ignore
from typing_extensions import override
from ..trainer_utils import get_batch_logps, nested_detach
from .trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel
class KDPOTrainer(KTrainer, CustomDPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
labels = batch.pop("labels") # dpo do not need compute loss in forward
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logits = all_logits.to("cpu")
labels = labels.to(all_logits.device)
all_logps, valid_length = get_batch_logps(
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
)
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length

View File

@@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer):
self.running = RunningMoments(self.accelerator)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -62,15 +62,7 @@ def run_dpo(
else:
ref_model = None
if model_args.use_kt:
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
from .ktrainer import KDPOTrainer as CustomDPOTrainer
GLOBAL_CONFIG._config["mod"] = "sft"
else:
from .trainer import CustomDPOTrainer
from .trainer import CustomDPOTrainer
# Initialize our Trainer
trainer = CustomDPOTrainer(

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_sft
from .workflow import run_pt, run_sft
__all__ = ["run_sft"]
__all__ = ["run_pt", "run_sft"]

View File

@@ -0,0 +1,222 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperParallel distributed trainer for LlamaFactory."""
import logging
import os
import types
from contextlib import nullcontext
from typing import Any, Optional
import torch
from hyper_parallel.integration.llamafactory import (
HSDPModule,
HyperParallelArguments,
export_to_hf_format,
fsdp2_prepare_model,
hsdp_sync_stream,
load_hsdp_model,
load_hsdp_optimizer_and_scheduler,
save_hsdp_checkpoint,
wrap_optimizer_with_skip_dtensor_dispatch,
)
from hyper_parallel.integration.llamafactory import (
clip_grad_norm_ as hp_clip_grad_norm_,
)
from torch import nn
from ..sft.trainer import CustomSeq2SeqTrainer
logger = logging.getLogger(__name__)
class HyperParallelTrainer(CustomSeq2SeqTrainer):
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
Inherits CustomSeq2SeqTrainer for training algorithm logic (loss, metrics,
prediction, sampler, etc.) and only overrides HSDP-specific behavior.
"""
def __init__(
self,
hp_args: HyperParallelArguments,
finetuning_args=None,
processor=None,
ref_model: Optional[nn.Module] = None,
**kwargs,
):
self._hp_args = hp_args
# Let CustomSeq2SeqTrainer handle everything except ref_model —
# Custom would prepare it with accelerate's fsdp2_prepare_model,
# but we need HP's version instead.
super().__init__(
finetuning_args=finetuning_args,
processor=processor,
ref_model=None,
**kwargs,
)
if not getattr(self.accelerator, "is_fsdp2", False):
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
# Prepare ref_model with HP's fsdp2_prepare_model
self.ref_model = ref_model
if self.ref_model is not None:
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
self._orig_fsdp2_prepare_model = None
self._accelerator_patches_active = False
def _activate_accelerator_patches(self) -> None:
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
if self._accelerator_patches_active:
return
import accelerate.accelerator as acc_module # pylint: disable=C0415
hp_args = self._hp_args
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
def _hp_fsdp2_prepare_model(accelerator, model):
return fsdp2_prepare_model(accelerator, model, hp_args)
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2):
if getattr(accelerator, "is_fsdp2", False):
accelerator.unscale_gradients()
parameter_list = list(parameters)
parameter_ids = {id(param) for param in parameter_list}
for model in accelerator._models: # pylint: disable=protected-access
if not isinstance(model, HSDPModule):
continue
model_param_ids = {id(param) for param in model.parameters()}
if parameter_ids and parameter_ids.issubset(model_param_ids):
return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type)
return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator)
self._accelerator_patches_active = True
def _restore_accelerator_patches(self) -> None:
"""Restore original Accelerate methods."""
if not self._accelerator_patches_active:
return
import accelerate.accelerator as acc_module # pylint: disable=C0415
if self._orig_fsdp2_prepare_model is not None:
acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model
self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm
self._accelerator_patches_active = False
def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module:
"""Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct."""
del dataloader
if isinstance(model, HSDPModule):
return model
if training and getattr(self.accelerator, "is_fsdp2", False):
return model
return super()._wrap_model(model, training=training)
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
"""Skip redundant device moves for HSDP-wrapped models."""
if isinstance(model, HSDPModule):
return model
if device is None:
return model
return model.to(device)
def train(self, *args, **kwargs):
"""Activate HP patches during training and restore afterwards."""
self._activate_accelerator_patches()
try:
return super().train(*args, **kwargs)
finally:
self._restore_accelerator_patches()
def training_step(
self,
model: nn.Module,
inputs: dict[str, Any],
num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
"""Standard training step with HSDP gradient synchronization."""
model.train()
inputs = self._prepare_inputs(inputs)
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
if isinstance(model, HSDPModule):
model.set_is_last_backward(sync_gradients)
model.set_requires_gradient_sync(sync_gradients)
compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext)
with compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.args.n_gpu > 1:
loss = loss.mean()
if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None:
loss = loss / self.args.gradient_accumulation_steps
self.accelerator.backward(loss)
if isinstance(model, HSDPModule) and sync_gradients:
hsdp_sync_stream()
return loss.detach()
def create_optimizer(self):
"""Create optimizer and wrap step with SkipDTensorDispatch."""
optimizer = super().create_optimizer()
wrap_optimizer_with_skip_dtensor_dispatch(optimizer)
return optimizer
def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
"""Save model/optimizer shards per-rank and scheduler."""
save_hsdp_checkpoint(
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
output_dir=output_dir,
should_save_scheduler=self.args.should_save and self.lr_scheduler is not None,
)
def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None:
"""Load model from HSDP sharded checkpoint."""
target = model if model is not None else self.model
loaded = load_hsdp_model(target, resume_from_checkpoint)
if not loaded:
return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
self._pending_hsdp_checkpoint = resume_from_checkpoint
return None
def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None:
"""Load optimizer/scheduler from per-rank checkpoint files."""
ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint
if ckpt_dir is None:
return
load_hsdp_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, ckpt_dir)
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""Save model weights in HuggingFace-compatible format."""
save_dir = output_dir or self.args.output_dir
os.makedirs(save_dir, exist_ok=True)
export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir)

View File

@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional
from transformers import DataCollatorForLanguageModeling
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
@@ -21,9 +24,9 @@ from ...extras.misc import calculate_tps
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import SaveProcessorCallback
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import HyperParallelTrainer
if TYPE_CHECKING:
@@ -35,6 +38,90 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelArguments"):
r"""Load HyperParallel arguments and apply LlamaFactory-side overrides.
When activation optimization is enabled, skip native gradient checkpointing
so HP can install its own via ``setup_activation_optimization``.
"""
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
if hp_args.activation_mode != "none":
model_args.disable_gradient_checkpointing = True
return hp_args
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
hp_args = _prepare_hp_args(finetuning_args, model_args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**tokenizer_module,
)
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -43,15 +130,7 @@ def run_sft(
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
if not is_hyper_parallel_available():
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
HyperParallelArguments,
HyperParallelTrainer,
)
hp_args = _prepare_hp_args(finetuning_args, model_args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
@@ -96,25 +175,6 @@ def run_sft(
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
callbacks = list(callbacks or [])
processor = tokenizer_module.get("processor")
if processor is not None:
callbacks.append(SaveProcessorCallback(processor))
compute_loss_func = None
if finetuning_args.use_dft_loss:
compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss:
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
elif finetuning_args.use_asft_loss:
from functools import partial
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
@@ -124,20 +184,11 @@ def run_sft(
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
compute_loss_func=compute_loss_func,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if finetuning_args.use_badam:
from types import MethodType
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
trainer.add_callback(BAdamCallback)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

View File

@@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -92,7 +92,8 @@ def _data_collator_wrapper(data_collator: Any):
def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
model_type = mca_config.get("hf_model_type", None)
else:
@@ -110,7 +111,14 @@ def _check_model_support(model_args: "ModelArguments"):
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
if getattr(model.config, "hf_model_type", None) not in [
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
]:
return
params_to_freeze = []

View File

@@ -69,10 +69,10 @@ class CustomTrainer(Trainer):
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer):
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -103,37 +103,18 @@ def run_sft(
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
# Initialize our Trainer
if model_args.use_kt:
from ktransformers.sft.lora import KTrainer # type: ignore
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
GLOBAL_CONFIG._config["mod"] = "sft"
trainer = KTrainer(
model=model,
args=training_args,
tokenizer=tokenizer_module,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**metric_module,
)
trainer.model_accepts_loss_kwargs = False
model.config.use_cache = False
else:
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
**dataset_module,
**tokenizer_module,
**metric_module,
)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
**dataset_module,
**tokenizer_module,
**metric_module,
)
# Training
if training_args.do_train:

View File

@@ -103,7 +103,7 @@ def create_modelcard_and_push(
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
if model_args.use_kt:
kwargs["tags"] = kwargs["tags"] + ["ktransformers"]
kwargs["tags"] = kwargs["tags"] + ["kt-kernel"]
if not training_args.do_train:
pass

View File

@@ -32,7 +32,13 @@ from ..extras.packages import (
)
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .callbacks import (
LogCallback,
ModuleProfilerCallback,
PissaConvertCallback,
ReporterCallback,
TorchProfilerCallback,
)
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
@@ -74,16 +80,27 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
if getattr(training_args, "enable_torch_profiler", False):
callbacks.append(TorchProfilerCallback(training_args))
if getattr(training_args, "profile_modules", None):
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel:
if not is_hyper_parallel_available():
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
from .hyper_parallel import run_sft as run_sft_hp
if finetuning_args.stage == "pt":
from .hyper_parallel import run_pt as run_pt_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks)
else:
from .hyper_parallel import run_sft as run_sft_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available():
@@ -182,7 +199,15 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
if not is_transformers_version_greater_than("5.0.0"):
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.save_pretrained(**save_kwargs)
try:
model.save_pretrained(**save_kwargs)
except NotImplementedError as err:
raise RuntimeError(
"Failed to export model: weight conversion reversal is not supported for this model architecture "
"(NotImplementedError in transformers.core_model_loading.reverse_op). "
"This is a known issue with transformers>=5.0 for certain model types (e.g. Mistral/Ministral). "
"Workarounds: (1) use transformers<5.0, or (2) report the issue to the transformers repository."
) from err
if model_args.export_hub_model_id is not None:
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.types import AttentionFunction
from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
__all__ = [
"AttentionFunction",
"BatchingStrategy",
"DataArguments",
"InputArgument",

View File

@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
model_args, data_args, training_args, sample_args = parsed_args
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
return tuple(parsed_args)
return model_args, data_args, training_args, sample_args
if __name__ == "__main__":

View File

@@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from ..utils.types import AttentionFunction
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -32,6 +33,12 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
flash_attn: AttentionFunction = field(
default=AttentionFunction.SDPA,
metadata={
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
@@ -54,6 +61,12 @@ class ModelArguments:
)
def __post_init__(self) -> None:
supported_flash_attn = [item.value for item in AttentionFunction]
if self.flash_attn not in supported_flash_attn:
raise ValueError(
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
)
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
full_determinism: bool = field(
default=False,
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
)
resume_from_checkpoint: str | None = field(
default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
@@ -116,3 +120,9 @@ class TrainingArguments:
self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
if self.max_steps is None or self.max_steps <= 0:
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
if self.save_epochs is not None:
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")

View File

@@ -34,7 +34,7 @@ import torch.nn.functional as F
from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..config import BatchingStrategy, TrainingArguments
from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
@@ -134,6 +134,9 @@ class BaseTrainer:
global_step=self.global_step,
epoch=self._resume_epoch,
)
# Keep callback state aligned with checkpoint-resumed trainer counters.
self.state.global_step = self.global_step
self.state.epoch = self._resume_epoch
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
@@ -144,13 +147,19 @@ class BaseTrainer:
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
raise ValueError(
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
)
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None:
if (
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
):
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
renderer=self.renderer,
@@ -234,6 +243,7 @@ class BaseTrainer:
self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
for micro_batches in self.train_batch_generator:
self.global_step += 1
@@ -269,26 +279,13 @@ class BaseTrainer:
# deepspeed: engine.step() already ran inside backward at the sync boundary
grad_norm = self._deepspeed_engine.get_grad_norm()
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
grad_norm = grad_norm**2
grad_norm = DistributedInterface().all_reduce(grad_norm, op=ReduceOp.SUM, dim=Dim.CP)
grad_norm = grad_norm**0.5
parameters = self.model.parameters()
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
parameters = list(parameters)
grads = [p.grad for p in parameters if p.grad is not None]
grad_norm = _get_total_norm(grads)
grad_norm = grad_norm.to(self.device)
_clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm)
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
grad_norm = grad_norm.full_tensor().item()
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.args.max_grad_norm
).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
@@ -316,7 +313,7 @@ class BaseTrainer:
if self.global_step % self.args.logging_steps == 0:
logs = {
"epoch": epoch,
"step": self.global_step,
"step": self.state.global_step,
"loss": step_loss,
"grad_norm": grad_norm,
"learning_rate": current_lr,
@@ -348,7 +345,9 @@ class BaseTrainer:
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
model_to_save.save_pretrained(
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
)
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -120,6 +120,7 @@ class ModelEngine:
init_device = DistributedInterface().current_device
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
@@ -143,6 +144,12 @@ class ModelEngine:
elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification
self.model_config.num_labels = 1
self.model_config.classifier_dropout = 0.0
text_config = getattr(self.model_config, "text_config", None)
if text_config is not None:
text_config.num_labels = 1
text_config.classifier_dropout = 0.0
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
@@ -158,6 +165,7 @@ class ModelEngine:
self.args.model,
config=self.model_config,
dtype="auto",
attn_implementation=self.args.flash_attn,
trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
)
@@ -182,9 +190,12 @@ class ModelEngine:
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model, include_kernels=self.args.kernel_config.get("include_kernels")
)
kernel_config = self.args.kernel_config
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
if kernel_config.name == "liger_kernel":
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
kernel_kwargs["require_logits"] = self.is_train
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
return model

Some files were not shown because too many files have changed in this diff Show More