mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 12:48:55 +08:00
Compare commits
19 Commits
2322bf1cc2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8669a22e9c | ||
|
|
897a44386c | ||
|
|
7a1e9630f2 | ||
|
|
cabe59a343 | ||
|
|
9ca4026efe | ||
|
|
0b7aaf8f6a | ||
|
|
8a4f6a3da5 | ||
|
|
409e8a477f | ||
|
|
053d43c0ac | ||
|
|
a98a1ef101 | ||
|
|
8ef7335b6a | ||
|
|
7af909522a | ||
|
|
e016d2480e | ||
|
|
7d719182c9 | ||
|
|
01398eb18d | ||
|
|
8e68764b65 | ||
|
|
16ff5a23cb | ||
|
|
bdcb92d035 | ||
|
|
7e20db5735 |
21
README.md
21
README.md
@@ -15,8 +15,6 @@
|
||||
|
||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
||||
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
||||
@@ -38,7 +36,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
|
||||
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
|
||||
|
||||
\[ English | [中文](README_zh.md) \]
|
||||
|
||||
@@ -52,14 +50,12 @@ Start local training:
|
||||
Start cloud training:
|
||||
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||
|
||||
Read technical notes:
|
||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
|
||||
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
|
||||
- **Documentation (ASCEND NPU)**: https://llamafactory.readthedocs.io/en/latest/multibackend/npu/index.html
|
||||
- **Official Blog**: https://blog.llamafactory.net/en/
|
||||
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||
|
||||
> [!NOTE]
|
||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||
@@ -78,7 +74,6 @@ Read technical notes:
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Quickstart](#quickstart)
|
||||
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
|
||||
- [LLaMA Factory Online](#llama-factory-online)
|
||||
- [Build Docker](#build-docker)
|
||||
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
|
||||
- [Download from ModelScope Hub](#download-from-modelscope-hub)
|
||||
@@ -117,15 +112,13 @@ Read technical notes:
|
||||
|
||||
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
|
||||
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
||||
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- 💡 [DataFlow × LLaMA Factory: Producing High-Quality Data for LLM Training with a Data Preparation Pipeline](https://wcny4qa9krto.feishu.cn/wiki/LWkkwTDBfiiRKqkDSvucG6yjnbW) (English) | [中文](https://wcny4qa9krto.feishu.cn/wiki/LlMxweUAJimrmykRD5qcGuswnHd)
|
||||
- 💡 [DataFlex × LLaMA Factory: A Data-Centric Dynamic Training System Built on LLaMA-Factory](https://wcny4qa9krto.feishu.cn/wiki/OlREwPQWdi9K6ZkJNHIciLhtnkv) (English) | [中文](https://wcny4qa9krto.feishu.cn/wiki/H2A9wSsbCinzavkT2oyc2C5Vn0e)
|
||||
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
|
||||
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
|
||||
|
||||
<details><summary>All Blogs</summary>
|
||||
|
||||
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
|
||||
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
|
||||
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
|
||||
@@ -661,10 +654,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### LLaMA Factory Online
|
||||
|
||||
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
|
||||
|
||||
### Build Docker
|
||||
|
||||
For CUDA users:
|
||||
@@ -838,7 +827,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collapse. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
|
||||
22
README_zh.md
22
README_zh.md
@@ -15,8 +15,6 @@
|
||||
|
||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
||||
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
||||
@@ -38,7 +36,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)、[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)。
|
||||
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)和 [NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
|
||||
|
||||
\[ [English](README.md) | 中文 \]
|
||||
|
||||
@@ -52,16 +50,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
开始云端训练:
|
||||
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **LLaMA Factory Online(在线微调)**:https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||
- **九章智算云(算力优惠活动)**:https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||
|
||||
阅读技术文档:
|
||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||
- **微调视频教程**:https://www.bilibili.com/video/BV1djgRzxEts/
|
||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
|
||||
- **框架文档(昇腾 NPU)**:https://llamafactory.readthedocs.io/zh-cn/latest/multibackend/npu/index.html
|
||||
- **官方博客**:https://blog.llamafactory.net/
|
||||
- **官方课程**:https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||
|
||||
> [!NOTE]
|
||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- [数据准备](#数据准备)
|
||||
- [快速开始](#快速开始)
|
||||
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
|
||||
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
|
||||
- [构建 Docker](#构建-docker)
|
||||
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
|
||||
- [从魔搭社区下载](#从魔搭社区下载)
|
||||
@@ -119,15 +113,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
|
||||
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
|
||||
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
|
||||
- 💡 [DataFlow × LLaMA Factory: 利用数据准备流水线产出高质量数据训练 LLM](https://wcny4qa9krto.feishu.cn/wiki/LlMxweUAJimrmykRD5qcGuswnHd)(中文)| [English](https://wcny4qa9krto.feishu.cn/wiki/LWkkwTDBfiiRKqkDSvucG6yjnbW)
|
||||
- 💡 [DataFlex × LLaMA Factory: 构建在 LLaMA-Factory 之上的以数据为中心的动态训练系统](https://wcny4qa9krto.feishu.cn/wiki/H2A9wSsbCinzavkT2oyc2C5Vn0e)(中文)| [English](https://wcny4qa9krto.feishu.cn/wiki/OlREwPQWdi9K6ZkJNHIciLhtnkv)
|
||||
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
|
||||
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
|
||||
|
||||
<details><summary>全部博客</summary>
|
||||
|
||||
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
|
||||
- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
|
||||
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
|
||||
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
|
||||
@@ -662,10 +654,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### LLaMA Factory Online 在线微调
|
||||
|
||||
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
|
||||
|
||||
### 构建 Docker
|
||||
|
||||
CUDA 用户:
|
||||
@@ -842,7 +830,7 @@ swanlab_run_name: test_run # 可选
|
||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collapse. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
|
||||
@@ -36,6 +36,7 @@ COPY . /app
|
||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
RUN pip uninstall -y torch torchvision torchaudio
|
||||
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
||||
RUN pip install --no-cache-dir -r requirements/triton_ascend.txt
|
||||
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
||||
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
20
examples/accelerate/fsdp2_config_qwen35_moe.yaml
Normal file
20
examples/accelerate/fsdp2_config_qwen35_moe.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer,Qwen3_5MoeVisionBlock
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8 # Change to match your NPU count (e.g., 8 for A2, 16 for A3)
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
||||
51
examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
Normal file
51
examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
Normal file
@@ -0,0 +1,51 @@
|
||||
# Start FSDP2 full fine-tuning on Ascend NPU
|
||||
# Usage:
|
||||
# accelerate launch \
|
||||
# --config_file examples/accelerate/fsdp2_config_qwen35_moe.yaml \
|
||||
# src/train.py examples/ascend/qwen3_5moe_lora_sft_fsdp2.yaml
|
||||
#
|
||||
# Note: Change `num_processes` in fsdp2_config_qwen35_moe.yaml to match your NPU count
|
||||
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3.5-35B-A3B
|
||||
trust_remote_code: true
|
||||
use_v1_kernels: false
|
||||
flash_attn: fa2
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: alpaca_en_demo
|
||||
template: qwen3_5_nothink
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
packing: false
|
||||
|
||||
### output
|
||||
output_dir: saves/Qwen3.5-35B/lora/sft
|
||||
logging_steps: 1
|
||||
save_steps: 2000
|
||||
max_steps: 2000
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 1800
|
||||
resume_from_checkpoint: null
|
||||
disable_gradient_checkpointing: true
|
||||
@@ -0,0 +1,31 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 2
|
||||
batching_strategy: normal
|
||||
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 2
|
||||
batching_strategy: dynamic_batching
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: dynamic_padding_free
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: padding_free
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
28
examples/v1/train_full/train_full_liger_kernel.yaml
Normal file
28
examples/v1/train_full/train_full_liger_kernel.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: liger_kernel
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
2
requirements/triton_ascend.txt
Normal file
2
requirements/triton_ascend.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
--extra-index-url https://triton-ascend.osinfra.cn/pypi/simple
|
||||
triton-ascend==3.2.1
|
||||
@@ -886,6 +886,9 @@ register_model_group(
|
||||
"Gemma-4-E4B-Thinking": {
|
||||
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
|
||||
},
|
||||
"Gemma-4-12B-Thinking": {
|
||||
DownloadSource.DEFAULT: "google/gemma-4-12B-it",
|
||||
},
|
||||
},
|
||||
template="gemma4n",
|
||||
multimodal=True,
|
||||
@@ -1912,6 +1915,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM5-1B-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM5-1B",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM5-1B",
|
||||
},
|
||||
},
|
||||
template="empty",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-o-2.6": {
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
VERSION = "0.9.5.dev0"
|
||||
VERSION = "0.9.6.dev0"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -487,7 +487,7 @@ class FinetuningArguments(
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
|
||||
"Only supported for the 'sft' stage with full fine-tuning."
|
||||
"Only supported for the 'pt' and 'sft' stages with full fine-tuning."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -194,10 +194,16 @@ def _setup_lora_tuning(
|
||||
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
if isinstance(model, PeftModel):
|
||||
pass # already loaded via load_unsloth_peft_model in loader.py
|
||||
else:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
if model_args.use_unsloth:
|
||||
peft_model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
if peft_model is not None:
|
||||
model = peft_model
|
||||
|
||||
if not model_args.use_unsloth: # unsloth was disabled or fell back
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
|
||||
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model, load_unsloth_peft_model
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
@@ -142,14 +142,13 @@ def load_model(
|
||||
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||
|
||||
model = None
|
||||
lazy_load = False
|
||||
if model_args.use_unsloth:
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lazy_load = True
|
||||
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
|
||||
elif is_trainable:
|
||||
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
|
||||
|
||||
if model is None and not lazy_load:
|
||||
if model is None:
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
init_kwargs["torch_dtype"] = "auto"
|
||||
@@ -176,9 +175,8 @@ def load_model(
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
if not lazy_load:
|
||||
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
@@ -29,7 +30,81 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||
def get_embedding_vocab_size(model: "PreTrainedModel") -> int:
|
||||
r"""Get the vocab size from the input embedding layer.
|
||||
|
||||
Handles DeepSpeed ZeRO-3 parameter sharding by gathering the embedding weight
|
||||
before reading its size.
|
||||
"""
|
||||
embedding = model.get_input_embeddings()
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
with deepspeed.zero.GatheredParameters([embedding.weight]):
|
||||
return embedding.weight.size(0)
|
||||
|
||||
return embedding.weight.size(0)
|
||||
|
||||
|
||||
def _resolve_new_token_ids(
|
||||
new_tokens: Optional[Iterable[str]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
embed_size: int,
|
||||
) -> Optional[list[int]]:
|
||||
r"""Resolve the explicit embedding-row IDs of the newly added tokens.
|
||||
|
||||
Relying on ``embed_weight[-num_new_tokens:]`` to locate new tokens is unsafe when
|
||||
the model embedding was already padded beyond the tokenizer vocab (e.g. Qwen2.5-VL
|
||||
has vocab 151665 but embedding 151936). In that case the appended tokens land
|
||||
inside the original padding zone and the tail slice points at the wrong rows.
|
||||
|
||||
Args:
|
||||
new_tokens: Iterable of the newly added token strings.
|
||||
tokenizer: The tokenizer instance.
|
||||
embed_size: Current embedding size (upper bound for valid token IDs).
|
||||
|
||||
Returns:
|
||||
A sorted list of unique, in-range token IDs, or ``None`` when no tokens are
|
||||
given so that callers can fall back to the tail-slice behaviour.
|
||||
"""
|
||||
if not new_tokens:
|
||||
return None
|
||||
|
||||
unk_token_id = getattr(tokenizer, "unk_token_id", None)
|
||||
token_ids: set[int] = set()
|
||||
for token_str in new_tokens:
|
||||
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id is None or token_id == unk_token_id or not (0 <= token_id < embed_size):
|
||||
logger.warning_rank0(f"Token '{token_str}' not found or out of range, skipping during init.")
|
||||
continue
|
||||
|
||||
token_ids.add(token_id)
|
||||
|
||||
return sorted(token_ids) or None
|
||||
|
||||
|
||||
def _existing_embeddings(
|
||||
embed_weight: "torch.Tensor", num_new_tokens: int, new_token_ids: Optional[list[int]]
|
||||
) -> "torch.Tensor":
|
||||
"""Return the rows treated as 'existing' embeddings used as the init baseline.
|
||||
|
||||
Prefers excluding the explicit new-token rows (robust to padding). Falls back to
|
||||
dropping the last ``num_new_tokens`` rows when no explicit IDs are available.
|
||||
"""
|
||||
if new_token_ids:
|
||||
mask = torch.ones(embed_weight.size(0), dtype=torch.bool, device=embed_weight.device)
|
||||
mask[torch.as_tensor(new_token_ids, device=embed_weight.device, dtype=torch.long)] = False
|
||||
return embed_weight[mask]
|
||||
|
||||
if num_new_tokens > 0:
|
||||
return embed_weight[:-num_new_tokens]
|
||||
|
||||
return embed_weight
|
||||
|
||||
|
||||
def _noisy_mean_initialization(
|
||||
embed_weight: "torch.Tensor", num_new_tokens: int, token_ids: Optional[list[int]] = None
|
||||
) -> None:
|
||||
"""Initialize new token embeddings with mean + Gaussian noise.
|
||||
|
||||
This is the default initialization method used by LlamaFactory.
|
||||
@@ -37,12 +112,23 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
|
||||
Args:
|
||||
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
||||
num_new_tokens: Number of new tokens added at the end of the embedding matrix
|
||||
token_ids: Explicit token IDs to initialize. When provided, these exact rows are
|
||||
written (robust to padding). When ``None``, falls back to the last
|
||||
``num_new_tokens`` rows.
|
||||
"""
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
avg_weight = _existing_embeddings(embed_weight, num_new_tokens, token_ids).mean(dim=0, keepdim=True)
|
||||
|
||||
if token_ids:
|
||||
noise_weight = torch.empty(
|
||||
len(token_ids), embedding_dim, device=embed_weight.device, dtype=embed_weight.dtype
|
||||
)
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[token_ids] = avg_weight + noise_weight
|
||||
else:
|
||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
|
||||
|
||||
def _description_based_initialization(
|
||||
@@ -51,6 +137,7 @@ def _description_based_initialization(
|
||||
descriptions: dict[str, str],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model: "PreTrainedModel",
|
||||
new_token_ids: Optional[list[int]] = None,
|
||||
add_noise: bool = False,
|
||||
) -> None:
|
||||
"""Initialize new token embeddings based on textual descriptions.
|
||||
@@ -61,6 +148,9 @@ def _description_based_initialization(
|
||||
3. Averages them to initialize the new token's embedding
|
||||
4. Optionally adds Gaussian noise
|
||||
|
||||
New tokens are placed by their resolved token ID rather than by tail slicing,
|
||||
so the initialization is correct even when the embedding matrix was padded.
|
||||
|
||||
Args:
|
||||
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
||||
num_new_tokens: Number of new tokens added
|
||||
@@ -68,6 +158,8 @@ def _description_based_initialization(
|
||||
e.g., {"<think>": "A token representing reasoning process"}
|
||||
tokenizer: The tokenizer instance
|
||||
model: The model instance (used to get input embeddings)
|
||||
new_token_ids: IDs of all newly added tokens. Used to exclude not-yet-initialized
|
||||
rows when averaging description-token embeddings (robust to embedding padding).
|
||||
add_noise: Whether to add Gaussian noise to the initialization
|
||||
|
||||
Example:
|
||||
@@ -77,38 +169,55 @@ def _description_based_initialization(
|
||||
}
|
||||
"""
|
||||
embedding_dim = embed_weight.size(1)
|
||||
vocab_size = embed_weight.size(0)
|
||||
unk_token_id = getattr(tokenizer, "unk_token_id", None)
|
||||
device = embed_weight.device
|
||||
|
||||
# The set of rows that are NOT yet initialized (the newly added tokens). Description
|
||||
# tokens that fall into this set must be excluded, otherwise we would average garbage.
|
||||
# `num_new_tokens` (the padded resize delta) is NOT a reliable boundary, so rely on
|
||||
# the explicit IDs, falling back to resolving them from the description keys.
|
||||
if new_token_ids is None:
|
||||
new_token_ids = _resolve_new_token_ids(descriptions.keys(), tokenizer, vocab_size)
|
||||
|
||||
new_id_set = set(new_token_ids or [])
|
||||
fallback_embedding = _existing_embeddings(embed_weight, num_new_tokens, new_token_ids).mean(dim=0)
|
||||
|
||||
for token_str, desc in descriptions.items():
|
||||
# Resolve token ID for correct placement (robust to embedding padding)
|
||||
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id is None or token_id == unk_token_id or not (0 <= token_id < vocab_size):
|
||||
logger.warning_rank0(f"desc_init: token '{token_str}' not found or out of range, skipping.")
|
||||
continue
|
||||
|
||||
for i, desc in enumerate(descriptions.values()):
|
||||
# Tokenize description text
|
||||
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
|
||||
|
||||
with torch.no_grad():
|
||||
token_ids = tokens["input_ids"][0]
|
||||
# Move to the same device as embed_weight
|
||||
device = embed_weight.device
|
||||
token_ids = token_ids.to(device)
|
||||
token_ids = tokens["input_ids"][0].tolist()
|
||||
|
||||
# Filter out new tokens (they don't have valid embeddings yet)
|
||||
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
|
||||
# Keep only description tokens that already have a meaningful embedding.
|
||||
valid_token_ids = [tid for tid in token_ids if tid not in new_id_set and 0 <= tid < vocab_size]
|
||||
|
||||
if len(valid_token_ids) == 0:
|
||||
# Fallback: use mean of all existing embeddings
|
||||
logger.warning_rank0(
|
||||
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
|
||||
f"Description for token '{token_str}' contains no valid tokens. "
|
||||
"Using mean of existing embeddings."
|
||||
)
|
||||
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
|
||||
base_embedding = fallback_embedding
|
||||
else:
|
||||
# Get embeddings of description tokens and average them
|
||||
token_embeds = model.get_input_embeddings()(valid_token_ids)
|
||||
valid_ids_tensor = torch.as_tensor(valid_token_ids, device=device, dtype=torch.long)
|
||||
token_embeds = model.get_input_embeddings()(valid_ids_tensor)
|
||||
base_embedding = token_embeds.mean(dim=0)
|
||||
|
||||
# Add noise if requested (ensure correct device and dtype)
|
||||
if add_noise:
|
||||
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
|
||||
embed_weight[-num_new_tokens + i] = base_embedding + noise
|
||||
embed_weight[token_id] = base_embedding + noise
|
||||
else:
|
||||
embed_weight[-num_new_tokens + i] = base_embedding
|
||||
embed_weight[token_id] = base_embedding
|
||||
|
||||
|
||||
def _initialize_embeddings(
|
||||
@@ -118,6 +227,7 @@ def _initialize_embeddings(
|
||||
new_special_tokens_config: Optional[dict],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model: "PreTrainedModel",
|
||||
new_token_ids: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
"""Single source of truth for embedding initialization.
|
||||
|
||||
@@ -130,16 +240,18 @@ def _initialize_embeddings(
|
||||
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
|
||||
tokenizer: The tokenizer instance
|
||||
model: The model instance
|
||||
new_token_ids: Explicit IDs of the newly added tokens (robust to embedding padding).
|
||||
When ``None``, the init helpers fall back to the last ``num_new_tokens`` rows.
|
||||
"""
|
||||
if init_method == "desc_init" and new_special_tokens_config:
|
||||
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
|
||||
_description_based_initialization(
|
||||
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
|
||||
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=False
|
||||
)
|
||||
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
|
||||
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
|
||||
_description_based_initialization(
|
||||
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
|
||||
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=True
|
||||
)
|
||||
else:
|
||||
if init_method != "noise_init":
|
||||
@@ -147,20 +259,28 @@ def _initialize_embeddings(
|
||||
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
|
||||
)
|
||||
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
|
||||
_noisy_mean_initialization(embed_weight, num_new_tokens)
|
||||
_noisy_mean_initialization(embed_weight, num_new_tokens, token_ids=new_token_ids)
|
||||
|
||||
|
||||
def resize_embedding_layer(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
new_tokens: Optional[Iterable[str]] = None,
|
||||
new_special_tokens_config: Optional[dict] = None,
|
||||
init_special_tokens: str = "noise_init",
|
||||
) -> None:
|
||||
r"""Resize token embeddings and initialize new tokens.
|
||||
r"""Resize token embeddings (when needed) and initialize the newly added tokens.
|
||||
|
||||
Resizing and initialization are decoupled: even when the tokenizer vocab fits inside
|
||||
the model's existing (padded) embedding matrix and no resize is triggered, the newly
|
||||
added tokens still occupy uninitialized rows and must be initialized. We therefore
|
||||
resolve the explicit row IDs of ``new_tokens`` and always initialize those rows.
|
||||
|
||||
Args:
|
||||
model: The model to resize
|
||||
tokenizer: The tokenizer (used to get target vocab size)
|
||||
new_tokens: Iterable of the newly added token strings. Used to locate the exact
|
||||
embedding rows to initialize, which is robust to pre-existing embedding padding.
|
||||
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
|
||||
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
|
||||
"""
|
||||
@@ -175,44 +295,70 @@ def resize_embedding_layer(
|
||||
else:
|
||||
context_maybe_zero3 = nullcontext()
|
||||
|
||||
with context_maybe_zero3:
|
||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
current_embedding_size = get_embedding_vocab_size(model)
|
||||
needs_resize = len(tokenizer) > current_embedding_size
|
||||
|
||||
if len(tokenizer) > current_embedding_size:
|
||||
if needs_resize:
|
||||
if getattr(model, "quantization_method", None):
|
||||
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||
|
||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||
raise ValueError("Current model does not support resizing embedding layers.")
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||
with context_maybe_zero3:
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
# mean_resizing=False preserves the original embedding distribution exactly.
|
||||
# HuggingFace's default mean_resizing=True re-samples new rows from the mean/covariance
|
||||
# of existing embeddings, which conflicts with our explicit initialization below.
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64, mean_resizing=False)
|
||||
|
||||
with context_maybe_zero3:
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
|
||||
# Resolve the exact rows of the new tokens. This works whether or not a resize was
|
||||
# triggered (e.g. tokens added into a model's pre-existing padding zone).
|
||||
new_token_ids = _resolve_new_token_ids(new_tokens, tokenizer, new_embedding_size)
|
||||
|
||||
if num_new_tokens <= 0 and not new_token_ids:
|
||||
return
|
||||
|
||||
if needs_resize:
|
||||
logger.info_rank0(
|
||||
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
|
||||
)
|
||||
else:
|
||||
logger.info_rank0(
|
||||
f"No resize needed (vocab fits in padded embedding {new_embedding_size}); "
|
||||
f"initializing {len(new_token_ids or [])} new token(s) in place."
|
||||
)
|
||||
|
||||
# Initialize input embeddings
|
||||
# Initialize input embeddings
|
||||
_initialize_embeddings(
|
||||
model.get_input_embeddings().weight.data,
|
||||
num_new_tokens,
|
||||
init_special_tokens,
|
||||
new_special_tokens_config,
|
||||
tokenizer,
|
||||
model,
|
||||
new_token_ids=new_token_ids,
|
||||
)
|
||||
|
||||
# Initialize output embeddings if not tied
|
||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||
_initialize_embeddings(
|
||||
model.get_input_embeddings().weight.data,
|
||||
model.get_output_embeddings().weight.data,
|
||||
num_new_tokens,
|
||||
init_special_tokens,
|
||||
new_special_tokens_config,
|
||||
tokenizer,
|
||||
model,
|
||||
new_token_ids=new_token_ids,
|
||||
)
|
||||
|
||||
# Initialize output embeddings if not tied
|
||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||
_initialize_embeddings(
|
||||
model.get_output_embeddings().weight.data,
|
||||
num_new_tokens,
|
||||
init_special_tokens,
|
||||
new_special_tokens_config,
|
||||
tokenizer,
|
||||
model,
|
||||
)
|
||||
|
||||
if needs_resize:
|
||||
model.config.vocab_size = new_embedding_size
|
||||
# Also update the nested text_config for VL models (e.g., Qwen2.5-VL, LLaVA),
|
||||
# otherwise config.vocab_size and config.text_config.vocab_size become inconsistent.
|
||||
if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "vocab_size"):
|
||||
model.config.text_config.vocab_size = new_embedding_size
|
||||
|
||||
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||
|
||||
@@ -16,6 +16,7 @@ import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import get_device_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -81,6 +82,8 @@ def apply_liger_kernel(
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
|
||||
elif model_type == "qwen3_5":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
|
||||
elif model_type == "qwen3_5_moe":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe as apply_liger_kernel
|
||||
elif model_type == "gpt_oss":
|
||||
try:
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
||||
@@ -97,5 +100,12 @@ def apply_liger_kernel(
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
if get_device_name() == "npu":
|
||||
import torch
|
||||
|
||||
if "Ascend910" not in torch.npu.get_device_name(0):
|
||||
kwargs["swiglu"] = False
|
||||
kwargs["fused_linear_cross_entropy"] = False
|
||||
|
||||
apply_liger_kernel(**kwargs)
|
||||
logger.info_rank0("Liger kernel has been applied to the model.")
|
||||
|
||||
@@ -84,8 +84,12 @@ def load_unsloth_peft_model(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
) -> "PreTrainedModel":
|
||||
r"""Load peft model with unsloth. Used in both training and inference."""
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
r"""Load peft model with unsloth. Used in both training and inference.
|
||||
|
||||
Returns None if unsloth does not support the model type, and sets
|
||||
model_args.use_unsloth = False so callers can fall back to standard loading.
|
||||
"""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
|
||||
@@ -95,7 +99,9 @@ def load_unsloth_peft_model(
|
||||
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model_args.use_unsloth = False
|
||||
return None
|
||||
|
||||
if not is_trainable:
|
||||
FastLanguageModel.for_inference(model)
|
||||
|
||||
@@ -20,6 +20,7 @@ from peft import PeftModel
|
||||
from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
@@ -84,7 +85,60 @@ def _check_fla_dependencies() -> None:
|
||||
) from exc
|
||||
|
||||
|
||||
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
||||
def patch_qwen3_5_forward_npu(model: "PreTrainedModel") -> None:
|
||||
"""Patch for Qwen3.5 models on NPU by importing torch_npu to enable torch.cuda compatibility.
|
||||
|
||||
On NPU, torch.cuda operations will fail unless torch_npu is imported.
|
||||
torch_npu provides compatibility layer that maps torch.cuda calls to NPU operations.
|
||||
|
||||
Also replaces chunk_gated_delta_rule with NPU-compatible implementation.
|
||||
"""
|
||||
import importlib.metadata
|
||||
|
||||
if "Ascend910" not in torch.npu.get_device_name(0):
|
||||
logger.warning_rank0("Currently only 910B series NPUs are supported for the NPU GDN patch.")
|
||||
return
|
||||
|
||||
try:
|
||||
importlib.metadata.version("triton_ascend")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
logger.warning_rank0(
|
||||
"triton_ascend not installed, skipping NPU GDN patch. "
|
||||
"To enable it on NPU, reinstall Triton with the Ascend build: "
|
||||
"`pip uninstall -y triton && pip install -r requirements/triton_ascend.txt`. "
|
||||
"Note: triton and triton_ascend cannot coexist — triton must be uninstalled first."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info_rank0("triton_ascend detected for NPU compatibility.")
|
||||
|
||||
from ..third_party.triton.chunk_gated_delta_rule import chunk_gated_delta_rule as npu_chunk_gated_delta_rule
|
||||
|
||||
if model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
|
||||
try:
|
||||
# Qwen3.5-MoE structure: model.model.language_model.layers
|
||||
for layer in model.model.language_model.layers:
|
||||
if hasattr(layer, "linear_attn"):
|
||||
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
|
||||
|
||||
logger.info_rank0(
|
||||
"Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5-MoE model."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
|
||||
elif model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
|
||||
try:
|
||||
# Qwen3.5 structure: model.model.layers
|
||||
for layer in model.model.layers:
|
||||
if hasattr(layer, "linear_attn"):
|
||||
layer.linear_attn.chunk_gated_delta_rule = npu_chunk_gated_delta_rule
|
||||
|
||||
logger.info_rank0("Replaced chunk_gated_delta_rule with NPU-compatible implementation for Qwen3.5 model.")
|
||||
except Exception as e:
|
||||
logger.warning_rank0(f"Failed to replace chunk_gated_delta_rule for NPU: {e}")
|
||||
|
||||
|
||||
def patch_qwen3_5_forward_gpu(model: "PreTrainedModel") -> None:
|
||||
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
|
||||
|
||||
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
|
||||
@@ -162,8 +216,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
||||
if position_ids is not None and position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
|
||||
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
|
||||
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
|
||||
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
|
||||
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
|
||||
if position_ids is not None and batch_size == 1:
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
||||
# standard causal-conv1d path that the upstream forward uses.
|
||||
@@ -397,9 +457,14 @@ def patch_model(
|
||||
prepare_valuehead_model(model)
|
||||
|
||||
if model_args.resize_vocab:
|
||||
# Pass the explicit list of newly added tokens so their exact embedding rows can be
|
||||
# located and initialized, even when they land in a model's pre-existing padding zone.
|
||||
new_tokens = (model_args.add_tokens or []) + (model_args.add_special_tokens or [])
|
||||
|
||||
resize_embedding_layer(
|
||||
model,
|
||||
tokenizer,
|
||||
new_tokens=new_tokens or None,
|
||||
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
|
||||
init_special_tokens=model_args.init_special_tokens,
|
||||
)
|
||||
@@ -415,8 +480,12 @@ def patch_model(
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
||||
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
|
||||
patch_qwen3_5_forward(model)
|
||||
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"]:
|
||||
if is_torch_npu_available():
|
||||
patch_qwen3_5_forward_npu(model)
|
||||
elif is_torch_cuda_available() and model_args.flash_attn == "fa2":
|
||||
# this is the patch for packing/neat_packing for GPU GDN. And when setting packing, flash_attn must be fa2.
|
||||
patch_qwen3_5_forward_gpu(model)
|
||||
|
||||
if not model_args.use_unsloth:
|
||||
print_attn_implementation(model.config)
|
||||
|
||||
594
src/llamafactory/third_party/triton/chunk_delta_h.py
vendored
Normal file
594
src/llamafactory/third_party/triton/chunk_delta_h.py
vendored
Normal file
@@ -0,0 +1,594 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import get_autotune_config, get_npu_properties, prepare_chunk_indices, prepare_chunk_offsets
|
||||
|
||||
|
||||
CUBE_CORE_NUM = get_npu_properties()["num_aicore"]
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_GK": lambda args: args["gk"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.autotune(
|
||||
configs=get_autotune_config(multibuffer_list=(False,)),
|
||||
key=["H", "K", "V", "BT"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
v_new,
|
||||
g,
|
||||
gk,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_GK: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
SAVE_NEW_VALUE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
T_all = T
|
||||
NT_all = NT
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# Initialize hidden states
|
||||
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
|
||||
if IS_VARLEN:
|
||||
v = v + (i_h * T_all + bos) * V
|
||||
k = k + (i_h * T_all + bos) * K
|
||||
w = w + (i_h * T_all + bos) * K
|
||||
g = g + i_h * T_all + bos
|
||||
h = h + (i_h * NT_all + boh) * K * V
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new_base = v_new + (i_h * T_all + bos) * V
|
||||
else:
|
||||
v = v + (i_n * H + i_h) * T * V
|
||||
k = k + (i_n * H + i_h) * T * K
|
||||
w = w + (i_n * H + i_h) * T * K
|
||||
g = g + (i_n * H + i_h) * T
|
||||
h = h + (i_n * H + i_h) * NT * K * V
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new_base = v_new + (i_n * H + i_h) * T * V
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
h0_ptr = h0 + i_nh * K * V
|
||||
if STORE_FINAL_STATE:
|
||||
ht_ptr = ht + i_nh * K * V
|
||||
|
||||
# Load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0_1 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 64:
|
||||
p_h0_2 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 128:
|
||||
p_h0_3 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 192:
|
||||
p_h0_4 = tl.make_block_ptr(h0_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
# Main recurrence over chunks
|
||||
for i_t in range(NT):
|
||||
# Store current hidden state h_t
|
||||
p_h1 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(h + i_t * K * V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# Compute v_residual = v - w @ h
|
||||
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 0), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 64), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 128), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (K, 1), (i_t * BT, 192), (BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new = tl.make_block_ptr(v_new_base, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
|
||||
# Apply output gate g
|
||||
if USE_G:
|
||||
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
|
||||
b_g_last = tl.load(g + last_idx)
|
||||
p_g = tl.make_block_ptr(g, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_v *= (m_t * tl.exp(b_g_last - b_g))[:, None]
|
||||
b_g_last_exp = tl.exp(b_g_last)
|
||||
b_h1 *= b_g_last_exp
|
||||
if K > 64:
|
||||
b_h2 *= b_g_last_exp
|
||||
if K > 128:
|
||||
b_h3 *= b_g_last_exp
|
||||
if K > 192:
|
||||
b_h4 *= b_g_last_exp
|
||||
|
||||
# Apply key gate gk
|
||||
if USE_GK:
|
||||
o_k1 = tl.arange(0, 64).to(tl.float32)
|
||||
gk_base_ptr = gk + (i_n * H + i_h) * T * K
|
||||
b_gk_last1 = tl.load(gk_base_ptr + last_idx * K + o_k1, mask=(o_k1 < K), other=0.0)
|
||||
b_h1 *= tl.exp(b_gk_last1)[:, None]
|
||||
if K > 64:
|
||||
o_k2 = 64 + o_k1
|
||||
b_gk_last2 = tl.load(gk_base_ptr + last_idx * K + o_k2, mask=(o_k2 < K), other=0.0)
|
||||
b_h2 *= tl.exp(b_gk_last2)[:, None]
|
||||
if K > 128:
|
||||
o_k3 = 128 + o_k1
|
||||
b_gk_last3 = tl.load(gk_base_ptr + last_idx * K + o_k3, mask=(o_k3 < K), other=0.0)
|
||||
b_h3 *= tl.exp(b_gk_last3)[:, None]
|
||||
if K > 192:
|
||||
o_k4 = 192 + o_k1
|
||||
b_gk_last4 = tl.load(gk_base_ptr + last_idx * K + o_k4, mask=(o_k4 < K), other=0.0)
|
||||
b_h4 *= tl.exp(b_gk_last4)[:, None]
|
||||
|
||||
b_v = b_v.to(k.dtype.element_ty)
|
||||
|
||||
# Update hidden state: h += k @ v
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
if USE_GK:
|
||||
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (0, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = (b_k * tl.exp(b_gk_last1[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
|
||||
b_h1 += tl.dot(b_k, b_v)
|
||||
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
if USE_GK:
|
||||
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (64, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = (b_k * tl.exp(b_gk_last2[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
|
||||
b_h2 += tl.dot(b_k, b_v)
|
||||
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
if USE_GK:
|
||||
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (128, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = (b_k * tl.exp(b_gk_last3[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
|
||||
b_h3 += tl.dot(b_k, b_v)
|
||||
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
if USE_GK:
|
||||
p_gk = tl.make_block_ptr(gk_base_ptr, (K, T), (1, K), (192, i_t * BT), (64, BT), (0, 1))
|
||||
b_k = (b_k * tl.exp(b_gk_last4[:, None] - tl.load(p_gk, boundary_check=(0, 1)))).to(b_k.dtype)
|
||||
b_h4 += tl.dot(b_k, b_v)
|
||||
|
||||
# Store final state
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_ht = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
gk: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # default:64
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, u.shape[-1]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V).permute(0, 2, 1, 3, 4).contiguous()
|
||||
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
|
||||
BV = 128
|
||||
|
||||
v_new = torch.empty_like(u).permute(0, 2, 1, 3).contiguous() if save_new_value else None
|
||||
k = k.permute(0, 2, 1, 3).contiguous()
|
||||
w = w.permute(0, 2, 1, 3).contiguous()
|
||||
u = u.permute(0, 2, 1, 3).contiguous()
|
||||
g = g.permute(0, 2, 1).contiguous()
|
||||
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[(triton.cdiv(V, BV), N * H)](
|
||||
k=k,
|
||||
v=u,
|
||||
w=w,
|
||||
v_new=v_new,
|
||||
g=g,
|
||||
gk=gk,
|
||||
h=h,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
T=T,
|
||||
H=H,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BV=BV,
|
||||
NT=NT,
|
||||
)
|
||||
h = h.permute(0, 2, 1, 3, 4).contiguous()
|
||||
v_new = v_new.permute(0, 2, 1, 3).contiguous()
|
||||
return h, v_new, final_state
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_GK": lambda args: args["gk"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["dh0"] is not None,
|
||||
"USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.autotune(
|
||||
configs=get_autotune_config(multibuffer_list=(True, False)),
|
||||
key=["H", "K", "V", "BT", "BV", "USE_G", "IS_VARLEN"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
|
||||
q,
|
||||
k,
|
||||
w,
|
||||
g,
|
||||
gk,
|
||||
dht,
|
||||
dh0,
|
||||
do,
|
||||
dh,
|
||||
dv,
|
||||
dv2,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_GK: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
USE_FINAL_STATE_GRADIENT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
T_all = T
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_dh2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_dh3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
|
||||
q += (bos * H + i_h) * K
|
||||
k += (bos * H + i_h) * K
|
||||
w += (bos * H + i_h) * K
|
||||
do += (bos * H + i_h) * V
|
||||
dv += (bos * H + i_h) * V
|
||||
dv2 += (bos * H + i_h) * V
|
||||
dh += (boh * H + i_h) * K * V
|
||||
if USE_GK:
|
||||
gk += (bos * H + i_h) * K
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
dh0 += i_nh * K * V
|
||||
if USE_FINAL_STATE_GRADIENT:
|
||||
dht += i_nh * K * V
|
||||
|
||||
stride_v = H * V
|
||||
stride_h = H * K * V
|
||||
stride_k = H * K
|
||||
|
||||
if USE_FINAL_STATE_GRADIENT:
|
||||
p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
b_dh1 += tl.load(p_dht1, boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
b_dh2 += tl.load(p_dht2, boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
b_dh3 += tl.load(p_dht3, boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
b_dh4 += tl.load(p_dht4, boundary_check=(0, 1))
|
||||
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
if USE_G:
|
||||
if IS_VARLEN:
|
||||
bos_g = i_h * T_all + bos
|
||||
else:
|
||||
bos_g = (i_n * H + i_h) * T_all
|
||||
bg_last = tl.load(g + bos_g + last_idx)
|
||||
bg_last_exp = tl.exp(bg_last)
|
||||
p_g = tl.make_block_ptr(
|
||||
base=g + bos_g, shape=(T,), strides=(1,), offsets=(i_t * BT,), block_shape=(BT,), order=(0,)
|
||||
)
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_g_exp = tl.exp(b_g)
|
||||
|
||||
p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
# Update dv
|
||||
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
if USE_GK:
|
||||
o_k1 = tl.arange(0, 64)
|
||||
b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0)
|
||||
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
|
||||
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
if USE_GK:
|
||||
o_k2 = 64 + o_k1
|
||||
b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0)
|
||||
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
|
||||
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
if USE_GK:
|
||||
o_k3 = 128 + o_k1
|
||||
b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0)
|
||||
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
|
||||
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
if USE_GK:
|
||||
o_k4 = 192 + o_k1
|
||||
b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0)
|
||||
b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))
|
||||
|
||||
if USE_G:
|
||||
m_t = (i_t * BT + tl.arange(0, BT)).to(tl.float32) < T
|
||||
b_dv *= (m_t * tl.exp(bg_last - b_g))[:, None]
|
||||
b_dv += tl.load(p_dv, boundary_check=(0, 1))
|
||||
|
||||
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# Update dh
|
||||
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
|
||||
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
if USE_G:
|
||||
b_dh1 *= bg_last_exp
|
||||
b_q = b_q * b_g_exp[None, :]
|
||||
if USE_GK:
|
||||
b_dh1 *= tl.exp(b_gk_last1[:, None])
|
||||
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
|
||||
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
if USE_G:
|
||||
b_dh2 *= bg_last_exp
|
||||
b_q = b_q * b_g_exp[None, :]
|
||||
if USE_GK:
|
||||
b_dh2 *= tl.exp(b_gk_last2[:, None])
|
||||
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
|
||||
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
if USE_G:
|
||||
b_dh3 *= bg_last_exp
|
||||
b_q = b_q * b_g_exp[None, :]
|
||||
if USE_GK:
|
||||
b_dh3 *= tl.exp(b_gk_last3[:, None])
|
||||
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
|
||||
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
if USE_G:
|
||||
b_dh4 *= bg_last_exp
|
||||
b_q = b_q * b_g_exp[None, :]
|
||||
if USE_GK:
|
||||
b_dh4 *= tl.exp(b_gk_last4[:, None])
|
||||
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_bwd_dhu(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
dv: torch.Tensor,
|
||||
g: torch.Tensor | None = None,
|
||||
gk: torch.Tensor | None = None,
|
||||
h0: torch.Tensor | None = None,
|
||||
dht: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
chunk_indices: torch.LongTensor | None = None,
|
||||
use_exp2: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *q.shape, do.shape[-1]
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
BT = 64
|
||||
assert K <= 256, "current kernel does not support head dimension being larger than 256."
|
||||
|
||||
if chunk_indices is None and cu_seqlens is not None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
||||
|
||||
dh = q.new_empty(B, NT, H, K, V)
|
||||
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
|
||||
dv2 = torch.empty_like(dv)
|
||||
|
||||
BV = 128
|
||||
|
||||
g = g.permute(0, 2, 1).contiguous()
|
||||
|
||||
chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[(triton.cdiv(V, BV), N * H)](
|
||||
q=q,
|
||||
k=k,
|
||||
w=w,
|
||||
g=g,
|
||||
gk=gk,
|
||||
dht=dht,
|
||||
dh0=dh0,
|
||||
do=do,
|
||||
dh=dh,
|
||||
dv=dv,
|
||||
dv2=dv2,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
scale=scale,
|
||||
T=T,
|
||||
H=H,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BV=BV,
|
||||
)
|
||||
return dh, dh0, dv2
|
||||
347
src/llamafactory/third_party/triton/chunk_gated_delta_rule.py
vendored
Normal file
347
src/llamafactory/third_party/triton/chunk_gated_delta_rule.py
vendored
Normal file
@@ -0,0 +1,347 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
|
||||
from .wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(
|
||||
k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32
|
||||
)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
chunk_size=chunk_size,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
return g, o, A, final_state
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_bwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
dht: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
dv = chunk_bwd_dv_local(
|
||||
q=q,
|
||||
k=k,
|
||||
g=g,
|
||||
do=do,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
|
||||
q=q,
|
||||
k=k,
|
||||
w=w,
|
||||
g=g,
|
||||
h0=initial_state,
|
||||
dht=dht,
|
||||
do=do,
|
||||
dv=dv,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
dq, dk, dw, dg = chunk_bwd_dqkwg(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
w=w,
|
||||
g=g,
|
||||
h=h,
|
||||
dv=dv,
|
||||
do=do,
|
||||
dh=dh,
|
||||
chunk_size=chunk_size,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
dk2, dv, db, dg2 = prepare_wy_repr_bwd(
|
||||
k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size
|
||||
)
|
||||
dk.add_(dk2)
|
||||
dg.add_(dg2)
|
||||
if dg.dtype != torch.float32:
|
||||
raise ValueError(f"dg current type is {dg.dtype} , should be float32")
|
||||
dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False)
|
||||
return dq, dk, dv, db, dg, dh0
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@autocast_custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
q_rstd, k_rstd = None, None
|
||||
g, o, A, final_state = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
ctx.chunk_size = chunk_size
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@autocast_custom_bwd
|
||||
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
|
||||
q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
|
||||
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
A=A,
|
||||
scale=ctx.scale,
|
||||
initial_state=initial_state,
|
||||
do=do,
|
||||
dht=dht,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=ctx.chunk_size,
|
||||
)
|
||||
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
head_first: bool = False,
|
||||
):
|
||||
r"""Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, H, V]`.
|
||||
g (torch.Tensor):
|
||||
(forget) gating tensor (in log space!) of shape `[B, T, H]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, H]`.
|
||||
scale (Optional[float]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||
use_qk_l2norm_in_kernel (bool):
|
||||
Whether to apply L2norm to the q/k tensor internally. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
head_first (Optional[bool]):
|
||||
Whether the inputs are in the head-first format. Default: `False`.
|
||||
This argument has been deprecated.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
""" # noqa: D205
|
||||
if q.dtype != k.dtype or k.dtype != v.dtype:
|
||||
raise ValueError(
|
||||
f"q current type is {q.dtype} , k current type is {k.dtype} ,v current type is {v.dtype} , they should are equal"
|
||||
)
|
||||
if q.dtype == torch.float32:
|
||||
raise ValueError("ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.")
|
||||
if len(beta.shape) != 3:
|
||||
raise ValueError(
|
||||
f"beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
)
|
||||
|
||||
if head_first:
|
||||
warnings.warn(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead."
|
||||
)
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
||||
)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
||||
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
||||
original_dtype = x.dtype
|
||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
# Counteract verl's autocast promotion (bf16 -> fp32) by restoring original dtype
|
||||
return (x * inv_norm).to(original_dtype)
|
||||
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm(q, dim=-1, eps=1e-6)
|
||||
k = l2norm(k, dim=-1, eps=1e-6)
|
||||
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, False, chunk_size
|
||||
)
|
||||
return o, final_state
|
||||
617
src/llamafactory/third_party/triton/chunk_o.py
vendored
Normal file
617
src/llamafactory/third_party/triton/chunk_o.py
vendored
Normal file
@@ -0,0 +1,617 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import exp, prepare_chunk_indices, prepare_chunk_offsets
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
|
||||
"USE_DW": lambda args: args["dw"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_bwd_kernel_dqkwg(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
g_gamma,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dg,
|
||||
w,
|
||||
dv,
|
||||
dw,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
B: tl.constexpr,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_G_GAMMA: tl.constexpr,
|
||||
USE_DW: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
gdiff,
|
||||
):
|
||||
i_t, i_b = tl.program_id(0), tl.program_id(1)
|
||||
T_max = T
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
total = B * T_max
|
||||
T = eos - bos
|
||||
else:
|
||||
NT = tl.cdiv(T, BT)
|
||||
i_tg = i_b * NT + i_t
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
total = B * T_max
|
||||
|
||||
NK = tl.cdiv(K, BK)
|
||||
for i_k in range(NK):
|
||||
if USE_G:
|
||||
dg_k = dg + i_k * total * H
|
||||
|
||||
for i_h in range(H):
|
||||
v_h = v + (bos * H + i_h) * V
|
||||
do_h = do + (bos * H + i_h) * V
|
||||
h_h = h + (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
dh_h = dh + (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
q_h = q + (bos * H + i_h) * K
|
||||
k_h = k + (bos * H + i_h) * K
|
||||
dq_h = dq + (bos * H + i_h) * K
|
||||
dk_h = dk + (bos * H + i_h) * K
|
||||
|
||||
if USE_DW:
|
||||
w_h = w + (bos * H + i_h) * K # noqa: F841
|
||||
dw_h = dw + (bos * H + i_h) * K
|
||||
dv_h = dv + (bos * H + i_h) * V
|
||||
|
||||
if USE_G:
|
||||
if IS_VARLEN:
|
||||
dg_h = dg_k + i_h * T_max + bos
|
||||
g_h = g + i_h * T_max + bos
|
||||
else:
|
||||
dg_h = dg_k + (i_b * H + i_h) * T_max
|
||||
g_h = g + (i_b * H + i_h) * T_max
|
||||
b_dg_last = tl.zeros(
|
||||
[
|
||||
1,
|
||||
],
|
||||
dtype=tl.float32,
|
||||
)
|
||||
|
||||
if USE_G_GAMMA:
|
||||
b_gamma = tl.load(g_gamma + i_h)
|
||||
b_g = b_gamma * (tl.arange(0, BT) + 1)
|
||||
b_g_last = b_gamma * min(BT, T - i_t * BT)
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
p_dh = tl.make_block_ptr(dh_h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
b_dg_last += tl.sum(b_h * b_dh)
|
||||
|
||||
b_ds += tl.dot(b_do, tl.trans(b_v))
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
|
||||
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
|
||||
|
||||
if USE_DW:
|
||||
p_dv = tl.make_block_ptr(dv_h, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
||||
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
|
||||
|
||||
if USE_DW:
|
||||
p_dw = tl.make_block_ptr(dw_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
tl.debug_barrier()
|
||||
|
||||
p_q = tl.make_block_ptr(q_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
|
||||
p_dq = tl.make_block_ptr(dq_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk_h, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
|
||||
if USE_G:
|
||||
b_dg = tl.zeros(
|
||||
[
|
||||
BT,
|
||||
],
|
||||
dtype=tl.float32,
|
||||
)
|
||||
p_g = tl.make_block_ptr(g_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_g_last = tl.load(g_h + (min(i_t * BT + BT, T) - 1) * 1)
|
||||
b_dg_last *= tl.exp(b_g_last)
|
||||
|
||||
b_dq = b_dq * tl.exp(b_g)[:, None] * scale
|
||||
b_dg += tl.sum(b_dq * b_q, axis=1)
|
||||
|
||||
b_dk = b_dk * tl.where(m_t, tl.exp(-b_g + b_g_last), 0)[:, None]
|
||||
b_dg -= tl.sum(b_k * b_dk, axis=1)
|
||||
b_dg_last += tl.sum(b_dk * b_k)
|
||||
|
||||
if IS_VARLEN:
|
||||
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
|
||||
else:
|
||||
p_gdiff = tl.make_block_ptr(
|
||||
gdiff + i_b * H * NT * BT * BT + i_h * NT * BT * BT + i_t * BT * BT,
|
||||
(BT, BT),
|
||||
(BT, 1),
|
||||
(0, 0),
|
||||
(BT, BT),
|
||||
(1, 0),
|
||||
)
|
||||
gdiff_ = tl.load(p_gdiff)
|
||||
b_ds = b_ds * gdiff_ * scale
|
||||
|
||||
b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
|
||||
b_dg += tl.sum(b_ds2, axis=1)
|
||||
b_dg -= tl.sum(b_ds2, axis=0)
|
||||
|
||||
b_ds = b_ds.to(b_k.dtype)
|
||||
b_dq += tl.dot(b_ds, b_k)
|
||||
b_dk += tl.dot(tl.trans(b_ds), b_q)
|
||||
p_dg = tl.make_block_ptr(dg_h, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
|
||||
last_index_local = min(BT, T - i_t * BT) - 1
|
||||
if last_index_local >= 0:
|
||||
is_last_mask = tl.arange(0, BT) == last_index_local
|
||||
b_dg = tl.where(is_last_mask, b_dg + b_dg_last, b_dg)
|
||||
else:
|
||||
pass
|
||||
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
|
||||
|
||||
elif USE_G_GAMMA:
|
||||
b_dq = b_dq * exp(b_g)[:, None] * scale
|
||||
b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None]
|
||||
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
|
||||
b_ds = b_ds.to(b_k.dtype)
|
||||
b_dq += tl.dot(b_ds, b_k)
|
||||
b_dk += tl.dot(tl.trans(b_ds), b_q)
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
else:
|
||||
b_ds = tl.where(m_A, b_ds, 0)
|
||||
b_ds = b_ds.to(b_k.dtype)
|
||||
b_dq += tl.dot(b_ds, b_k)
|
||||
b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
|
||||
b_dq *= scale
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_bwd_kernel_dv_local(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
g_gamma,
|
||||
do,
|
||||
dv,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_G_GAMMA: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_b = tl.program_id(0), tl.program_id(1)
|
||||
T_max = T
|
||||
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
for i_h in range(H):
|
||||
offset_kh = (bos * H + i_h) * K
|
||||
offset_vh = (bos * H + i_h) * V
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + offset_kh, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_q = tl.make_block_ptr(q + offset_kh, (K, T), (1, H * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_A += tl.dot(b_k, b_q)
|
||||
|
||||
if USE_G:
|
||||
if IS_VARLEN:
|
||||
offset_g = i_h * T_max + bos
|
||||
else:
|
||||
offset_g = i_b * H * T_max + i_h * T_max
|
||||
|
||||
p_g = tl.make_block_ptr(g + offset_g, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
|
||||
if USE_G_GAMMA:
|
||||
b_gamma = tl.load(g_gamma + i_h)
|
||||
b_g = b_gamma * (tl.arange(0, BT) + 1)
|
||||
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
|
||||
if USE_G:
|
||||
b_A = tl.where(m_A, b_A * tl.exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
|
||||
else:
|
||||
b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty)
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_do = tl.make_block_ptr(do + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + offset_vh, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_G_GAMMA": lambda args: args["g_gamma"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
g_gamma,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_G_GAMMA: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
T_max = T
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
for i_n in range(N):
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
core_id = tl.program_id(0)
|
||||
total_cores = tl.num_programs(0)
|
||||
base_chunks_per_pid = NT // total_cores
|
||||
remainder = NT % total_cores
|
||||
|
||||
if core_id < remainder:
|
||||
chunks_this_pid = base_chunks_per_pid + 1
|
||||
start_idx = core_id * chunks_this_pid
|
||||
else:
|
||||
chunks_this_pid = base_chunks_per_pid
|
||||
start_idx = core_id * base_chunks_per_pid + remainder
|
||||
|
||||
# offset calculation
|
||||
for i_h in range(0, H):
|
||||
q_offset = (bos * Hg + i_h // (H // Hg)) * K
|
||||
k_offset = (bos * Hg + i_h // (H // Hg)) * K
|
||||
v_offset = (bos * H + i_h) * V
|
||||
o_offset = (bos * H + i_h) * V
|
||||
|
||||
for i_t in range(start_idx, start_idx + chunks_this_pid):
|
||||
i_tg = boh + i_t
|
||||
h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + q_offset, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_k = tl.make_block_ptr(
|
||||
k + k_offset, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||
)
|
||||
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, b_h)
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
if USE_G:
|
||||
if IS_VARLEN:
|
||||
p_g = tl.make_block_ptr(g + bos + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
else:
|
||||
p_g = tl.make_block_ptr(g + bos * H + i_h * T_max, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_o = b_o * exp(b_g)[:, None]
|
||||
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
|
||||
if USE_G_GAMMA:
|
||||
b_gamma = tl.load(g_gamma + i_h)
|
||||
b_g = b_gamma * (tl.arange(0, BT) + 1)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_A = o_i[:, None] >= o_i[None, :]
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v + v_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + o_offset, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
# already solved by triton v3.2 or higher
|
||||
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_bwd_dqkwg(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
dh: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
g_gamma: Optional[torch.Tensor] = None,
|
||||
dv: Optional[torch.Tensor] = None,
|
||||
w: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
scale: float = 1.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
BK = 128 if cu_seqlens is None else 64
|
||||
BV = 64
|
||||
NK = triton.cdiv(K, BK)
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
|
||||
dw = torch.empty_like(w) if w is not None else None
|
||||
grid = (NT, B)
|
||||
|
||||
if cu_seqlens is None:
|
||||
if NT * BT == T:
|
||||
g_ = g.reshape(B, H, NT, BT)
|
||||
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
|
||||
g_diff = g_diff.clamp(-60, 60).exp()
|
||||
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
|
||||
else:
|
||||
diff = NT * BT - T
|
||||
g_ = torch.cat((g, torch.zeros(B, H, diff).to(g.device)), dim=-1).reshape(B, H, NT, BT)
|
||||
g_diff = g_[:, :, :, :, None] - g_[:, :, :, None, :]
|
||||
g_diff = g_diff.clamp(-60, 60).exp()
|
||||
g_diff[:, :, :] *= torch.tril(torch.ones(BT, BT), diagonal=0).to(g.device)
|
||||
bias = torch.arange(0, BT).to(g.device)
|
||||
o_t = (NT - 1) * BT + bias
|
||||
m_t = o_t < T
|
||||
m_A = m_t[:, None] & m_t
|
||||
g_diff[:, :, -1] *= m_A
|
||||
else:
|
||||
g_diff = None
|
||||
|
||||
chunk_bwd_kernel_dqkwg[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
h=h,
|
||||
g=g,
|
||||
g_gamma=g_gamma,
|
||||
do=do,
|
||||
dh=dh,
|
||||
dv=dv,
|
||||
w=w,
|
||||
dw=dw,
|
||||
dq=dq,
|
||||
dk=dk,
|
||||
dg=dg,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
scale=scale,
|
||||
B=B,
|
||||
T=T,
|
||||
H=H,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
gdiff=g_diff,
|
||||
)
|
||||
|
||||
if dg is not None:
|
||||
dg = dg.sum(0)
|
||||
dg = dg.transpose(1, 2).contiguous()
|
||||
return dq, dk, dw, dg
|
||||
|
||||
|
||||
def chunk_bwd_dv_local(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
g_gamma: Optional[torch.Tensor] = None,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, H, K, V = *k.shape, do.shape[-1]
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
|
||||
BK = 128
|
||||
BV = 128
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
dv = torch.empty_like(do)
|
||||
grid = (NT, B)
|
||||
chunk_bwd_kernel_dv_local[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
g=g,
|
||||
g_gamma=g_gamma,
|
||||
do=do,
|
||||
dv=dv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
scale=scale,
|
||||
T=T,
|
||||
H=H,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
return dv
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
g_gamma: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) # noqa: F841
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
if cu_seqlens is None:
|
||||
N, chunk_offsets = B, None
|
||||
else:
|
||||
N, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta["BV"]), N * H)
|
||||
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
h = h.contiguous()
|
||||
CV_kernel_num = 24
|
||||
chunk_fwd_kernel_o[(CV_kernel_num,)](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
g_gamma,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
N=N,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=128,
|
||||
BV=128,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
bwd_chunk_dqkwg = chunk_bwd_dqkwg
|
||||
bwd_chunk_dv_local = chunk_bwd_dv_local
|
||||
359
src/llamafactory/third_party/triton/chunk_scaled_dot_kkt.py
vendored
Normal file
359
src/llamafactory/third_party/triton/chunk_scaled_dot_kkt.py
vendored
Normal file
@@ -0,0 +1,359 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
g,
|
||||
beta,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
NT,
|
||||
B,
|
||||
TOTAL_TASKS,
|
||||
):
|
||||
core_id = tl.program_id(0)
|
||||
num_blocks = tl.num_programs(0)
|
||||
T_max = T
|
||||
|
||||
base_tasks_per_block = TOTAL_TASKS // num_blocks
|
||||
remainder_tasks = TOTAL_TASKS % num_blocks
|
||||
|
||||
if core_id < remainder_tasks:
|
||||
tasks_this_core = base_tasks_per_block + 1
|
||||
start_idx = core_id * tasks_this_core
|
||||
else:
|
||||
tasks_this_core = base_tasks_per_block
|
||||
start_idx = core_id * base_tasks_per_block + remainder_tasks
|
||||
|
||||
for idx in range(start_idx, start_idx + tasks_this_core):
|
||||
i_b = idx // NT
|
||||
local_idx = idx % NT
|
||||
|
||||
if IS_VARLEN:
|
||||
i_n = tl.load(chunk_indices + local_idx * 2).to(tl.int32)
|
||||
i_t = tl.load(chunk_indices + local_idx * 2 + 1).to(tl.int32)
|
||||
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
|
||||
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T_local = eos - bos
|
||||
else:
|
||||
bos, eos = 0, T
|
||||
i_t = local_idx
|
||||
T_local = T
|
||||
|
||||
for i_h in range(H):
|
||||
k_batch_off = i_b * T_max * H * K
|
||||
beta_batch_off = i_b * H * T_max
|
||||
g_batch_off = i_b * H * T_max
|
||||
A_batch_off = i_b * T_max * H * BT
|
||||
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + beta_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + k_batch_off + (bos * H + i_h) * K,
|
||||
(T_local, K),
|
||||
(H * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
dot_product = tl.dot(b_k, tl.trans(b_k))
|
||||
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
o_t = o_t.to(tl.float32)
|
||||
T_mask = (o_t < T_local).to(tl.float32)
|
||||
|
||||
row_indices = tl.arange(0, BT)[:, None]
|
||||
col_indices = tl.arange(0, BT)[None, :]
|
||||
tril_mask = (row_indices > col_indices).to(tl.float32)
|
||||
tril_mask = tril_mask * T_mask[:, None]
|
||||
masked_dot = dot_product * tril_mask
|
||||
b_A += masked_dot
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(
|
||||
g + g_batch_off + bos + i_h * T_max, (T_local,), (1,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_g_diff = tl.minimum(tl.maximum(b_g_diff, -50.0), 50.0)
|
||||
b_A *= tl.exp(b_g_diff)
|
||||
b_A *= b_beta[:, None]
|
||||
|
||||
p_A = tl.make_block_ptr(
|
||||
A + A_batch_off + (bos * H + i_h) * BT, (T_local, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.autotune(configs=[triton.Config({"BK": BK}) for BK in [32, 64]], key=["BC"])
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
|
||||
k,
|
||||
g,
|
||||
beta,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_i, i_j = i_c // NC, i_c % NC
|
||||
|
||||
for i_h in range(H):
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T_val = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
T_val = T
|
||||
|
||||
should_compute = (i_t * BT + i_i * BC < T_val) and (i_i > i_j)
|
||||
|
||||
if should_compute:
|
||||
k_ptr = k + (bos * H + i_h) * K
|
||||
g_ptr = g + (bos * H + i_h) * K
|
||||
A_ptr = A + (bos * H + i_h) * BT
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T_val,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
b_A = tl.zeros([BC, BC], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
|
||||
)
|
||||
p_g = tl.make_block_ptr(
|
||||
g_ptr, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
|
||||
)
|
||||
b_kt = tl.make_block_ptr(
|
||||
k_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
|
||||
)
|
||||
p_gk = tl.make_block_ptr(
|
||||
g_ptr, (K, T_val), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
|
||||
)
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
m_k = o_k < K
|
||||
b_gn = tl.load(g_ptr + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.exp(b_g - b_gn[None, :])
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kt = tl.load(b_kt, boundary_check=(0, 1)) * tl.exp(b_gn[:, None] - b_gk)
|
||||
b_A += tl.dot(b_k, b_kt)
|
||||
b_A *= b_beta[:, None]
|
||||
|
||||
p_A = tl.make_block_ptr(A_ptr, (T_val, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
|
||||
k,
|
||||
g,
|
||||
beta,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_i, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
for i_h in range(H):
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T_val = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
T_val = T
|
||||
|
||||
should_compute = i_t * BT + i_i * BC < T_val
|
||||
|
||||
if should_compute:
|
||||
o_i = tl.arange(0, BC)
|
||||
o_k = tl.arange(0, BK)
|
||||
m_k = o_k < K
|
||||
m_A = (i_t * BT + i_i * BC + o_i) < T_val
|
||||
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
|
||||
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
|
||||
)
|
||||
p_g = tl.make_block_ptr(
|
||||
g + (bos * H + i_h) * K, (T_val, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)
|
||||
)
|
||||
p_beta = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
|
||||
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1)) * tl.load(p_beta, mask=m_A, other=0)[:, None]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
|
||||
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
|
||||
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
|
||||
|
||||
for j in range(0, min(BC, T_val - i_t * BT - i_i * BC)):
|
||||
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
|
||||
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
|
||||
b_A = tl.sum(b_k * b_kt[None, :] * tl.exp(b_g - b_gk[None, :]), 1)
|
||||
# 转化成f32
|
||||
o_i_tmp = o_i.to(tl.float32)
|
||||
b_A = tl.where(o_i_tmp > j, b_A, 0.0)
|
||||
|
||||
tl.store(A + o_A + j, b_A, mask=m_A)
|
||||
p_kt += H * K
|
||||
p_gk += H * K
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
gk: Optional[torch.Tensor] = None,
|
||||
beta: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
r"""Compute beta * K * K^T.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor):
|
||||
The key tensor of shape `[B, T, H, K]`.
|
||||
beta (torch.Tensor):
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||
gk (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float32`
|
||||
|
||||
Returns:
|
||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||
"""
|
||||
B, T, H, K = k.shape
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
beta = beta.transpose(1, 2).contiguous()
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
BK = 128
|
||||
kernel_num = 24
|
||||
|
||||
if gk is None:
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
chunk_scaled_dot_kkt_fwd_kernel[(kernel_num,)](
|
||||
k=k,
|
||||
g=g,
|
||||
beta=beta,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
K=K,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
NT=NT,
|
||||
B=B,
|
||||
TOTAL_TASKS=B * NT,
|
||||
)
|
||||
return A
|
||||
|
||||
BC = min(16, BT)
|
||||
NC = triton.cdiv(BT, BC)
|
||||
BK = max(triton.next_power_of_2(K), 16)
|
||||
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
grid = (NT, NC * NC, B)
|
||||
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
|
||||
k=k,
|
||||
g=gk,
|
||||
beta=beta,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
K=K,
|
||||
BT=BT,
|
||||
BC=BC,
|
||||
NC=NC,
|
||||
)
|
||||
|
||||
grid = (NT, NC, B)
|
||||
chunk_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
|
||||
k=k,
|
||||
g=gk,
|
||||
beta=beta,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
K=K,
|
||||
BT=BT,
|
||||
BC=BC,
|
||||
BK=BK,
|
||||
)
|
||||
return A
|
||||
147
src/llamafactory/third_party/triton/cumsum.py
vendored
Normal file
147
src/llamafactory/third_party/triton/cumsum.py
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{"HAS_SCALE": lambda args: args["scale"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
scale,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
HAS_SCALE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
CHUNK_SIZE: tl.constexpr = 64,
|
||||
):
|
||||
i_block, i_b = tl.program_id(0), tl.program_id(1)
|
||||
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
|
||||
|
||||
if IS_VARLEN:
|
||||
i_s, i_block = (
|
||||
tl.load(chunk_indices + i_block * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32),
|
||||
)
|
||||
|
||||
bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||
b_s = tl.load(ptr_s, boundary_check=(0,)).to(tl.float32)
|
||||
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
|
||||
b_s = tl.trans(b_s, (1, 0, 2))
|
||||
b_o = tl.cumsum(b_s, axis=0)
|
||||
if REVERSE:
|
||||
b_z = tl.sum(b_s, axis=0)
|
||||
b_o = -b_o + b_z[None] + b_s
|
||||
if HAS_SCALE:
|
||||
b_o *= scale
|
||||
b_o = tl.trans(b_o, (1, 0, 2))
|
||||
b_o = tl.reshape(b_o, (BLOCK_T, H))
|
||||
|
||||
tl.store(ptr_o, b_o.to(ptr_o.dtype.element_ty), boundary_check=(0,))
|
||||
return
|
||||
|
||||
|
||||
def chunk_local_cumsum_scalar(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
) -> torch.Tensor:
|
||||
B, T, H = g.shape
|
||||
if chunk_size != 2 ** (chunk_size.bit_length() - 1):
|
||||
raise ValueError(f"chunk_size must be a power of 2, chunk_size is{chunk_size}")
|
||||
# We adjust the tiling strategy to prevent overflow in in backward passes and context parallel scenarios
|
||||
# while maximizing UB utilization where possible.
|
||||
# The tiling strategy is as follows:
|
||||
# 1. BT must be greater than or equal to chunk_size.
|
||||
# 2. UB estimation varies directly with H.
|
||||
# 3. BT in reverse mode is smaller than in forward mode.
|
||||
BT = max(chunk_size, triton.next_power_of_2((1 << 11 if reverse else 1 << 12) // H))
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (NT, B)
|
||||
chunk_local_cumsum_scalar_kernel[grid](
|
||||
s=g_org,
|
||||
o=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BLOCK_T=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
CHUNK_SIZE=chunk_size,
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cu_seqlens is not None:
|
||||
if g.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}"
|
||||
)
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(
|
||||
g=g,
|
||||
chunk_size=chunk_size,
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported input shape {g.shape}, "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise"
|
||||
)
|
||||
272
src/llamafactory/third_party/triton/solve_tril.py
vendored
Normal file
272
src/llamafactory/third_party/triton/solve_tril.py
vendored
Normal file
@@ -0,0 +1,272 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import input_guard, make_tensor_descriptor, prepare_chunk_indices
|
||||
|
||||
|
||||
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T", "TPP"])
|
||||
def solve_tril_16x16_kernel(
|
||||
A,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
TPP: tl.constexpr,
|
||||
USE_TMA: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
DOT_PRECISION: tl.constexpr,
|
||||
):
|
||||
pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = pid_bh // H, pid_bh % H
|
||||
|
||||
base_t = pid_t * TPP
|
||||
|
||||
if IS_VARLEN:
|
||||
i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
|
||||
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
|
||||
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T_eff = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
T_eff = T
|
||||
|
||||
o_i = tl.arange(0, 16) # noqa: F841
|
||||
o_i_fp32 = tl.arange(0, 16).to(tl.float32)
|
||||
m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
|
||||
m_I = o_i_fp32[:, None] == o_i_fp32[None, :]
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ai = Ai + (bos * H + i_h) * BT
|
||||
|
||||
for tpp in tl.static_range(0, TPP):
|
||||
tile_t = base_t + tpp
|
||||
tile_row = tile_t * 16
|
||||
|
||||
offset = (tile_t * 16) % BT
|
||||
|
||||
if not USE_TMA:
|
||||
p_A = tl.make_block_ptr(A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0))
|
||||
b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16])
|
||||
desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16])
|
||||
b_A_raw = desc.load([tile_row, offset]).to(tl.float32)
|
||||
|
||||
b_A_neg = -b_A_raw
|
||||
b_A = b_A_neg * m_A
|
||||
for i in range(2, min(16, T_eff - tile_row)):
|
||||
slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1])
|
||||
b_a_val = tl.reshape(slice_res, (16,), can_reorder=True)
|
||||
dot_prod = tl.sum(b_a_val[:, None] * b_A, 0)
|
||||
b_a_update = b_a_val + dot_prod
|
||||
b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
|
||||
b_A += m_I
|
||||
|
||||
if not USE_TMA:
|
||||
p_Ai = tl.make_block_ptr(Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0))
|
||||
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||
else:
|
||||
desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T", "TPP"])
|
||||
def merge_16x16_to_32x32_inverse_kernel(
|
||||
A,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
TPP: tl.constexpr,
|
||||
USE_TMA: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
DOT_PRECISION: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
m_A = o_i[:, None] > o_i[None, :]
|
||||
m_I = o_i[:, None] == o_i[None, :]
|
||||
A += (bos * H + i_h) * BT
|
||||
Ai += (bos * H + i_h) * BT
|
||||
|
||||
if not USE_TMA:
|
||||
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
|
||||
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
|
||||
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
|
||||
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
|
||||
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
|
||||
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
|
||||
|
||||
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
|
||||
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
|
||||
|
||||
for i in range(2, min(16, T - i_t * BT)):
|
||||
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
|
||||
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
|
||||
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
|
||||
for i in range(16 + 2, min(32, T - i_t * BT)):
|
||||
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
|
||||
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
|
||||
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
|
||||
|
||||
b_Ai_11 += m_I
|
||||
b_Ai_22 += m_I
|
||||
|
||||
if not USE_TMA:
|
||||
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
|
||||
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
|
||||
|
||||
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
|
||||
|
||||
if not USE_TMA:
|
||||
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
|
||||
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
|
||||
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
|
||||
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||
else:
|
||||
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_64x64_kernel(
|
||||
A,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
USE_TMA: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
DOT_PRECISION: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_i = tl.arange(0, 64)
|
||||
m_I = o_i[:, None] == o_i[None, :]
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ai = Ai + (bos * H + i_h) * 64
|
||||
|
||||
offset = (i_t * 64) % BT
|
||||
if not USE_TMA:
|
||||
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 64, offset), (64, 64), (1, 0))
|
||||
b_A = -tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [64, 64])
|
||||
desc_o = make_tensor_descriptor(Ai, [T, 64], [H * 64, 1], [64, 64])
|
||||
b_A = -desc.load([i_t * 64, offset]).to(tl.float32)
|
||||
|
||||
for i in range(2, min(64, T - i_t * 64)):
|
||||
b_a = -tl.load(A + (i_t * 64 + i) * H * BT + o_i + offset)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
|
||||
b_A += m_I
|
||||
if not USE_TMA:
|
||||
p_Ai = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (64, 64), (1, 0))
|
||||
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
||||
else:
|
||||
desc_o.store([i_t * 64, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||
|
||||
|
||||
@input_guard
|
||||
def solve_tril(
|
||||
A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float
|
||||
) -> torch.Tensor:
|
||||
"""Compute the inverse of the matrix I + A
|
||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||
|
||||
Args:
|
||||
A (torch.Tensor):
|
||||
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor. Default: `None`.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`.
|
||||
If `None`, the output dtype will be the same as the input dtype.
|
||||
|
||||
Returns:
|
||||
(I + A)^-1 with the same shape as A
|
||||
""" # noqa: D205
|
||||
if A.shape[-1] not in [16, 32, 64]:
|
||||
raise ValueError(f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}")
|
||||
output_dtype = A.dtype if output_dtype is None else output_dtype
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
|
||||
Ai = torch.zeros_like(A, dtype=output_dtype)
|
||||
|
||||
if BT == 16:
|
||||
merge_fn = solve_tril_16x16_kernel
|
||||
elif BT == 32:
|
||||
merge_fn = merge_16x16_to_32x32_inverse_kernel
|
||||
elif BT == 64:
|
||||
merge_fn = solve_tril_64x64_kernel
|
||||
|
||||
merge_fn[NT, B * H](
|
||||
A=A,
|
||||
Ai=Ai,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
USE_TMA=False,
|
||||
DOT_PRECISION=FLA_TRIL_PRECISION,
|
||||
)
|
||||
return Ai
|
||||
359
src/llamafactory/third_party/triton/utils.py
vendored
Normal file
359
src/llamafactory/third_party/triton/utils.py
vendored
Normal file
@@ -0,0 +1,359 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.language.extra.libdevice as tldevice
|
||||
import triton.runtime.driver as driver
|
||||
from packaging import version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
||||
|
||||
|
||||
def tensor_cache(fn: Optional[Callable[..., torch.Tensor]] = None, *, maxsize: int = 1) -> Any:
|
||||
"""A decorator that caches the most recent results of a function with tensor inputs.
|
||||
|
||||
This decorator will store the outputs of the decorated function for the most recent
|
||||
set of input tensors, up to `maxsize` entries. If the function is called again with
|
||||
the same input tensors, it will return the cached result.
|
||||
|
||||
When maxsize=1 (default), the behavior is identical to caching only the most recent result.
|
||||
Can be used as @tensor_cache or @tensor_cache(maxsize=n).
|
||||
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor], optional):
|
||||
The function to be decorated when used without parentheses.
|
||||
maxsize (int):
|
||||
Maximum number of input combinations to cache. Default is 1.
|
||||
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with caching.
|
||||
"""
|
||||
if maxsize < 1:
|
||||
raise ValueError("maxsize must be at least 1")
|
||||
|
||||
def _is_match(a: Any, b: Any) -> bool:
|
||||
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
|
||||
return a is b
|
||||
try:
|
||||
return a == b
|
||||
except Exception:
|
||||
return a is b
|
||||
|
||||
def _make_wrapper(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
cache: list = []
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
for i, (cached_args, cached_kwargs, cached_result) in enumerate(cache):
|
||||
if len(args) == len(cached_args) and len(kwargs) == len(cached_kwargs):
|
||||
if all(_is_match(a, b) for a, b in zip(args, cached_args)) and all(
|
||||
k in cached_kwargs and _is_match(v, cached_kwargs[k]) for k, v in kwargs.items()
|
||||
):
|
||||
if i != 0:
|
||||
cache.insert(0, cache.pop(i))
|
||||
return cached_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
cache.insert(0, (args, kwargs, result))
|
||||
if len(cache) > maxsize:
|
||||
cache.pop()
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
if fn is not None:
|
||||
return _make_wrapper(fn)
|
||||
return _make_wrapper
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache(maxsize=3)
|
||||
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
||||
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||
|
||||
|
||||
def get_abs_err(x, y):
|
||||
return (x.detach() - y.detach()).flatten().abs().max().item()
|
||||
|
||||
|
||||
def get_err_ratio(x, y):
|
||||
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
|
||||
base = (x.detach()).flatten().square().mean().sqrt().item()
|
||||
return err / (base + 1e-8)
|
||||
|
||||
|
||||
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
|
||||
abs_atol = get_abs_err(ref, tri)
|
||||
msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
|
||||
logger.info(msg)
|
||||
error_rate = get_err_ratio(ref, tri)
|
||||
if abs_atol <= err_atol:
|
||||
return
|
||||
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
|
||||
if error_rate > ratio:
|
||||
warnings.warn(msg)
|
||||
else:
|
||||
assert error_rate < ratio, msg
|
||||
|
||||
|
||||
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
|
||||
# For Triton 3.3.x
|
||||
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
|
||||
elif hasattr(triton.language, "make_tensor_descriptor"):
|
||||
# For Triton 3.4.x and later
|
||||
make_tensor_descriptor = triton.language.make_tensor_descriptor
|
||||
else:
|
||||
"""
|
||||
Fallback implementation when TMA is not supported.
|
||||
Returns None to indicate TMA descriptors are unavailable.
|
||||
Just make triton compiler happy.
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def make_tensor_descriptor(
|
||||
base,
|
||||
shape,
|
||||
strides,
|
||||
block_shape,
|
||||
_builder=None,
|
||||
):
|
||||
return None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_available_device() -> str:
|
||||
try:
|
||||
return triton.runtime.driver.active.get_current_target().backend
|
||||
except BaseException:
|
||||
_cpu_device_warning()
|
||||
return "cpu"
|
||||
|
||||
|
||||
def map_triton_backend_to_torch_device() -> str:
|
||||
backend = get_available_device() # 'cuda' | 'hip' | 'xpu' | 'cpu' | ...
|
||||
return {"cuda": "cuda", "hip": "cuda", "xpu": "xpu"}.get(backend, backend)
|
||||
|
||||
|
||||
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
||||
device_torch_lib = getattr(torch, device)
|
||||
device_platform = get_available_device()
|
||||
is_amd = device_platform == "hip"
|
||||
is_nvidia = device_platform == "cuda"
|
||||
is_nvidia_hopper = is_nvidia and (
|
||||
"NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
|
||||
is_tma_supported = (
|
||||
(is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9)
|
||||
and os.environ.get("FLA_NO_USE_TMA", "0") != "1"
|
||||
and (
|
||||
hasattr(triton.language, "_experimental_make_tensor_descriptor")
|
||||
or hasattr(triton.language, "make_tensor_descriptor")
|
||||
)
|
||||
)
|
||||
|
||||
if is_nvidia and not is_tf32_supported:
|
||||
# Make old card happy, since triton will use tf32 by default.
|
||||
# This is a workaround for old nvidia card.
|
||||
os.environ["TRITON_F32_DEFAULT"] = "ieee"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def check_pytorch_version(version_s: str = "2.4") -> bool:
|
||||
return version.parse(torch.__version__) >= version.parse(version_s)
|
||||
|
||||
|
||||
if check_pytorch_version("2.4"):
|
||||
device = "cuda" if device == "cpu" else device
|
||||
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
||||
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
||||
|
||||
def custom_device_ctx(index: int):
|
||||
return device_torch_lib.device(index)
|
||||
else:
|
||||
assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0."
|
||||
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
||||
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
||||
|
||||
def custom_device_ctx(index: int):
|
||||
return torch.cuda.device(index)
|
||||
|
||||
|
||||
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""A decorator to make sure all input tensors are contiguous and set the device based on input tensors."""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
|
||||
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
tensor = arg
|
||||
break
|
||||
if tensor is None:
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value
|
||||
break
|
||||
|
||||
if tensor is not None:
|
||||
ctx = custom_device_ctx(tensor.device.index)
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
return fn(*contiguous_args, **contiguous_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _cpu_device_warning():
|
||||
warnings.warn(("Triton is not supported on current platform, roll back to CPU."), stacklevel=1)
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
|
||||
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
|
||||
|
||||
|
||||
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
|
||||
exp = tldevice.fast_expf
|
||||
exp2 = tldevice.exp2
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
exp = tl.exp
|
||||
exp2 = tl.math.exp2
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
def get_all_max_shared_mem():
|
||||
try:
|
||||
return [
|
||||
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
|
||||
for i in range(device_torch_lib.device_count())
|
||||
]
|
||||
except BaseException:
|
||||
_cpu_device_warning()
|
||||
return [-1]
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
ADA = 101376 # RTX 4090
|
||||
AMPERE = 166912 # A100
|
||||
HOPPER = 232448 # H100
|
||||
DEFAULT = 102400 # Default
|
||||
|
||||
@classmethod
|
||||
def get_shared_memory(cls, arch: str) -> int:
|
||||
try:
|
||||
return cls[arch.upper()].value
|
||||
except KeyError:
|
||||
return cls.DEFAULT.value
|
||||
|
||||
|
||||
@functools.cache
|
||||
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
||||
try:
|
||||
device_shared_mem_list = get_all_max_shared_mem()
|
||||
max_shared_memory = device_shared_mem_list[tensor_idx]
|
||||
return max_shared_memory >= Backend.get_shared_memory(arch)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_autotune_config(
|
||||
multibuffer_list: tuple = (False,),
|
||||
unit_flag_list: tuple = (False,),
|
||||
limit_auto_multi_buffer_only_for_local_buffer_list: tuple = (False,),
|
||||
limit_auto_multi_buffer_of_local_buffer_list: tuple = ("no-l0c",),
|
||||
set_workspace_multibuffer_list: tuple = (2, 4),
|
||||
enable_hivm_auto_cv_balance_list: tuple = (True,),
|
||||
tile_mix_vector_loop_num_list: tuple = (2, 4),
|
||||
tile_mix_cube_loop_num_list: tuple = (2, 4),
|
||||
):
|
||||
configs = []
|
||||
for (
|
||||
multibuffer,
|
||||
unit_flag,
|
||||
limit_auto_multi_buffer_only_for_local_buffer,
|
||||
limit_auto_multi_buffer_of_local_buffer,
|
||||
) in itertools.product(
|
||||
list(multibuffer_list),
|
||||
list(unit_flag_list),
|
||||
list(limit_auto_multi_buffer_only_for_local_buffer_list),
|
||||
list(limit_auto_multi_buffer_of_local_buffer_list),
|
||||
):
|
||||
base_config_dict = {
|
||||
"multibuffer": multibuffer,
|
||||
"unit_flag": unit_flag,
|
||||
"limit_auto_multi_buffer_only_for_local_buffer": limit_auto_multi_buffer_only_for_local_buffer,
|
||||
"limit_auto_multi_buffer_of_local_buffer": limit_auto_multi_buffer_of_local_buffer,
|
||||
}
|
||||
|
||||
if limit_auto_multi_buffer_only_for_local_buffer:
|
||||
configs.append(triton.Config(base_config_dict))
|
||||
else:
|
||||
for (
|
||||
set_workspace_multibuffer,
|
||||
enable_hivm_auto_cv_balance,
|
||||
tile_mix_vector_loop,
|
||||
tile_mix_cube_loop,
|
||||
) in itertools.product(
|
||||
list(set_workspace_multibuffer_list),
|
||||
list(enable_hivm_auto_cv_balance_list),
|
||||
list(tile_mix_vector_loop_num_list),
|
||||
list(tile_mix_cube_loop_num_list),
|
||||
):
|
||||
full_config_dict = base_config_dict.copy()
|
||||
full_config_dict.update(
|
||||
{
|
||||
"set_workspace_multibuffer": set_workspace_multibuffer,
|
||||
"enable_hivm_auto_cv_balance": enable_hivm_auto_cv_balance,
|
||||
"tile_mix_vector_loop": tile_mix_vector_loop,
|
||||
"tile_mix_cube_loop": tile_mix_cube_loop,
|
||||
}
|
||||
)
|
||||
configs.append(triton.Config(full_config_dict))
|
||||
return configs
|
||||
|
||||
|
||||
def get_npu_properties():
|
||||
return driver.active.utils.get_device_properties(torch.npu.current_device())
|
||||
387
src/llamafactory/third_party/triton/wy_fast.py
vendored
Normal file
387
src/llamafactory/third_party/triton/wy_fast.py
vendored
Normal file
@@ -0,0 +1,387 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import exp, prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def prepare_wy_repr_bwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
g,
|
||||
A,
|
||||
dw,
|
||||
du,
|
||||
dk,
|
||||
dv,
|
||||
dbeta,
|
||||
dg,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
core_id = tl.program_id(0)
|
||||
total_cores = tl.num_programs(0)
|
||||
T_max = T
|
||||
|
||||
base_chunks_per_pid = NT // total_cores
|
||||
remainder_chunks = NT % total_cores
|
||||
|
||||
if core_id < remainder_chunks:
|
||||
chunks_this_pid = base_chunks_per_pid + 1
|
||||
start_idx = core_id * chunks_this_pid
|
||||
else:
|
||||
chunks_this_pid = base_chunks_per_pid
|
||||
start_idx = core_id * chunks_this_pid + remainder_chunks
|
||||
|
||||
for idx in range(start_idx, start_idx + chunks_this_pid):
|
||||
for i_b in range(B):
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + idx * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
i_t = idx
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
for i_h in range(0, H):
|
||||
if IS_VARLEN:
|
||||
offset = bos + i_h * T_max
|
||||
else:
|
||||
offset = bos * H + i_h * T_max
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)
|
||||
)
|
||||
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_g_exp = tl.exp(b_g)
|
||||
|
||||
b_dbeta = tl.zeros([BT], dtype=tl.float32)
|
||||
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
b_dg = tl.zeros([BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_dk = tl.make_block_ptr(
|
||||
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_dw = tl.make_block_ptr(
|
||||
dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_k_beta_g = (b_k * b_beta[:, None] * b_g_exp[:, None]).to(b_k.dtype)
|
||||
b_dw = tl.load(p_dw, boundary_check=(0, 1))
|
||||
b_dA += tl.dot(b_dw, tl.trans(b_k_beta_g))
|
||||
b_dk_beta_g = tl.dot(b_A, b_dw)
|
||||
b_dk = b_dk_beta_g * b_beta[:, None] * b_g_exp[:, None]
|
||||
b_dbeta += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None], 1)
|
||||
b_dg += tl.sum(b_dk_beta_g * b_k * b_g_exp[:, None] * b_beta[:, None], 1)
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(
|
||||
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
p_dv = tl.make_block_ptr(
|
||||
dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
p_du = tl.make_block_ptr(
|
||||
du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_du = tl.load(p_du, boundary_check=(0, 1))
|
||||
b_dA += tl.dot(b_du, tl.trans(b_v_beta))
|
||||
b_dv_beta = tl.dot(b_A, b_du)
|
||||
b_dv = b_dv_beta * b_beta[:, None]
|
||||
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
b_dA = tl.where(m_A, b_dA, 0)
|
||||
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
|
||||
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
|
||||
b_dA = tl.where(m_A, -b_dA * exp(b_g[:, None] - b_g[None, :]), 0)
|
||||
b_dA = b_dA.to(k.dtype.element_ty)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_dk = tl.make_block_ptr(
|
||||
dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_dk = tl.load(p_dk, boundary_check=(0, 1))
|
||||
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
|
||||
b_A += tl.dot(b_k_beta, tl.trans(b_k))
|
||||
b_dk_beta = tl.dot(b_dA, b_k)
|
||||
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
|
||||
b_dk += tl.dot(tl.trans(b_dA), b_k_beta)
|
||||
b_dk += b_dk_beta * b_beta[:, None]
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
b_dA_A = b_dA * b_A
|
||||
b_dg += tl.sum(b_dA_A, axis=1) - tl.sum(b_dA_A, axis=0)
|
||||
p_dg = tl.make_block_ptr(dg + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
p_dbeta = tl.make_block_ptr(dbeta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
|
||||
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_GK": lambda args: args["gk"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def recompute_w_u_fwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
g,
|
||||
gk,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T_tmp,
|
||||
B,
|
||||
H: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_GK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
core_id = tl.program_id(0)
|
||||
total_cores = tl.num_programs(0)
|
||||
T_max = T_tmp
|
||||
|
||||
base_chunks_per_pid = NT // total_cores
|
||||
remainder_chunks = NT % total_cores
|
||||
|
||||
if core_id < remainder_chunks:
|
||||
chunks_this_pid = base_chunks_per_pid + 1
|
||||
start_idx = core_id * chunks_this_pid
|
||||
else:
|
||||
chunks_this_pid = base_chunks_per_pid
|
||||
start_idx = core_id * chunks_this_pid + remainder_chunks
|
||||
|
||||
for idx in range(start_idx, start_idx + chunks_this_pid):
|
||||
for i_b in range(B):
|
||||
for i_h in range(0, H):
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + idx * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + idx * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
offset = bos + i_h * T_max
|
||||
T = eos - bos
|
||||
else:
|
||||
T = T_tmp
|
||||
i_t = idx
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
offset = bos * H + i_h * T_max
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(
|
||||
v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
p_u = tl.make_block_ptr(
|
||||
u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(g + offset, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_w = tl.make_block_ptr(
|
||||
w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = b_k * b_beta[:, None]
|
||||
if USE_G:
|
||||
b_kb *= b_g[:, None]
|
||||
if USE_GK:
|
||||
p_gk = tl.make_block_ptr(
|
||||
gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
b_kb *= tl.exp(tl.load(p_gk, boundary_check=(0, 1)))
|
||||
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def recompute_w_u_fwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
gk: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
BT = A.shape[-1]
|
||||
BK = 128
|
||||
BV = 128
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g = g.transpose(1, 2).contiguous() if g is not None else None
|
||||
beta = beta.transpose(1, 2).contiguous()
|
||||
|
||||
w = torch.empty_like(k)
|
||||
u = torch.empty_like(v)
|
||||
cv_kernel_num = 24
|
||||
recompute_w_u_fwd_kernel[(cv_kernel_num,)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g,
|
||||
gk=gk,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T_tmp=T,
|
||||
B=B,
|
||||
H=H,
|
||||
K=K,
|
||||
V=V,
|
||||
NT=NT,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
return w, u
|
||||
|
||||
|
||||
def prepare_wy_repr_bwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
dw: torch.Tensor,
|
||||
du: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
chunk_size: int = 64,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 128
|
||||
BV = 128
|
||||
beta = beta.transpose(1, 2).contiguous()
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
dbeta = torch.empty_like(beta)
|
||||
dg = torch.empty_like(g)
|
||||
|
||||
cv_kernel_num = 24
|
||||
prepare_wy_repr_bwd_kernel[(cv_kernel_num,)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
g=g,
|
||||
A=A,
|
||||
dw=dw,
|
||||
du=du,
|
||||
dk=dk,
|
||||
dv=dv,
|
||||
dbeta=dbeta,
|
||||
dg=dg,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
K=K,
|
||||
V=V,
|
||||
NT=NT,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
|
||||
dbeta = dbeta.transpose(1, 2).contiguous()
|
||||
dg = dg.transpose(1, 2).contiguous()
|
||||
|
||||
return dk, dv, dbeta, dg
|
||||
|
||||
|
||||
bwd_prepare_wy_repr = prepare_wy_repr_bwd
|
||||
|
||||
fwd_recompute_w_u = recompute_w_u_fwd
|
||||
@@ -20,7 +20,6 @@ import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@@ -584,7 +583,7 @@ class ModuleProfilerCallback(TrainerCallback):
|
||||
if matched:
|
||||
logger.info_rank0(
|
||||
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
|
||||
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
|
||||
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
|
||||
)
|
||||
else:
|
||||
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
|
||||
@@ -616,7 +615,7 @@ class ModuleProfilerCallback(TrainerCallback):
|
||||
bwd = self._backward_times.get(name, [])
|
||||
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
|
||||
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
|
||||
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
|
||||
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
|
||||
|
||||
logger.info_rank0("\n".join(lines))
|
||||
self._forward_times.clear()
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_sft
|
||||
from .workflow import run_pt, run_sft
|
||||
|
||||
|
||||
__all__ = ["run_sft"]
|
||||
__all__ = ["run_pt", "run_sft"]
|
||||
|
||||
222
src/llamafactory/train/hyper_parallel/trainer.py
Normal file
222
src/llamafactory/train/hyper_parallel/trainer.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""HyperParallel distributed trainer for LlamaFactory."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import types
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from hyper_parallel.integration.llamafactory import (
|
||||
HSDPModule,
|
||||
HyperParallelArguments,
|
||||
export_to_hf_format,
|
||||
fsdp2_prepare_model,
|
||||
hsdp_sync_stream,
|
||||
load_hsdp_model,
|
||||
load_hsdp_optimizer_and_scheduler,
|
||||
save_hsdp_checkpoint,
|
||||
wrap_optimizer_with_skip_dtensor_dispatch,
|
||||
)
|
||||
from hyper_parallel.integration.llamafactory import (
|
||||
clip_grad_norm_ as hp_clip_grad_norm_,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from ..sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
|
||||
|
||||
Inherits CustomSeq2SeqTrainer for training algorithm logic (loss, metrics,
|
||||
prediction, sampler, etc.) and only overrides HSDP-specific behavior.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hp_args: HyperParallelArguments,
|
||||
finetuning_args=None,
|
||||
processor=None,
|
||||
ref_model: Optional[nn.Module] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._hp_args = hp_args
|
||||
|
||||
# Let CustomSeq2SeqTrainer handle everything except ref_model —
|
||||
# Custom would prepare it with accelerate's fsdp2_prepare_model,
|
||||
# but we need HP's version instead.
|
||||
super().__init__(
|
||||
finetuning_args=finetuning_args,
|
||||
processor=processor,
|
||||
ref_model=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not getattr(self.accelerator, "is_fsdp2", False):
|
||||
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
|
||||
|
||||
# Prepare ref_model with HP's fsdp2_prepare_model
|
||||
self.ref_model = ref_model
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
|
||||
|
||||
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
|
||||
self._orig_fsdp2_prepare_model = None
|
||||
self._accelerator_patches_active = False
|
||||
|
||||
def _activate_accelerator_patches(self) -> None:
|
||||
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
|
||||
if self._accelerator_patches_active:
|
||||
return
|
||||
|
||||
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
||||
|
||||
hp_args = self._hp_args
|
||||
|
||||
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
|
||||
|
||||
def _hp_fsdp2_prepare_model(accelerator, model):
|
||||
return fsdp2_prepare_model(accelerator, model, hp_args)
|
||||
|
||||
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
|
||||
|
||||
def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2):
|
||||
if getattr(accelerator, "is_fsdp2", False):
|
||||
accelerator.unscale_gradients()
|
||||
parameter_list = list(parameters)
|
||||
parameter_ids = {id(param) for param in parameter_list}
|
||||
for model in accelerator._models: # pylint: disable=protected-access
|
||||
if not isinstance(model, HSDPModule):
|
||||
continue
|
||||
model_param_ids = {id(param) for param in model.parameters()}
|
||||
if parameter_ids and parameter_ids.issubset(model_param_ids):
|
||||
return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type)
|
||||
return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
|
||||
|
||||
self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator)
|
||||
self._accelerator_patches_active = True
|
||||
|
||||
def _restore_accelerator_patches(self) -> None:
|
||||
"""Restore original Accelerate methods."""
|
||||
if not self._accelerator_patches_active:
|
||||
return
|
||||
|
||||
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
||||
|
||||
if self._orig_fsdp2_prepare_model is not None:
|
||||
acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model
|
||||
self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm
|
||||
self._accelerator_patches_active = False
|
||||
|
||||
def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module:
|
||||
"""Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct."""
|
||||
del dataloader
|
||||
if isinstance(model, HSDPModule):
|
||||
return model
|
||||
if training and getattr(self.accelerator, "is_fsdp2", False):
|
||||
return model
|
||||
return super()._wrap_model(model, training=training)
|
||||
|
||||
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
|
||||
"""Skip redundant device moves for HSDP-wrapped models."""
|
||||
if isinstance(model, HSDPModule):
|
||||
return model
|
||||
if device is None:
|
||||
return model
|
||||
return model.to(device)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
"""Activate HP patches during training and restore afterwards."""
|
||||
self._activate_accelerator_patches()
|
||||
try:
|
||||
return super().train(*args, **kwargs)
|
||||
finally:
|
||||
self._restore_accelerator_patches()
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Any],
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Standard training step with HSDP gradient synchronization."""
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
|
||||
if isinstance(model, HSDPModule):
|
||||
model.set_is_last_backward(sync_gradients)
|
||||
model.set_requires_gradient_sync(sync_gradients)
|
||||
|
||||
compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext)
|
||||
with compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
|
||||
if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if isinstance(model, HSDPModule) and sync_gradients:
|
||||
hsdp_sync_stream()
|
||||
|
||||
return loss.detach()
|
||||
|
||||
def create_optimizer(self):
|
||||
"""Create optimizer and wrap step with SkipDTensorDispatch."""
|
||||
optimizer = super().create_optimizer()
|
||||
wrap_optimizer_with_skip_dtensor_dispatch(optimizer)
|
||||
return optimizer
|
||||
|
||||
def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
|
||||
"""Save model/optimizer shards per-rank and scheduler."""
|
||||
save_hsdp_checkpoint(
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
output_dir=output_dir,
|
||||
should_save_scheduler=self.args.should_save and self.lr_scheduler is not None,
|
||||
)
|
||||
|
||||
def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None:
|
||||
"""Load model from HSDP sharded checkpoint."""
|
||||
target = model if model is not None else self.model
|
||||
loaded = load_hsdp_model(target, resume_from_checkpoint)
|
||||
if not loaded:
|
||||
return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
|
||||
self._pending_hsdp_checkpoint = resume_from_checkpoint
|
||||
return None
|
||||
|
||||
def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None:
|
||||
"""Load optimizer/scheduler from per-rank checkpoint files."""
|
||||
ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint
|
||||
if ckpt_dir is None:
|
||||
return
|
||||
load_hsdp_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, ckpt_dir)
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
||||
"""Save model weights in HuggingFace-compatible format."""
|
||||
save_dir = output_dir or self.args.output_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir)
|
||||
@@ -12,8 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
@@ -21,9 +24,9 @@ from ...extras.misc import calculate_tps
|
||||
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
|
||||
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||
from .trainer import HyperParallelTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,6 +38,90 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelArguments"):
|
||||
r"""Load HyperParallel arguments and apply LlamaFactory-side overrides.
|
||||
|
||||
When activation optimization is enabled, skip native gradient checkpointing
|
||||
so HP can install its own via ``setup_activation_optimization``.
|
||||
"""
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||
|
||||
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
|
||||
|
||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||
if hp_args.activation_mode != "none":
|
||||
model_args.disable_gradient_checkpointing = True
|
||||
return hp_args
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
hp_args = _prepare_hp_args(finetuning_args, model_args)
|
||||
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
trainer = HyperParallelTrainer(
|
||||
hp_args=hp_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||
else:
|
||||
keys += ["eval_loss"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
for key in dataset_module["eval_dataset"].keys():
|
||||
try:
|
||||
perplexity = math.exp(metrics[f"eval_{key}_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
metrics[f"eval_{key}_perplexity"] = perplexity
|
||||
else:
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
metrics["eval_perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -43,13 +130,7 @@ def run_sft(
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||
|
||||
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
|
||||
HyperParallelArguments,
|
||||
HyperParallelTrainer,
|
||||
)
|
||||
hp_args = _prepare_hp_args(finetuning_args, model_args)
|
||||
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
@@ -94,25 +175,6 @@ def run_sft(
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
|
||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||
|
||||
callbacks = list(callbacks or [])
|
||||
processor = tokenizer_module.get("processor")
|
||||
if processor is not None:
|
||||
callbacks.append(SaveProcessorCallback(processor))
|
||||
|
||||
compute_loss_func = None
|
||||
if finetuning_args.use_dft_loss:
|
||||
compute_loss_func = dft_loss_func
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
elif finetuning_args.use_asft_loss:
|
||||
from functools import partial
|
||||
|
||||
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
|
||||
|
||||
trainer = HyperParallelTrainer(
|
||||
hp_args=hp_args,
|
||||
model=model,
|
||||
@@ -122,20 +184,11 @@ def run_sft(
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
ref_model=ref_model,
|
||||
compute_loss_func=compute_loss_func,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from types import MethodType
|
||||
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
|
||||
|
||||
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
|
||||
trainer.add_callback(BAdamCallback)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
|
||||
@@ -80,20 +80,27 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
if finetuning_args.early_stopping_steps is not None:
|
||||
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||
|
||||
if training_args.enable_torch_profiler:
|
||||
if getattr(training_args, "enable_torch_profiler", False):
|
||||
callbacks.append(TorchProfilerCallback(training_args))
|
||||
|
||||
if training_args.profile_modules:
|
||||
if getattr(training_args, "profile_modules", None):
|
||||
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||
if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel:
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||
from .hyper_parallel import run_sft as run_sft_hp
|
||||
raise ImportError(
|
||||
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
||||
)
|
||||
if finetuning_args.stage == "pt":
|
||||
from .hyper_parallel import run_pt as run_pt_hp
|
||||
|
||||
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
from .hyper_parallel import run_sft as run_sft_hp
|
||||
|
||||
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||
if not is_mcore_adapter_available():
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils.types import AttentionFunction
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionFunction",
|
||||
"BatchingStrategy",
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
|
||||
@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
model_args, data_args, training_args, sample_args = parsed_args
|
||||
# Seed as early as possible after argument parsing so all downstream
|
||||
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||
for arg in parsed_args:
|
||||
seed = getattr(arg, "seed", None)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
break
|
||||
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
|
||||
|
||||
return tuple(parsed_args)
|
||||
return model_args, data_args, training_args, sample_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..utils.types import AttentionFunction
|
||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@@ -32,6 +33,12 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
flash_attn: AttentionFunction = field(
|
||||
default=AttentionFunction.SDPA,
|
||||
metadata={
|
||||
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
|
||||
},
|
||||
)
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
@@ -54,6 +61,12 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
supported_flash_attn = [item.value for item in AttentionFunction]
|
||||
if self.flash_attn not in supported_flash_attn:
|
||||
raise ValueError(
|
||||
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
|
||||
)
|
||||
|
||||
self.init_config = get_plugin_config(self.init_config)
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
|
||||
@@ -85,6 +85,10 @@ class TrainingArguments:
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
full_determinism: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
|
||||
)
|
||||
resume_from_checkpoint: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
|
||||
@@ -116,3 +120,9 @@ class TrainingArguments:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
self.optim_config = get_plugin_config(self.optim_config)
|
||||
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
|
||||
|
||||
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
|
||||
if self.max_steps is None or self.max_steps <= 0:
|
||||
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
|
||||
if self.save_epochs is not None:
|
||||
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..accelerator.helper import ReduceOp
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import TrainingArguments
|
||||
from ..config import BatchingStrategy, TrainingArguments
|
||||
from ..utils import logging
|
||||
from ..utils.callbacks import (
|
||||
CallbackHandler,
|
||||
@@ -147,13 +147,19 @@ class BaseTrainer:
|
||||
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
|
||||
|
||||
if model.config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_rank0(
|
||||
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
|
||||
raise ValueError(
|
||||
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
|
||||
)
|
||||
model.config._attn_implementation = "flash_attention_2"
|
||||
|
||||
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
if (
|
||||
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
|
||||
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
|
||||
):
|
||||
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
|
||||
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
renderer=self.renderer,
|
||||
@@ -237,6 +243,7 @@ class BaseTrainer:
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||
|
||||
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
|
||||
for micro_batches in self.train_batch_generator:
|
||||
self.global_step += 1
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ class ModelEngine:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
|
||||
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
|
||||
|
||||
if self.args.quant_config is not None:
|
||||
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||
@@ -164,6 +165,7 @@ class ModelEngine:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
attn_implementation=self.args.flash_attn,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
**init_kwargs,
|
||||
)
|
||||
@@ -188,9 +190,12 @@ class ModelEngine:
|
||||
if self.args.kernel_config is not None:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
kernel_config = self.args.kernel_config
|
||||
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
|
||||
if kernel_config.name == "liger_kernel":
|
||||
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
|
||||
kernel_kwargs["require_logits"] = self.is_train
|
||||
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -42,6 +42,8 @@ from .rendering import Renderer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
__all__ = ["BatchGenerator"]
|
||||
|
||||
|
||||
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_size = batch_info["micro_batch_size"]
|
||||
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
|
||||
if not self.drop_last:
|
||||
raise ValueError("Drop last must be True.")
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
}
|
||||
|
||||
self._init_data_provider()
|
||||
|
||||
self._is_resuming: bool = False
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._buffer = StatefulBuffer()
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
"data_iter": self._data_iter,
|
||||
}
|
||||
|
||||
logger.info_rank0(
|
||||
f"Init unified data loader with global batch size {self.global_batch_size}, "
|
||||
f"micro batch size {self.micro_batch_size}, "
|
||||
@@ -137,12 +138,19 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
batch_size = self.micro_batch_size * self.num_micro_batch
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
|
||||
|
||||
generator_seed = torch.Generator()
|
||||
generator_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.batching_workers,
|
||||
collate_fn=self.renderer.process_samples,
|
||||
@@ -156,8 +164,7 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
|
||||
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
|
||||
|
||||
def _generate_batch(self) -> list[BatchInput] | None:
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
|
||||
|
||||
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
|
||||
|
||||
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
|
||||
try:
|
||||
return next(self._data_iter)
|
||||
except StopIteration:
|
||||
if not restart:
|
||||
return None
|
||||
|
||||
# Dynamic batching may restart the provider to fill one token-budgeted batch.
|
||||
self._data_iter = iter(self._data_provider)
|
||||
try:
|
||||
return next(self._data_iter)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"buffer": self._buffer.state_dict(),
|
||||
|
||||
@@ -34,7 +34,7 @@ class BaseKernel(ABC):
|
||||
"""
|
||||
|
||||
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
|
||||
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
|
||||
_device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc.
|
||||
|
||||
@classmethod
|
||||
def get_kernel_id(cls) -> str:
|
||||
@@ -42,8 +42,8 @@ class BaseKernel(ABC):
|
||||
return cls._kernel_id
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
def get_device(cls) -> list[DeviceType]:
|
||||
"""Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"])."""
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
@@ -58,7 +58,7 @@ class BaseKernel(ABC):
|
||||
it should raise an error instead of silently switching.
|
||||
Kernels can override this method to implement custom dependency checks.
|
||||
"""
|
||||
if cls._device != get_current_accelerator().type:
|
||||
if get_current_accelerator().type not in cls._device:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode
|
||||
apply_kernel(kernel, model=model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@KernelPlugin("liger_kernel").register()
|
||||
def apply_liger_kernels(
|
||||
model: HFModel,
|
||||
include_kernels: str = None,
|
||||
require_logits: bool = False,
|
||||
) -> HFModel:
|
||||
"""Applies Liger kernel to the model.
|
||||
|
||||
Args:
|
||||
model (HFModel): The model instance to apply kernels to.
|
||||
include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with
|
||||
library defaults. If a comma-separated list (e.g.
|
||||
``rope,rms_norm``), enable only those ops; names match
|
||||
``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``,
|
||||
``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``.
|
||||
If ``None`` or ``False``, do nothing. Defaults to ``None``.
|
||||
require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor
|
||||
of non-fused CE so the forward pass returns ``logits``. Needed
|
||||
for trainers that compute weighted loss from logits (e.g. v1
|
||||
SFT with ``loss_weights``). Defaults to ``False`` (fused CE
|
||||
when supported). The v1 ``run_sft`` entrypoint sets
|
||||
``require_logits`` to true for ``liger_kernel`` when the key
|
||||
is omitted so SFT weighted loss keeps working.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with Liger kernel applied.
|
||||
"""
|
||||
if not include_kernels:
|
||||
return model
|
||||
if include_kernels == "auto" or include_kernels is True:
|
||||
use_kernels = "auto"
|
||||
else:
|
||||
use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()]
|
||||
if not use_kernels:
|
||||
return model
|
||||
|
||||
try:
|
||||
from .liger_kernel_ops import LigerKernel
|
||||
except ImportError as e:
|
||||
logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}")
|
||||
return model
|
||||
|
||||
return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits)
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of Liger Kernel.
|
||||
|
||||
Init Phase:
|
||||
1. Define LigerKernel class.
|
||||
2. Register Liger kernel.
|
||||
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.types import HFModel
|
||||
from .base import BaseKernel
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = {
|
||||
"qwen3": "apply_liger_kernel_to_qwen3",
|
||||
"qwen3_moe": "apply_liger_kernel_to_qwen3_moe",
|
||||
"qwen3_next": "apply_liger_kernel_to_qwen3_next",
|
||||
"qwen3_5": "apply_liger_kernel_to_qwen3_5",
|
||||
"qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text",
|
||||
"qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe",
|
||||
"qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text",
|
||||
}
|
||||
|
||||
|
||||
class LigerKernel(BaseKernel):
|
||||
"""Liger Kernel for optimized model training."""
|
||||
|
||||
_device = [DeviceType.CUDA, DeviceType.NPU]
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
"""Checks if the required dependencies for the kernel are available."""
|
||||
try:
|
||||
import liger_kernel # noqa: F401
|
||||
|
||||
return super().check_deps()
|
||||
except ImportError:
|
||||
logger.warning_rank0(
|
||||
"Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel."
|
||||
)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
"""Applies the Liger kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op
|
||||
names to enable exclusively, or the string ``"auto"`` to use each
|
||||
``apply_liger_kernel_to_*`` function's signature defaults (same as calling
|
||||
upstream with only ``model``). Optional ``require_logits`` forces non-fused
|
||||
cross entropy when supported.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with Liger kernel applied.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not provided.
|
||||
RuntimeError: If dependencies are not met.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
use_kernels = kwargs.get("use_kernels", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(
|
||||
f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}"
|
||||
)
|
||||
|
||||
require_logits = kwargs.get("require_logits", False)
|
||||
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
|
||||
if model_type not in _LIGER_FN_BY_MODEL_TYPE:
|
||||
logger.warning_rank0("Current model does not support liger kernel.")
|
||||
return model
|
||||
|
||||
import liger_kernel.transformers as liger_transformers
|
||||
|
||||
apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type])
|
||||
|
||||
sig = inspect.signature(apply_liger_kernel).parameters
|
||||
togglable = [name for name in sig if name != "model"]
|
||||
|
||||
def _normalize_op_name(raw: str) -> str:
|
||||
key = raw.strip().lower().replace("-", "_")
|
||||
aliases = {
|
||||
"rmsnorm": "rms_norm",
|
||||
"flce": "fused_linear_cross_entropy",
|
||||
"lce": "fused_linear_cross_entropy",
|
||||
"fused_ce": "fused_linear_cross_entropy",
|
||||
}
|
||||
return aliases.get(key, key)
|
||||
|
||||
if use_kernels is not None and len(use_kernels) == 0:
|
||||
return model
|
||||
|
||||
if use_kernels != "auto":
|
||||
selected = {_normalize_op_name(k) for k in use_kernels}
|
||||
ops = selected - set(togglable)
|
||||
if ops:
|
||||
raise ValueError(
|
||||
f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}"
|
||||
)
|
||||
if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected:
|
||||
raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.")
|
||||
call_kwargs = {name: (name in selected) for name in togglable}
|
||||
call_kwargs["model"] = model
|
||||
else:
|
||||
# Mirror ``liger_kernel`` signature defaults so patches match upstream defaults
|
||||
# and logging reflects enabled ops (omitted kwargs only live in the callee).
|
||||
call_kwargs = {"model": model}
|
||||
for name in togglable:
|
||||
param = sig[name]
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
call_kwargs[name] = param.default
|
||||
|
||||
if require_logits and "fused_linear_cross_entropy" in sig:
|
||||
logger.warning_rank0("Current training stage does not support chunked cross entropy.")
|
||||
call_kwargs["fused_linear_cross_entropy"] = False
|
||||
call_kwargs["cross_entropy"] = True
|
||||
|
||||
apply_liger_kernel(**call_kwargs)
|
||||
|
||||
applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on)
|
||||
logger.info_rank0(f"These Liger ops are applied to the model: {applied}")
|
||||
|
||||
return model
|
||||
@@ -228,6 +228,30 @@ class NpuMoeFused:
|
||||
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
||||
return routed_out
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_v5_forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
|
||||
|
||||
Transformers v5 stores expert weights in F.linear layout:
|
||||
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
|
||||
down_proj: [num_experts, hidden_dim, intermediate_dim]
|
||||
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
|
||||
"""
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||
hidden_states, top_k_index.to(torch.int32)
|
||||
)
|
||||
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
|
||||
|
||||
gate_up_proj = self.gate_up_proj.transpose(1, 2)
|
||||
down_proj = self.down_proj.transpose(1, 2)
|
||||
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
|
||||
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
|
||||
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
|
||||
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
|
||||
|
||||
|
||||
# moe patch config mapping
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
||||
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3_5MoeForCausalLM": {
|
||||
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3_5MoeForConditionalGeneration": {
|
||||
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
|
||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
|
||||
else:
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
|
||||
},
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
||||
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
|
||||
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
|
||||
|
||||
.. note::
|
||||
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
|
||||
without hardcoding version numbers. Two methods are used: weight value inspection
|
||||
(most reliable) and class name pattern matching (backward compatibility).
|
||||
This must follow the module's forward semantics. Do not infer it from trained
|
||||
weight values because standard RMSNorm weights can also be close to zero.
|
||||
"""
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
weight_mean = module.weight.data.mean().item()
|
||||
if abs(weight_mean) < 0.3:
|
||||
return True
|
||||
residual_rmsnorm_classes = {
|
||||
"Qwen3_5RMSNorm",
|
||||
"Qwen3_5MoeRMSNorm",
|
||||
"Qwen3NextRMSNorm",
|
||||
}
|
||||
|
||||
class_name = module.__class__.__name__
|
||||
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
|
||||
for pattern in residual_patterns:
|
||||
if pattern in class_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
return class_name in residual_rmsnorm_classes
|
||||
|
||||
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
|
||||
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||
|
||||
if hasattr(self, "weight") and self.weight is not None:
|
||||
if _should_use_residual_rmsnorm(self):
|
||||
if getattr(self, "_npu_use_residual_rmsnorm", False):
|
||||
effective_weight = 1.0 + self.weight.float()
|
||||
else:
|
||||
effective_weight = self.weight.float()
|
||||
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
if "Gated" in module.__class__.__name__:
|
||||
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
|
||||
else:
|
||||
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
|
||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||
|
||||
return model
|
||||
|
||||
@@ -58,7 +58,7 @@ class Registry:
|
||||
device = kernel_cls.get_device()
|
||||
|
||||
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
|
||||
if device != get_current_accelerator().type:
|
||||
if get_current_accelerator().type not in device:
|
||||
return
|
||||
|
||||
if not kernel_id:
|
||||
|
||||
@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
|
||||
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
||||
# in shape : e.g., [s/p:h:]
|
||||
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
|
||||
|
||||
# scatter 2, gather 1
|
||||
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
||||
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||
@@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** -0.5
|
||||
|
||||
if attention_mask is None:
|
||||
if position_ids is not None:
|
||||
attention_mask = torch.ones_like(position_ids).to(torch.int64)
|
||||
else:
|
||||
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||
if position_ids is not None:
|
||||
global_position_ids = [
|
||||
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_position_ids, position_ids, group=self.spg)
|
||||
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
|
||||
global_attention_mask = [
|
||||
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||
global_attention_mask = [
|
||||
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||
|
||||
context_layer = self.attn_fn(
|
||||
q,
|
||||
|
||||
@@ -12,23 +12,272 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
from math import ceil
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils.data import default_collate
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import pad_and_truncate
|
||||
from ...utils.objects import StatefulBuffer
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import BatchInfo, BatchInput, DataLoader
|
||||
from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput
|
||||
|
||||
|
||||
class BatchingPlugin(BasePlugin):
|
||||
def compute_length(self, data_provider: DataLoader) -> int:
|
||||
def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int:
|
||||
"""Return the raw data provider batch size for this batching strategy."""
|
||||
return self["get_data_provider_batch_size"](batch_info)
|
||||
|
||||
def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
"""Compute the length of the batch generator.
|
||||
|
||||
The approximate length is used to calculate the lr schedule.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return self["compute_length"](data_provider, batch_info)
|
||||
|
||||
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
|
||||
def fill_buffer(
|
||||
self,
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
"""Fill the buffer with data."""
|
||||
raise NotImplementedError()
|
||||
return self["fill_buffer"](buffer, batch_info, next_samples)
|
||||
|
||||
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
"""Generate a batch from the buffer."""
|
||||
raise NotImplementedError()
|
||||
return self["generate_batch"](buffer, batch_info)
|
||||
|
||||
|
||||
def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
|
||||
"""Return sample counts for micro batches formed by one padded-token budget."""
|
||||
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
sizes = []
|
||||
index = 0
|
||||
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
|
||||
max_sample_len = 0
|
||||
used = 0
|
||||
is_complete = False
|
||||
while index + used < len(samples):
|
||||
sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len)
|
||||
padded_tokens = max(max_sample_len, sample_len) * (used + 1)
|
||||
if used > 0 and padded_tokens > budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
max_sample_len = max(max_sample_len, sample_len)
|
||||
used += 1
|
||||
if max_sample_len * used >= budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
if used == 0 or not is_complete:
|
||||
break
|
||||
|
||||
sizes.append(used)
|
||||
index += used
|
||||
|
||||
return sizes
|
||||
|
||||
|
||||
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
|
||||
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
sizes = []
|
||||
index = 0
|
||||
|
||||
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
|
||||
current_tokens = 0
|
||||
used = 0
|
||||
is_complete = False
|
||||
|
||||
while index + used < len(samples):
|
||||
sample = samples[index + used]
|
||||
sample_len = min(len(sample["input_ids"]), cutoff_len)
|
||||
|
||||
if current_tokens + sample_len > budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
current_tokens += sample_len
|
||||
used += 1
|
||||
|
||||
if used <= 0 or not is_complete:
|
||||
break
|
||||
|
||||
sizes.append(used)
|
||||
index += used
|
||||
|
||||
return sizes
|
||||
|
||||
|
||||
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
|
||||
"""Pack fixed samples into one padding-free sequence without a token budget."""
|
||||
packed: dict[str, list[Any]] = {}
|
||||
position_ids: list[int] = []
|
||||
|
||||
for sample_index, sample in enumerate(samples):
|
||||
# Padding-free still truncates each sample by cutoff_len before packing
|
||||
# all samples into one contiguous sequence.
|
||||
sample_len = min(len(sample["input_ids"]), cutoff_len)
|
||||
if sample_len <= 0:
|
||||
continue
|
||||
|
||||
for key, value in sample.items():
|
||||
if key in ("attention_mask", "position_ids") or isinstance(value, str):
|
||||
continue
|
||||
|
||||
if key not in packed:
|
||||
packed[key] = []
|
||||
|
||||
sliced_value = list(value[:sample_len])
|
||||
if sample_index > 0 and sliced_value:
|
||||
if key == "labels":
|
||||
sliced_value[0] = IGNORE_INDEX
|
||||
elif key == "loss_weights":
|
||||
sliced_value[0] = 0.0
|
||||
|
||||
packed[key].extend(sliced_value)
|
||||
|
||||
position_ids.extend(range(sample_len))
|
||||
|
||||
if not position_ids:
|
||||
return None
|
||||
|
||||
packed["position_ids"] = position_ids
|
||||
packed["attention_mask"] = [1] * len(position_ids)
|
||||
return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()}
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("get_data_provider_batch_size")
|
||||
def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("compute_length")
|
||||
def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
return len(data_provider)
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("fill_buffer")
|
||||
def fill_padding_free_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]:
|
||||
samples = next_samples(False)
|
||||
if samples is None:
|
||||
break
|
||||
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("generate_batch")
|
||||
def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_size = batch_info["micro_batch_size"]
|
||||
num_micro_batch = batch_info["num_micro_batch"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
batch_size = micro_batch_size * num_micro_batch
|
||||
if len(buffer) < batch_size:
|
||||
return None
|
||||
|
||||
samples = buffer.get(batch_size)
|
||||
batch = []
|
||||
for i in range(num_micro_batch):
|
||||
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
|
||||
packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len)
|
||||
if packed_micro_batch is None:
|
||||
return None
|
||||
|
||||
batch.append(packed_micro_batch)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size")
|
||||
def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("compute_length")
|
||||
def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
return ceil(len(data_provider) / batch_size)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("fill_buffer")
|
||||
def fill_dynamic_batching_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
|
||||
samples = next_samples(True)
|
||||
if samples is None:
|
||||
break
|
||||
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("generate_batch")
|
||||
def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info)
|
||||
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
|
||||
return None
|
||||
|
||||
batch = []
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
for num_samples in micro_batch_sample_counts:
|
||||
samples = buffer.get(num_samples)
|
||||
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
|
||||
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("compute_length")
|
||||
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
return ceil(len(data_provider) / batch_size)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
|
||||
def fill_dynamic_padding_free_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
|
||||
samples = next_samples(True)
|
||||
if samples is None:
|
||||
break
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
|
||||
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
|
||||
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
|
||||
return None
|
||||
|
||||
batch = []
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
|
||||
for num_samples in micro_batch_sample_counts:
|
||||
samples = buffer.get(num_samples)
|
||||
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
|
||||
if packed_batch is None:
|
||||
return None
|
||||
|
||||
batch.append(packed_batch)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt
|
||||
|
||||
@DistributedPlugin("deepspeed").register()
|
||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||
if dist_config.get("cp_size", 1) > 1:
|
||||
raise ValueError("CP currently requires `dist_config.name: fsdp2`.")
|
||||
|
||||
from .deepspeed import DeepSpeedEngine
|
||||
|
||||
return DeepSpeedEngine(
|
||||
|
||||
@@ -13,22 +13,46 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import set_seed as hf_set_seed
|
||||
|
||||
from ..accelerator.helper import is_torch_npu_available
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
def enable_full_determinism(seed: int) -> None:
|
||||
"""Enable full deterministic mode for reproducible distributed training."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.enabled = False
|
||||
if is_torch_npu_available():
|
||||
torch.npu.manual_seed(seed)
|
||||
torch.npu.manual_seed_all(seed)
|
||||
|
||||
|
||||
def set_seed(seed: int, full_determinism: bool = False) -> None:
|
||||
"""Set seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed: Random seed.
|
||||
full_determinism: Whether to enable full deterministic mode.
|
||||
"""
|
||||
hf_set_seed(seed)
|
||||
if full_determinism:
|
||||
enable_full_determinism(seed)
|
||||
else:
|
||||
hf_set_seed(seed)
|
||||
|
||||
|
||||
def is_tokenizer(processor: Processor) -> bool:
|
||||
|
||||
@@ -33,6 +33,10 @@ class StatefulBuffer:
|
||||
def size(self) -> int:
|
||||
return self._buffer_size
|
||||
|
||||
@property
|
||||
def samples(self) -> list[ModelInput]:
|
||||
return self._buffer
|
||||
|
||||
def put(self, samples: list[ModelInput]) -> None:
|
||||
"""Add samples to the buffer."""
|
||||
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from enum import StrEnum, unique
|
||||
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
@@ -54,6 +54,13 @@ else:
|
||||
ProcessGroup = None
|
||||
|
||||
|
||||
@unique
|
||||
class AttentionFunction(StrEnum):
|
||||
EAGER = "eager"
|
||||
SDPA = "sdpa"
|
||||
FLASH_ATTENTION_2 = "flash_attention_2"
|
||||
|
||||
|
||||
class DatasetInfo(TypedDict, total=False):
|
||||
path: str
|
||||
"""Local file path."""
|
||||
@@ -171,8 +178,6 @@ class BatchInfo(TypedDict):
|
||||
"""Number of micro batches."""
|
||||
cutoff_len: int
|
||||
"""Cutoff length."""
|
||||
data_iter: Iterator[list[ModelInput]]
|
||||
"""Data iterator."""
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
|
||||
149
tests/model/model_utils/test_embedding.py
Normal file
149
tests/model/model_utils/test_embedding.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.model.model_utils.embedding import (
|
||||
_description_based_initialization,
|
||||
_existing_embeddings,
|
||||
_noisy_mean_initialization,
|
||||
_resolve_new_token_ids,
|
||||
)
|
||||
|
||||
|
||||
class _StubTokenizer:
|
||||
"""Minimal tokenizer stub mapping token strings to fixed IDs."""
|
||||
|
||||
unk_token_id = 0
|
||||
|
||||
def __init__(self, mapping: dict[str, int], desc_ids: list[int] | None = None):
|
||||
self._mapping = mapping
|
||||
self._desc_ids = desc_ids or []
|
||||
|
||||
def convert_tokens_to_ids(self, token: str) -> int:
|
||||
return self._mapping.get(token, self.unk_token_id)
|
||||
|
||||
def __call__(self, desc, return_tensors=None, add_special_tokens=False):
|
||||
return {"input_ids": torch.tensor([self._desc_ids], dtype=torch.long)}
|
||||
|
||||
|
||||
class _StubModel:
|
||||
"""Wraps an embedding matrix so ``get_input_embeddings()`` is a usable lookup."""
|
||||
|
||||
def __init__(self, embed_weight: "torch.Tensor"):
|
||||
self._emb = torch.nn.Embedding.from_pretrained(embed_weight.clone(), freeze=True)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self._emb
|
||||
|
||||
|
||||
def test_resolve_new_token_ids_returns_none_without_config():
|
||||
tokenizer = _StubTokenizer({})
|
||||
assert _resolve_new_token_ids(None, tokenizer, embed_size=100) is None
|
||||
assert _resolve_new_token_ids([], tokenizer, embed_size=100) is None
|
||||
|
||||
|
||||
def test_resolve_new_token_ids_filters_invalid_and_dedups():
|
||||
# "<a>" valid, "<unk_like>" maps to unk_token_id (skipped), "<oob>" out of range (skipped)
|
||||
tokenizer = _StubTokenizer({"<a>": 10, "<unk_like>": 0, "<oob>": 999, "<b>": 5})
|
||||
# duplicates and unsorted input -> sorted unique in-range IDs
|
||||
tokens = ["<a>", "<a>", "<unk_like>", "<oob>", "<b>"]
|
||||
assert _resolve_new_token_ids(tokens, tokenizer, embed_size=100) == [5, 10]
|
||||
# passing a dict iterates its keys (config compatibility)
|
||||
assert _resolve_new_token_ids({"<a>": "desc"}, tokenizer, embed_size=100) == [10]
|
||||
|
||||
|
||||
def test_existing_embeddings_excludes_new_token_ids():
|
||||
embed_weight = torch.arange(10 * 2, dtype=torch.float32).reshape(10, 2)
|
||||
# explicit ids take precedence and drop exactly those rows
|
||||
existing = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=[2, 5])
|
||||
assert existing.size(0) == 8
|
||||
# tail fallback when no explicit ids
|
||||
tail = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=None)
|
||||
assert torch.allclose(tail, embed_weight[:-3])
|
||||
# no resize and no ids -> use everything
|
||||
everything = _existing_embeddings(embed_weight, num_new_tokens=0, new_token_ids=None)
|
||||
assert torch.allclose(everything, embed_weight)
|
||||
|
||||
|
||||
def test_noisy_mean_initialization_with_token_ids_targets_exact_rows():
|
||||
"""New tokens placed by explicit IDs must hit those rows, even inside the padding zone."""
|
||||
torch.manual_seed(0)
|
||||
vocab_size, embedding_dim = 20, 8
|
||||
embed_weight = torch.zeros(vocab_size, embedding_dim)
|
||||
# existing rows carry a constant so the mean is well-defined and non-zero
|
||||
embed_weight[:16] = 1.0
|
||||
|
||||
# num_new_tokens reflects the embedding resize delta (4 padded rows),
|
||||
# but the real new tokens sit at IDs 16 and 17 (inside what the tail slice would miss/over-cover).
|
||||
target_ids = [16, 17]
|
||||
_noisy_mean_initialization(embed_weight, num_new_tokens=4, token_ids=target_ids)
|
||||
|
||||
# targeted rows are initialized around the mean (~1.0) and not left at zero
|
||||
for tid in target_ids:
|
||||
assert not torch.allclose(embed_weight[tid], torch.zeros(embedding_dim))
|
||||
assert abs(embed_weight[tid].mean().item() - 1.0) < 0.5
|
||||
|
||||
# untouched padding rows (18, 19) must remain zero
|
||||
assert torch.allclose(embed_weight[18], torch.zeros(embedding_dim))
|
||||
assert torch.allclose(embed_weight[19], torch.zeros(embedding_dim))
|
||||
|
||||
|
||||
def test_noisy_mean_initialization_tail_fallback():
|
||||
"""Without token_ids, falls back to the last num_new_tokens rows."""
|
||||
torch.manual_seed(0)
|
||||
vocab_size, embedding_dim = 12, 8
|
||||
embed_weight = torch.zeros(vocab_size, embedding_dim)
|
||||
embed_weight[:10] = 1.0
|
||||
|
||||
_noisy_mean_initialization(embed_weight, num_new_tokens=2, token_ids=None)
|
||||
|
||||
# last two rows initialized, earlier rows untouched
|
||||
assert not torch.allclose(embed_weight[-1], torch.zeros(embedding_dim))
|
||||
assert not torch.allclose(embed_weight[-2], torch.zeros(embedding_dim))
|
||||
assert torch.allclose(embed_weight[0], torch.ones(embedding_dim))
|
||||
|
||||
|
||||
def test_description_init_excludes_new_token_ids_from_average():
|
||||
"""Description tokens that are themselves new (uninitialized) must be excluded.
|
||||
|
||||
Reproduces the padding-zone bug: id 17 is a new token and must not pollute the
|
||||
semantic average for id 16; only the valid existing token (id 5) should be used.
|
||||
"""
|
||||
vocab_size, embedding_dim = 20, 4
|
||||
embed_weight = torch.zeros(vocab_size, embedding_dim)
|
||||
embed_weight[5] = 3.0 # the only valid description token
|
||||
|
||||
# description for "<x>" tokenizes to [5 (existing), 17 (new -> must be skipped)]
|
||||
tokenizer = _StubTokenizer({"<x>": 16}, desc_ids=[5, 17])
|
||||
model = _StubModel(embed_weight)
|
||||
|
||||
_description_based_initialization(
|
||||
embed_weight,
|
||||
num_new_tokens=4,
|
||||
descriptions={"<x>": "ignored, ids come from the stub"},
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
new_token_ids=[16, 17],
|
||||
add_noise=False,
|
||||
)
|
||||
|
||||
# row 16 must equal embedding of id 5 only (3.0), not the (5,17) average (1.5)
|
||||
assert torch.allclose(embed_weight[16], torch.full((embedding_dim,), 3.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
pytest.main([__file__])
|
||||
@@ -58,10 +58,3 @@ def test_multi_device():
|
||||
master_port = find_available_port()
|
||||
world_size = 2
|
||||
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/accelerator/test_interface.py
|
||||
"""
|
||||
test_all_device()
|
||||
|
||||
@@ -70,13 +70,3 @@ def test_get_args_from_yaml(tmp_path: Path):
|
||||
assert training_args.bf16 is False
|
||||
assert training_args.dist_config is None
|
||||
assert sample_args.sample_backend == "hf"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.config.test_args_parser
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_get_args_from_yaml(tmp_path=Path(tmp_dir))
|
||||
|
||||
@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
|
||||
for index in indexes:
|
||||
print(data_engine[index])
|
||||
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.test_data_engine
|
||||
"""
|
||||
test_map_dataset(1)
|
||||
|
||||
@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
|
||||
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
|
||||
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.test_model_loader
|
||||
"""
|
||||
test_tiny_qwen()
|
||||
test_tiny_qwen_with_kernel_plugin()
|
||||
|
||||
@@ -16,6 +16,164 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||
from llamafactory.v1.plugins.trainer_plugins.batching import (
|
||||
BatchingPlugin,
|
||||
_get_dynamic_micro_batch_sizes,
|
||||
_get_dynamic_padding_free_micro_batch_sizes,
|
||||
)
|
||||
from llamafactory.v1.utils.constants import IGNORE_INDEX
|
||||
from llamafactory.v1.utils.objects import StatefulBuffer
|
||||
|
||||
|
||||
def _make_model_input(length: int, start: int = 0):
|
||||
input_ids = list(range(start, start + length))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": [1] * length,
|
||||
"labels": input_ids.copy(),
|
||||
"loss_weights": [1.0] * length,
|
||||
}
|
||||
|
||||
|
||||
class _RestartableDataProvider:
|
||||
def __init__(self, batches):
|
||||
self.batches = batches
|
||||
self.num_iters = 0
|
||||
|
||||
def __iter__(self):
|
||||
self.num_iters += 1
|
||||
return iter(self.batches)
|
||||
|
||||
|
||||
def test_padding_free():
|
||||
buffer = StatefulBuffer()
|
||||
# Input samples:
|
||||
# sample 0 input_ids: [0, 1]
|
||||
# sample 1 input_ids: [10, 11, 12, 13]
|
||||
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
|
||||
|
||||
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
|
||||
|
||||
# Output batch:
|
||||
# sample 1 is truncated to [10, 11, 12]
|
||||
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (1, 5)
|
||||
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
|
||||
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
|
||||
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
|
||||
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
|
||||
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
|
||||
assert len(buffer) == 0
|
||||
|
||||
|
||||
def test_batching_plugin_data_provider_batch_sizes():
|
||||
batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 3,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
|
||||
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
|
||||
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
|
||||
|
||||
|
||||
def test_dynamic_batching():
|
||||
# Input samples:
|
||||
# sample lengths: [3, 4, 6, 2, 8, 9]
|
||||
# input_ids:
|
||||
# [0, 1, 2]
|
||||
# [10, 11, 12, 13]
|
||||
# [20, 21, 22, 23, 24, 25]
|
||||
# [30, 31]
|
||||
# [40, 41, 42, 43, 44, 45, 46, 47]
|
||||
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
|
||||
samples = [
|
||||
_make_model_input(3, 0),
|
||||
_make_model_input(4, 10),
|
||||
_make_model_input(6, 20),
|
||||
_make_model_input(2, 30),
|
||||
_make_model_input(8, 40),
|
||||
_make_model_input(9, 50),
|
||||
]
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
# Dynamic batching output plan:
|
||||
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
|
||||
# as the padded-token budget for one training micro batch.
|
||||
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
|
||||
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
|
||||
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put(samples)
|
||||
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
|
||||
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (3, 6)
|
||||
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
|
||||
buffer = StatefulBuffer()
|
||||
# Input buffer:
|
||||
# only one sample with length [6].
|
||||
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
|
||||
# this buffer has not filled the budget and has no next sample to prove overflow,
|
||||
# so dynamic batching cannot produce a batch yet.
|
||||
buffer.put([_make_model_input(6, 0)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
|
||||
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
|
||||
# Batch generation does not read from the data iterator. It only returns None and keeps
|
||||
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
|
||||
assert len(buffer) == 1
|
||||
|
||||
|
||||
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
|
||||
# Input data provider:
|
||||
# each iterator pass yields one sample with length [6].
|
||||
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
|
||||
# _fill_buffer keeps restarting the iterator until the next appended sample
|
||||
# proves that the previous dynamic micro batch has reached its budget boundary.
|
||||
samples = [_make_model_input(6, 0)]
|
||||
data_provider = _RestartableDataProvider([[sample] for sample in samples])
|
||||
|
||||
batch_generator = BatchGenerator.__new__(BatchGenerator)
|
||||
batch_generator.batching_strategy = "dynamic_batching"
|
||||
batch_generator.micro_batch_size = 2
|
||||
batch_generator.num_micro_batch = 1
|
||||
batch_generator._buffer = StatefulBuffer()
|
||||
batch_generator._data_provider = data_provider
|
||||
batch_generator._data_iter = iter(data_provider)
|
||||
batch_generator._batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 1,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
batch_generator._fill_buffer()
|
||||
|
||||
# Filled buffer after restart:
|
||||
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
|
||||
# because adding it to the first dynamic micro batch would exceed the budget.
|
||||
assert data_provider.num_iters == 4
|
||||
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
|
||||
|
||||
batch = batch_generator._generate_batch()
|
||||
|
||||
# Output batch:
|
||||
# dynamic batching returns [micro_batch_0]
|
||||
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (3, 6)
|
||||
assert len(batch_generator._buffer) == 1
|
||||
|
||||
|
||||
def test_normal_batching():
|
||||
@@ -45,8 +203,166 @@ def test_normal_batching():
|
||||
assert batch[0]["input_ids"].shape == (4, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_dynamic_padding_free():
|
||||
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
|
||||
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
|
||||
# input_ids breakdown:
|
||||
# sample 0: [0,1,2] (length=3)
|
||||
# sample 1: [10,11,12,13] (length=4)
|
||||
# sample 2: [20,21,22,23,24,25] (length=6)
|
||||
# sample 3: [30,31] (length=2)
|
||||
# sample 4: [40-47] (length=8)
|
||||
# sample 5: [50-58] (length=9)
|
||||
samples = [
|
||||
_make_model_input(3, 0),
|
||||
_make_model_input(4, 10),
|
||||
_make_model_input(6, 20),
|
||||
_make_model_input(2, 30),
|
||||
_make_model_input(8, 40),
|
||||
_make_model_input(9, 50),
|
||||
]
|
||||
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
|
||||
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
|
||||
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put(samples)
|
||||
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
|
||||
|
||||
assert batch is not None
|
||||
assert len(batch) == 1 # num_micro_batch=1
|
||||
packed_batch = batch[0]
|
||||
|
||||
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
|
||||
assert packed_batch["input_ids"].shape == (1, 15)
|
||||
|
||||
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
|
||||
assert packed_batch["input_ids"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13, # Sample 1
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25, # Sample 2
|
||||
30,
|
||||
31,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
|
||||
assert packed_batch["labels"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
IGNORE_INDEX,
|
||||
11,
|
||||
12,
|
||||
13, # Sample 1
|
||||
IGNORE_INDEX,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25, # Sample 2
|
||||
IGNORE_INDEX,
|
||||
31,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify attention_mask
|
||||
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
|
||||
|
||||
# Verify position_ids
|
||||
assert packed_batch["position_ids"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3, # Sample 1
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5, # Sample 2
|
||||
0,
|
||||
1,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
|
||||
assert len(buffer) == 2
|
||||
|
||||
|
||||
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put([_make_model_input(6, 0)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
|
||||
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
|
||||
# Batch generation does not read from the data iterator. It only returns None and keeps
|
||||
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
|
||||
assert len(buffer) == 1
|
||||
|
||||
|
||||
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
|
||||
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
|
||||
|
||||
Data provider yields one sample of length 6 per iteration.
|
||||
_fill_buffer keeps restarting iterator until next sample exceeds budget.
|
||||
Budget = 2 * 10 = 20 tokens.
|
||||
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
|
||||
So buffer will have 4 samples after fill_buffer.
|
||||
"""
|
||||
python -m tests_v1.core.utils.test_batching
|
||||
"""
|
||||
test_normal_batching()
|
||||
samples = [_make_model_input(6, 0)]
|
||||
data_provider = _RestartableDataProvider([[sample] for sample in samples])
|
||||
|
||||
batch_generator = BatchGenerator.__new__(BatchGenerator)
|
||||
batch_generator.batching_strategy = "dynamic_padding_free"
|
||||
batch_generator.micro_batch_size = 2
|
||||
batch_generator.num_micro_batch = 1
|
||||
batch_generator._buffer = StatefulBuffer()
|
||||
batch_generator._data_provider = data_provider
|
||||
batch_generator._data_iter = iter(data_provider)
|
||||
batch_generator._batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 1,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
|
||||
batch_generator._fill_buffer()
|
||||
|
||||
# Buffer after restarts:
|
||||
# 3 samples can fit (18 tokens)
|
||||
# 4th sample is kept in buffer for next batch
|
||||
# => num_iters = 4
|
||||
assert data_provider.num_iters == 4
|
||||
assert _get_dynamic_padding_free_micro_batch_sizes(
|
||||
batch_generator._buffer.samples, batch_generator._batch_info
|
||||
) == [3]
|
||||
|
||||
batch = batch_generator._generate_batch()
|
||||
|
||||
# Output batch:
|
||||
# dynamic_padding_free returns [micro_batch_0]
|
||||
# 3 samples packed into shape [1, 18]
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (1, 18)
|
||||
assert len(batch_generator._buffer) == 1
|
||||
|
||||
@@ -227,17 +227,3 @@ def test_process_dpo_samples():
|
||||
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
|
||||
assert model_inputs[0]["extra_info"] == "test"
|
||||
assert model_inputs[0]["_dataset_name"] == "default"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.utils.test_rendering
|
||||
"""
|
||||
test_chatml_rendering()
|
||||
test_chatml_parse()
|
||||
test_chatml_rendering_remote(16)
|
||||
test_qwen3_nothink_rendering()
|
||||
test_qwen3_nothink_parse()
|
||||
test_qwen3_nothink_rendering_remote(16)
|
||||
test_process_sft_samples()
|
||||
test_process_dpo_samples()
|
||||
|
||||
@@ -117,12 +117,3 @@ def test_pair_converter(num_samples: int):
|
||||
],
|
||||
}
|
||||
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.plugins.data_plugins.test_converter
|
||||
"""
|
||||
test_alpaca_converter(1)
|
||||
test_sharegpt_converter()
|
||||
test_pair_converter(1)
|
||||
|
||||
@@ -52,12 +52,3 @@ def test_init_on_default():
|
||||
)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
assert model_engine.model.device == DistributedInterface().current_device
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/plugins/model_plugins/test_init_plugin.py
|
||||
"""
|
||||
test_init_on_meta()
|
||||
test_init_on_rank0()
|
||||
test_init_on_default()
|
||||
|
||||
@@ -35,10 +35,3 @@ def test_sync_sampler():
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": "This is a test."}],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/sampler/test_cli_sampler.py
|
||||
"""
|
||||
test_sync_sampler()
|
||||
|
||||
Reference in New Issue
Block a user