mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 12:48:55 +08:00
Compare commits
11 Commits
40e786d016
...
v0.9.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7af909522a | ||
|
|
e016d2480e | ||
|
|
7d719182c9 | ||
|
|
01398eb18d | ||
|
|
8e68764b65 | ||
|
|
16ff5a23cb | ||
|
|
bdcb92d035 | ||
|
|
7e20db5735 | ||
|
|
2322bf1cc2 | ||
|
|
368c48968f | ||
|
|
8b5ea65770 |
16
README.md
16
README.md
@@ -15,8 +15,6 @@
|
||||
|
||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
||||
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
||||
@@ -38,7 +36,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
|
||||
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
|
||||
|
||||
\[ English | [中文](README_zh.md) \]
|
||||
|
||||
@@ -52,14 +50,11 @@ Start local training:
|
||||
Start cloud training:
|
||||
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||
|
||||
Read technical notes:
|
||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
|
||||
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
|
||||
- **Official Blog**: https://blog.llamafactory.net/en/
|
||||
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||
|
||||
> [!NOTE]
|
||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||
@@ -78,7 +73,6 @@ Read technical notes:
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Quickstart](#quickstart)
|
||||
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
|
||||
- [LLaMA Factory Online](#llama-factory-online)
|
||||
- [Build Docker](#build-docker)
|
||||
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
|
||||
- [Download from ModelScope Hub](#download-from-modelscope-hub)
|
||||
@@ -117,15 +111,11 @@ Read technical notes:
|
||||
|
||||
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
|
||||
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
||||
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
|
||||
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
|
||||
|
||||
<details><summary>All Blogs</summary>
|
||||
|
||||
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
|
||||
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
|
||||
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
|
||||
@@ -661,10 +651,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### LLaMA Factory Online
|
||||
|
||||
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
|
||||
|
||||
### Build Docker
|
||||
|
||||
For CUDA users:
|
||||
|
||||
16
README_zh.md
16
README_zh.md
@@ -15,8 +15,6 @@
|
||||
|
||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
||||
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
||||
@@ -38,7 +36,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)、[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)。
|
||||
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)和 [NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
|
||||
|
||||
\[ [English](README.md) | 中文 \]
|
||||
|
||||
@@ -52,8 +50,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
开始云端训练:
|
||||
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **LLaMA Factory Online(在线微调)**:https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||
- **九章智算云(算力优惠活动)**:https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||
|
||||
阅读技术文档:
|
||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||
@@ -61,7 +57,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
|
||||
- **官方博客**:https://blog.llamafactory.net/
|
||||
- **官方课程**:https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||
|
||||
> [!NOTE]
|
||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- [数据准备](#数据准备)
|
||||
- [快速开始](#快速开始)
|
||||
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
|
||||
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
|
||||
- [构建 Docker](#构建-docker)
|
||||
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
|
||||
- [从魔搭社区下载](#从魔搭社区下载)
|
||||
@@ -119,15 +113,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
|
||||
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
|
||||
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
|
||||
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
|
||||
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
|
||||
|
||||
<details><summary>全部博客</summary>
|
||||
|
||||
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
|
||||
- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
|
||||
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
|
||||
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
|
||||
@@ -662,10 +652,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### LLaMA Factory Online 在线微调
|
||||
|
||||
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
|
||||
|
||||
### 构建 Docker
|
||||
|
||||
CUDA 用户:
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 2
|
||||
batching_strategy: normal
|
||||
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 2
|
||||
batching_strategy: dynamic_batching
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: dynamic_padding_free
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: padding_free
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
28
examples/v1/train_full/train_full_liger_kernel.yaml
Normal file
28
examples/v1/train_full/train_full_liger_kernel.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: liger_kernel
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -19,7 +19,7 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
VERSION = "0.9.5.dev0"
|
||||
VERSION = "0.9.5"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
TrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
|
||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||
|
||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_MCA_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_MCA_CLS = tuple[
|
||||
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
else:
|
||||
_TRAIN_MCA_ARGS = []
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -63,6 +64,58 @@ class RayArguments:
|
||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilerArguments:
|
||||
r"""Arguments for torch profiler configuration."""
|
||||
|
||||
enable_torch_profiler: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
|
||||
)
|
||||
profiler_output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
|
||||
)
|
||||
profiler_wait_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
|
||||
)
|
||||
profiler_warmup_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of profiler warm-up steps per cycle."},
|
||||
)
|
||||
profiler_active_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of steps to actively record per cycle."},
|
||||
)
|
||||
profiler_repeat: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
|
||||
)
|
||||
profiler_record_shapes: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to record tensor shapes during profiling."},
|
||||
)
|
||||
profiler_profile_memory: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to profile memory usage."},
|
||||
)
|
||||
profiler_with_stack: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to record stack traces during profiling."},
|
||||
)
|
||||
profile_modules: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Comma-separated list of module name patterns to profile with CUDA events. "
|
||||
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
|
||||
"Reports per-module forward/backward timing statistics at each logging step."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Arguments:
|
||||
r"""Arguments pertaining to the FP8 training."""
|
||||
@@ -87,7 +140,7 @@ class Fp8Arguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
|
||||
@@ -162,8 +162,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
||||
if position_ids is not None and position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
|
||||
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
|
||||
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
|
||||
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
|
||||
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
|
||||
if position_ids is not None and batch_size == 1:
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
||||
# standard causal-conv1d path that the upstream forward uses.
|
||||
|
||||
@@ -12,11 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
@@ -31,7 +33,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
|
||||
from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
|
||||
from ..extras.packages import is_safetensors_available
|
||||
|
||||
|
||||
@@ -338,6 +340,96 @@ class LogCallback(TrainerCallback):
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
|
||||
class TorchProfilerCallback(TrainerCallback):
|
||||
r"""A callback for collecting torch.profiler traces during training.
|
||||
|
||||
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
|
||||
|
||||
Configuration fields (in YAML):
|
||||
profiler_output_dir – where to write traces (default: <output_dir>/profiler)
|
||||
profiler_wait_steps – steps to skip at start of each cycle (default: 1)
|
||||
profiler_warmup_steps – profiler warm-up steps per cycle (default: 1)
|
||||
profiler_active_steps – steps to record per cycle (default: 1)
|
||||
profiler_repeat – number of cycles; 0 = forever (default: 1)
|
||||
profiler_record_shapes – record tensor shapes (default: true)
|
||||
profiler_profile_memory – profile memory usage (default: true)
|
||||
profiler_with_stack – record stack traces (default: true)
|
||||
|
||||
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
|
||||
``<profiler_output_dir>/rank_<N>/``.
|
||||
"""
|
||||
|
||||
def __init__(self, training_args: "TrainingArguments") -> None:
|
||||
self.profiler = None
|
||||
self.profiler_args = training_args
|
||||
|
||||
@staticmethod
|
||||
def _get_rank() -> int:
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
return 0
|
||||
|
||||
@override
|
||||
def on_train_begin(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.stop()
|
||||
self.profiler = None
|
||||
|
||||
pa = self.profiler_args
|
||||
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
|
||||
rank = self._get_rank()
|
||||
trace_dir = os.path.join(output_dir, f"rank_{rank}")
|
||||
os.makedirs(trace_dir, exist_ok=True)
|
||||
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
try:
|
||||
if is_torch_cuda_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
if is_torch_npu_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.NPU)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=activities,
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=pa.profiler_wait_steps,
|
||||
warmup=pa.profiler_warmup_steps,
|
||||
active=pa.profiler_active_steps,
|
||||
repeat=pa.profiler_repeat,
|
||||
),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||
record_shapes=pa.profiler_record_shapes,
|
||||
profile_memory=pa.profiler_profile_memory,
|
||||
with_stack=pa.profiler_with_stack,
|
||||
)
|
||||
self.profiler.start()
|
||||
logger.info_rank0(
|
||||
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
|
||||
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
|
||||
)
|
||||
|
||||
@override
|
||||
def on_step_end(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.step()
|
||||
|
||||
@override
|
||||
def on_train_end(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.stop()
|
||||
self.profiler = None
|
||||
logger.info_rank0("TorchProfiler stopped.")
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""A callback for reporting training status to external logger."""
|
||||
|
||||
@@ -394,3 +486,143 @@ class ReporterCallback(TrainerCallback):
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ModuleProfilerCallback(TrainerCallback):
|
||||
r"""Profile forward/backward time of specified modules using accelerator events.
|
||||
|
||||
Hooks are registered on modules matching the user-provided name patterns.
|
||||
Timing statistics are logged at each trainer logging step.
|
||||
|
||||
Usage in YAML config:
|
||||
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
|
||||
|
||||
Supports fnmatch wildcards:
|
||||
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_accelerator():
|
||||
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
|
||||
if is_torch_cuda_available():
|
||||
return torch.cuda.Event, torch.cuda.synchronize
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.Event, torch.npu.synchronize
|
||||
return None, None
|
||||
|
||||
def __init__(self, profile_modules: str) -> None:
|
||||
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
|
||||
self._create_event, self._synchronize = self._get_accelerator()
|
||||
self._handles: list[Any] = []
|
||||
self._forward_times: dict[str, list[float]] = defaultdict(list)
|
||||
self._backward_times: dict[str, list[float]] = defaultdict(list)
|
||||
self._pending_forward: dict[str, tuple] = {}
|
||||
self._pending_backward: dict[str, tuple] = {}
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._create_event is not None
|
||||
|
||||
def _match(self, name: str) -> bool:
|
||||
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
|
||||
|
||||
def _make_forward_pre_hook(self, name: str):
|
||||
def hook(module, input):
|
||||
start = self._create_event(enable_timing=True)
|
||||
end = self._create_event(enable_timing=True)
|
||||
start.record()
|
||||
self._pending_forward[name] = (start, end)
|
||||
|
||||
return hook
|
||||
|
||||
def _make_forward_hook(self, name: str):
|
||||
def hook(module, input, output):
|
||||
pair = self._pending_forward.get(name)
|
||||
if pair is not None:
|
||||
pair[1].record()
|
||||
|
||||
return hook
|
||||
|
||||
def _make_backward_pre_hook(self, name: str):
|
||||
def hook(module, grad_output):
|
||||
start = self._create_event(enable_timing=True)
|
||||
end = self._create_event(enable_timing=True)
|
||||
start.record()
|
||||
self._pending_backward[name] = (start, end)
|
||||
|
||||
return hook
|
||||
|
||||
def _make_backward_hook(self, name: str):
|
||||
def hook(module, grad_input, grad_output):
|
||||
pair = self._pending_backward.get(name)
|
||||
if pair is not None:
|
||||
pair[1].record()
|
||||
|
||||
return hook
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self.enabled:
|
||||
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
|
||||
return
|
||||
|
||||
model = kwargs.get("model")
|
||||
if model is None:
|
||||
return
|
||||
|
||||
matched = []
|
||||
for name, module in model.named_modules():
|
||||
if not name or not self._match(name):
|
||||
continue
|
||||
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
|
||||
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
|
||||
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
|
||||
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
|
||||
matched.append(name)
|
||||
|
||||
if matched:
|
||||
logger.info_rank0(
|
||||
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
|
||||
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
|
||||
)
|
||||
else:
|
||||
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
|
||||
|
||||
@override
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self._synchronize()
|
||||
|
||||
for name, (start, end) in self._pending_forward.items():
|
||||
self._forward_times[name].append(start.elapsed_time(end))
|
||||
self._pending_forward.clear()
|
||||
|
||||
for name, (start, end) in self._pending_backward.items():
|
||||
self._backward_times[name].append(start.elapsed_time(end))
|
||||
self._pending_backward.clear()
|
||||
|
||||
@override
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self._forward_times and not self._backward_times:
|
||||
return
|
||||
|
||||
lines = ["[ModuleProfiler] Timing (ms):"]
|
||||
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
|
||||
for name in all_names:
|
||||
fwd = self._forward_times.get(name, [])
|
||||
bwd = self._backward_times.get(name, [])
|
||||
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
|
||||
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
|
||||
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
|
||||
|
||||
logger.info_rank0("\n".join(lines))
|
||||
self._forward_times.clear()
|
||||
self._backward_times.clear()
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
for handle in self._handles:
|
||||
handle.remove()
|
||||
self._handles.clear()
|
||||
|
||||
@@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.running = RunningMoments(self.accelerator)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -69,10 +69,10 @@ class CustomTrainer(Trainer):
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer):
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
|
||||
@@ -32,7 +32,13 @@ from ..extras.packages import (
|
||||
)
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .callbacks import (
|
||||
LogCallback,
|
||||
ModuleProfilerCallback,
|
||||
PissaConvertCallback,
|
||||
ReporterCallback,
|
||||
TorchProfilerCallback,
|
||||
)
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .ppo import run_ppo
|
||||
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
if finetuning_args.early_stopping_steps is not None:
|
||||
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||
|
||||
if getattr(training_args, "enable_torch_profiler", False):
|
||||
callbacks.append(TorchProfilerCallback(training_args))
|
||||
|
||||
if getattr(training_args, "profile_modules", None):
|
||||
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils.types import AttentionFunction
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionFunction",
|
||||
"BatchingStrategy",
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
|
||||
@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
model_args, data_args, training_args, sample_args = parsed_args
|
||||
# Seed as early as possible after argument parsing so all downstream
|
||||
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||
for arg in parsed_args:
|
||||
seed = getattr(arg, "seed", None)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
break
|
||||
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
|
||||
|
||||
return tuple(parsed_args)
|
||||
return model_args, data_args, training_args, sample_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..utils.types import AttentionFunction
|
||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@@ -32,6 +33,12 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
flash_attn: AttentionFunction = field(
|
||||
default=AttentionFunction.SDPA,
|
||||
metadata={
|
||||
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
|
||||
},
|
||||
)
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
@@ -54,6 +61,12 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
supported_flash_attn = [item.value for item in AttentionFunction]
|
||||
if self.flash_attn not in supported_flash_attn:
|
||||
raise ValueError(
|
||||
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
|
||||
)
|
||||
|
||||
self.init_config = get_plugin_config(self.init_config)
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
|
||||
@@ -85,6 +85,10 @@ class TrainingArguments:
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
full_determinism: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
|
||||
)
|
||||
resume_from_checkpoint: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
|
||||
@@ -116,3 +120,9 @@ class TrainingArguments:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
self.optim_config = get_plugin_config(self.optim_config)
|
||||
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
|
||||
|
||||
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
|
||||
if self.max_steps is None or self.max_steps <= 0:
|
||||
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
|
||||
if self.save_epochs is not None:
|
||||
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..accelerator.helper import ReduceOp
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import TrainingArguments
|
||||
from ..config import BatchingStrategy, TrainingArguments
|
||||
from ..utils import logging
|
||||
from ..utils.callbacks import (
|
||||
CallbackHandler,
|
||||
@@ -134,6 +134,9 @@ class BaseTrainer:
|
||||
global_step=self.global_step,
|
||||
epoch=self._resume_epoch,
|
||||
)
|
||||
# Keep callback state aligned with checkpoint-resumed trainer counters.
|
||||
self.state.global_step = self.global_step
|
||||
self.state.epoch = self._resume_epoch
|
||||
|
||||
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||
@@ -144,13 +147,19 @@ class BaseTrainer:
|
||||
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
|
||||
|
||||
if model.config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_rank0(
|
||||
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
|
||||
raise ValueError(
|
||||
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
|
||||
)
|
||||
model.config._attn_implementation = "flash_attention_2"
|
||||
|
||||
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
if (
|
||||
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
|
||||
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
|
||||
):
|
||||
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
|
||||
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
renderer=self.renderer,
|
||||
@@ -234,6 +243,7 @@ class BaseTrainer:
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||
|
||||
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
|
||||
for micro_batches in self.train_batch_generator:
|
||||
self.global_step += 1
|
||||
|
||||
@@ -303,7 +313,7 @@ class BaseTrainer:
|
||||
if self.global_step % self.args.logging_steps == 0:
|
||||
logs = {
|
||||
"epoch": epoch,
|
||||
"step": self.global_step,
|
||||
"step": self.state.global_step,
|
||||
"loss": step_loss,
|
||||
"grad_norm": grad_norm,
|
||||
"learning_rate": current_lr,
|
||||
@@ -335,7 +345,9 @@ class BaseTrainer:
|
||||
)
|
||||
else:
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
model_to_save.save_pretrained(
|
||||
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
|
||||
)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ class ModelEngine:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
|
||||
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
|
||||
|
||||
if self.args.quant_config is not None:
|
||||
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||
@@ -143,6 +144,12 @@ class ModelEngine:
|
||||
elif self.args.model_class == ModelClass.CLS:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
self.model_config.num_labels = 1
|
||||
self.model_config.classifier_dropout = 0.0
|
||||
text_config = getattr(self.model_config, "text_config", None)
|
||||
if text_config is not None:
|
||||
text_config.num_labels = 1
|
||||
text_config.classifier_dropout = 0.0
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
@@ -158,6 +165,7 @@ class ModelEngine:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
attn_implementation=self.args.flash_attn,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
**init_kwargs,
|
||||
)
|
||||
@@ -182,9 +190,12 @@ class ModelEngine:
|
||||
if self.args.kernel_config is not None:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
kernel_config = self.args.kernel_config
|
||||
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
|
||||
if kernel_config.name == "liger_kernel":
|
||||
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
|
||||
kernel_kwargs["require_logits"] = self.is_train
|
||||
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -42,6 +42,8 @@ from .rendering import Renderer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
__all__ = ["BatchGenerator"]
|
||||
|
||||
|
||||
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_size = batch_info["micro_batch_size"]
|
||||
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
|
||||
if not self.drop_last:
|
||||
raise ValueError("Drop last must be True.")
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
}
|
||||
|
||||
self._init_data_provider()
|
||||
|
||||
self._is_resuming: bool = False
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._buffer = StatefulBuffer()
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
"data_iter": self._data_iter,
|
||||
}
|
||||
|
||||
logger.info_rank0(
|
||||
f"Init unified data loader with global batch size {self.global_batch_size}, "
|
||||
f"micro batch size {self.micro_batch_size}, "
|
||||
@@ -137,27 +138,33 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
generato_seed = torch.Generator()
|
||||
generato_seed.manual_seed(self.seed)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
batch_size = self.micro_batch_size * self.num_micro_batch
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
|
||||
|
||||
generator_seed = torch.Generator()
|
||||
generator_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.batching_workers,
|
||||
collate_fn=self.renderer.process_samples,
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
generator=generato_seed,
|
||||
generator=generator_seed,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
|
||||
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
|
||||
|
||||
def _generate_batch(self) -> list[BatchInput] | None:
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
|
||||
|
||||
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
|
||||
|
||||
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
|
||||
try:
|
||||
return next(self._data_iter)
|
||||
except StopIteration:
|
||||
if not restart:
|
||||
return None
|
||||
|
||||
# Dynamic batching may restart the provider to fill one token-budgeted batch.
|
||||
self._data_iter = iter(self._data_provider)
|
||||
try:
|
||||
return next(self._data_iter)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"buffer": self._buffer.state_dict(),
|
||||
|
||||
@@ -172,7 +172,7 @@ def _save_standard_training_states(
|
||||
if rank == 0:
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_dir = os.path.join(ckpt_dir, "model")
|
||||
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
|
||||
model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
|
||||
processor.save_pretrained(model_dir)
|
||||
|
||||
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
|
||||
@@ -212,7 +212,11 @@ def _load_standard_training_states(
|
||||
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
|
||||
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
|
||||
if state_dict:
|
||||
model_to_load.load_state_dict(state_dict)
|
||||
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
|
||||
if incompatible_keys.missing_keys:
|
||||
raise RuntimeError(
|
||||
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
|
||||
)
|
||||
else:
|
||||
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
|
||||
|
||||
|
||||
@@ -148,7 +148,9 @@ def launch():
|
||||
elif command == "dpo":
|
||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||
elif command == "rm":
|
||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||
from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||
|
||||
run_rm()
|
||||
|
||||
else:
|
||||
print(f"Unknown command: {command}.\n{USAGE}")
|
||||
@@ -175,9 +177,9 @@ def main():
|
||||
# run_dpo()
|
||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||
elif command == "rm":
|
||||
# from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||
# run_rm()
|
||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||
from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||
|
||||
run_rm()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -34,7 +34,7 @@ class BaseKernel(ABC):
|
||||
"""
|
||||
|
||||
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
|
||||
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
|
||||
_device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc.
|
||||
|
||||
@classmethod
|
||||
def get_kernel_id(cls) -> str:
|
||||
@@ -42,8 +42,8 @@ class BaseKernel(ABC):
|
||||
return cls._kernel_id
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
def get_device(cls) -> list[DeviceType]:
|
||||
"""Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"])."""
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
@@ -58,7 +58,7 @@ class BaseKernel(ABC):
|
||||
it should raise an error instead of silently switching.
|
||||
Kernels can override this method to implement custom dependency checks.
|
||||
"""
|
||||
if cls._device != get_current_accelerator().type:
|
||||
if get_current_accelerator().type not in cls._device:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode
|
||||
apply_kernel(kernel, model=model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@KernelPlugin("liger_kernel").register()
|
||||
def apply_liger_kernels(
|
||||
model: HFModel,
|
||||
include_kernels: str = None,
|
||||
require_logits: bool = False,
|
||||
) -> HFModel:
|
||||
"""Applies Liger kernel to the model.
|
||||
|
||||
Args:
|
||||
model (HFModel): The model instance to apply kernels to.
|
||||
include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with
|
||||
library defaults. If a comma-separated list (e.g.
|
||||
``rope,rms_norm``), enable only those ops; names match
|
||||
``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``,
|
||||
``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``.
|
||||
If ``None`` or ``False``, do nothing. Defaults to ``None``.
|
||||
require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor
|
||||
of non-fused CE so the forward pass returns ``logits``. Needed
|
||||
for trainers that compute weighted loss from logits (e.g. v1
|
||||
SFT with ``loss_weights``). Defaults to ``False`` (fused CE
|
||||
when supported). The v1 ``run_sft`` entrypoint sets
|
||||
``require_logits`` to true for ``liger_kernel`` when the key
|
||||
is omitted so SFT weighted loss keeps working.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with Liger kernel applied.
|
||||
"""
|
||||
if not include_kernels:
|
||||
return model
|
||||
if include_kernels == "auto" or include_kernels is True:
|
||||
use_kernels = "auto"
|
||||
else:
|
||||
use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()]
|
||||
if not use_kernels:
|
||||
return model
|
||||
|
||||
try:
|
||||
from .liger_kernel_ops import LigerKernel
|
||||
except ImportError as e:
|
||||
logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}")
|
||||
return model
|
||||
|
||||
return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits)
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of Liger Kernel.
|
||||
|
||||
Init Phase:
|
||||
1. Define LigerKernel class.
|
||||
2. Register Liger kernel.
|
||||
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.types import HFModel
|
||||
from .base import BaseKernel
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = {
|
||||
"qwen3": "apply_liger_kernel_to_qwen3",
|
||||
"qwen3_moe": "apply_liger_kernel_to_qwen3_moe",
|
||||
"qwen3_next": "apply_liger_kernel_to_qwen3_next",
|
||||
"qwen3_5": "apply_liger_kernel_to_qwen3_5",
|
||||
"qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text",
|
||||
"qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe",
|
||||
"qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text",
|
||||
}
|
||||
|
||||
|
||||
class LigerKernel(BaseKernel):
|
||||
"""Liger Kernel for optimized model training."""
|
||||
|
||||
_device = [DeviceType.CUDA, DeviceType.NPU]
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
"""Checks if the required dependencies for the kernel are available."""
|
||||
try:
|
||||
import liger_kernel # noqa: F401
|
||||
|
||||
return super().check_deps()
|
||||
except ImportError:
|
||||
logger.warning_rank0(
|
||||
"Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel."
|
||||
)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
"""Applies the Liger kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op
|
||||
names to enable exclusively, or the string ``"auto"`` to use each
|
||||
``apply_liger_kernel_to_*`` function's signature defaults (same as calling
|
||||
upstream with only ``model``). Optional ``require_logits`` forces non-fused
|
||||
cross entropy when supported.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with Liger kernel applied.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not provided.
|
||||
RuntimeError: If dependencies are not met.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
use_kernels = kwargs.get("use_kernels", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(
|
||||
f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}"
|
||||
)
|
||||
|
||||
require_logits = kwargs.get("require_logits", False)
|
||||
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
|
||||
if model_type not in _LIGER_FN_BY_MODEL_TYPE:
|
||||
logger.warning_rank0("Current model does not support liger kernel.")
|
||||
return model
|
||||
|
||||
import liger_kernel.transformers as liger_transformers
|
||||
|
||||
apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type])
|
||||
|
||||
sig = inspect.signature(apply_liger_kernel).parameters
|
||||
togglable = [name for name in sig if name != "model"]
|
||||
|
||||
def _normalize_op_name(raw: str) -> str:
|
||||
key = raw.strip().lower().replace("-", "_")
|
||||
aliases = {
|
||||
"rmsnorm": "rms_norm",
|
||||
"flce": "fused_linear_cross_entropy",
|
||||
"lce": "fused_linear_cross_entropy",
|
||||
"fused_ce": "fused_linear_cross_entropy",
|
||||
}
|
||||
return aliases.get(key, key)
|
||||
|
||||
if use_kernels is not None and len(use_kernels) == 0:
|
||||
return model
|
||||
|
||||
if use_kernels != "auto":
|
||||
selected = {_normalize_op_name(k) for k in use_kernels}
|
||||
ops = selected - set(togglable)
|
||||
if ops:
|
||||
raise ValueError(
|
||||
f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}"
|
||||
)
|
||||
if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected:
|
||||
raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.")
|
||||
call_kwargs = {name: (name in selected) for name in togglable}
|
||||
call_kwargs["model"] = model
|
||||
else:
|
||||
# Mirror ``liger_kernel`` signature defaults so patches match upstream defaults
|
||||
# and logging reflects enabled ops (omitted kwargs only live in the callee).
|
||||
call_kwargs = {"model": model}
|
||||
for name in togglable:
|
||||
param = sig[name]
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
call_kwargs[name] = param.default
|
||||
|
||||
if require_logits and "fused_linear_cross_entropy" in sig:
|
||||
logger.warning_rank0("Current training stage does not support chunked cross entropy.")
|
||||
call_kwargs["fused_linear_cross_entropy"] = False
|
||||
call_kwargs["cross_entropy"] = True
|
||||
|
||||
apply_liger_kernel(**call_kwargs)
|
||||
|
||||
applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on)
|
||||
logger.info_rank0(f"These Liger ops are applied to the model: {applied}")
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,429 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pure-Triton Fused MoE Kernel for NVIDIA GPUs.
|
||||
|
||||
Replaces the HuggingFace per-expert Python loop with a fully fused Triton pipeline:
|
||||
- Forward: scatter → grouped GEMM fc1 → SiLU·gate → apply routing → grouped GEMM fc2 → gather
|
||||
- Backward: all dX via grouped GEMM, all dW via grouped GEMM (no Python loops)
|
||||
|
||||
Supported models: Mixtral, Qwen3-MoE, Qwen3.5-MoE.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
from .triton_grouped_gemm import (
|
||||
group_gemm_same_mn,
|
||||
group_gemm_same_nk,
|
||||
moe_gather,
|
||||
moe_scatter,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Autograd Function: Full Triton MoE forward + backward
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TritonFusedMoeFunction(torch.autograd.Function):
|
||||
"""Fused MoE expert computation using Triton grouped GEMMs.
|
||||
|
||||
Forward: scatter → fc1 (group GEMM) → SiLU·gate → weight → fc2 (group GEMM) → gather
|
||||
Backward: all gradients computed via grouped GEMMs in single kernel launches.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
num_experts,
|
||||
gate_weights,
|
||||
expert_index,
|
||||
hidden_states,
|
||||
fc1_weight,
|
||||
fc2_weight,
|
||||
):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
ctx: autograd context
|
||||
num_experts: int
|
||||
gate_weights: (num_tokens, top_k) routing weights
|
||||
expert_index: (num_tokens, top_k) expert assignments
|
||||
hidden_states: (num_tokens, hidden_dim)
|
||||
fc1_weight: (E, 2*inter, hidden) merged gate+up weight
|
||||
fc2_weight: (E, hidden, inter) down projection weight
|
||||
"""
|
||||
# Compute scatter index: maps (token, topk) → position in sorted buffer
|
||||
scatter_index = expert_index.flatten().argsort(stable=True).argsort().int().view(expert_index.shape)
|
||||
|
||||
# Token counts per expert and cumulative boundaries
|
||||
splits = torch.zeros(num_experts, dtype=torch.int32, device=hidden_states.device)
|
||||
flat_experts = expert_index.flatten().int()
|
||||
splits.scatter_add_(0, flat_experts.long(), torch.ones_like(flat_experts))
|
||||
cumsum_t = torch.cumsum(splits, dim=0)
|
||||
|
||||
# Scatter hidden states to sorted expert buffer
|
||||
scatter_output = moe_scatter(hidden_states, scatter_index)
|
||||
|
||||
# FC1: grouped GEMM (scatter_output @ fc1_weight.T)
|
||||
max_M = int(splits.max().item())
|
||||
fc1_output = group_gemm_same_nk(
|
||||
a=scatter_output,
|
||||
b=fc1_weight,
|
||||
cumsum_M=cumsum_t,
|
||||
max_M=max_M,
|
||||
transpose_b=True,
|
||||
)
|
||||
|
||||
# SiLU gate activation
|
||||
fc1_1_output, fc1_2_output = fc1_output.chunk(2, dim=-1)
|
||||
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
|
||||
fc1_activation = fc1_1_activation * fc1_2_output
|
||||
|
||||
# Apply routing weights before fc2 (mathematically equivalent to after)
|
||||
reshaped_gate_weight = gate_weights.reshape(-1, 1)
|
||||
scattered_gate_weight = torch.empty_like(reshaped_gate_weight)
|
||||
scattered_gate_weight[scatter_index.flatten().long()] = reshaped_gate_weight
|
||||
fc1_weighted_output = fc1_activation * scattered_gate_weight
|
||||
|
||||
# FC2: grouped GEMM (fc1_weighted @ fc2_weight.T)
|
||||
fc2_output = group_gemm_same_nk(
|
||||
a=fc1_weighted_output,
|
||||
b=fc2_weight,
|
||||
cumsum_M=cumsum_t,
|
||||
max_M=max_M,
|
||||
transpose_b=True,
|
||||
)
|
||||
|
||||
# Gather back to original token positions (sum over topk)
|
||||
expert_output = moe_gather(fc2_output, scatter_index)
|
||||
|
||||
ctx.num_experts = num_experts
|
||||
ctx.save_for_backward(
|
||||
gate_weights,
|
||||
fc1_weight,
|
||||
fc2_weight,
|
||||
hidden_states,
|
||||
scatter_index,
|
||||
scatter_output,
|
||||
cumsum_t,
|
||||
fc1_1_output,
|
||||
fc1_2_output,
|
||||
fc1_activation,
|
||||
scattered_gate_weight,
|
||||
fc1_weighted_output,
|
||||
)
|
||||
|
||||
return expert_output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(
|
||||
gate_weights,
|
||||
fc1_weight,
|
||||
fc2_weight,
|
||||
hidden_states,
|
||||
scatter_index,
|
||||
scatter_output,
|
||||
cumsum_t,
|
||||
fc1_1_output,
|
||||
fc1_2_output,
|
||||
fc1_activation,
|
||||
scattered_gate_weight,
|
||||
fc1_weighted_output,
|
||||
) = ctx.saved_tensors
|
||||
num_experts = ctx.num_experts
|
||||
hidden_dim = grad_output.shape[-1]
|
||||
grad_output = grad_output.reshape(-1, hidden_dim).contiguous()
|
||||
|
||||
# Recompute max_M from cumsum
|
||||
splits = torch.zeros(num_experts, dtype=cumsum_t.dtype, device=cumsum_t.device)
|
||||
splits[0] = cumsum_t[0]
|
||||
splits[1:] = cumsum_t[1:] - cumsum_t[:-1]
|
||||
max_M = int(splits.max().item())
|
||||
|
||||
# Step 1: Scatter grad_output to expert buffer
|
||||
grad_fc2_output = moe_scatter(grad_output, scatter_index)
|
||||
|
||||
# Step 2: FC2 backward
|
||||
# dX for fc2: grad_fc2_output @ fc2_weight (transpose_b=False since fc2 is (E, hidden, inter))
|
||||
grad_fc1_weighted_output = group_gemm_same_nk(
|
||||
a=grad_fc2_output,
|
||||
b=fc2_weight,
|
||||
cumsum_M=cumsum_t,
|
||||
max_M=max_M,
|
||||
transpose_b=False,
|
||||
)
|
||||
|
||||
# dW for fc2: grad_fc2_output.T @ fc1_weighted_output
|
||||
grad_fc2_weight = None
|
||||
if fc2_weight.requires_grad:
|
||||
grad_fc2_weight = torch.empty_like(fc2_weight)
|
||||
group_gemm_same_mn(
|
||||
a=grad_fc2_output,
|
||||
b=fc1_weighted_output,
|
||||
c=grad_fc2_weight,
|
||||
cumsum_K=cumsum_t,
|
||||
)
|
||||
|
||||
# Step 3: Routing weight backward
|
||||
grad_fc1_activation = grad_fc1_weighted_output * scattered_gate_weight
|
||||
grad_scattered_gate_weight = torch.sum(fc1_activation * grad_fc1_weighted_output, dim=-1)
|
||||
grad_gate_weight = grad_scattered_gate_weight[scatter_index.flatten().long()]
|
||||
grad_gate_weight = grad_gate_weight.reshape(gate_weights.shape)
|
||||
|
||||
# Recompute silu activation for backward
|
||||
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
|
||||
|
||||
# Step 4: SiLU gate backward
|
||||
grad_fc1_1_activation = grad_fc1_activation * fc1_2_output
|
||||
grad_fc1_2_output = fc1_1_activation * grad_fc1_activation
|
||||
|
||||
# SiLU backward: d/dx[x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
|
||||
grad_fc1_1_output = torch.ops.aten.silu_backward(grad_fc1_1_activation, fc1_1_output)
|
||||
|
||||
# Merge fc1 gradients back to (total_M, 2*inter)
|
||||
grad_fc1_output = torch.cat([grad_fc1_1_output, grad_fc1_2_output], dim=-1)
|
||||
|
||||
# Step 5: FC1 backward
|
||||
# dX for fc1: grad_fc1_output @ fc1_weight (transpose_b=False)
|
||||
grad_scatter_output = group_gemm_same_nk(
|
||||
a=grad_fc1_output,
|
||||
b=fc1_weight,
|
||||
cumsum_M=cumsum_t,
|
||||
max_M=max_M,
|
||||
transpose_b=False,
|
||||
)
|
||||
|
||||
# dW for fc1: grad_fc1_output.T @ scatter_output
|
||||
grad_fc1_weight = None
|
||||
if fc1_weight.requires_grad:
|
||||
grad_fc1_weight = torch.empty_like(fc1_weight)
|
||||
group_gemm_same_mn(
|
||||
a=grad_fc1_output,
|
||||
b=scatter_output,
|
||||
c=grad_fc1_weight,
|
||||
cumsum_K=cumsum_t,
|
||||
)
|
||||
|
||||
# Step 6: Gather gradients back to original positions
|
||||
grad_hidden_states = moe_gather(grad_scatter_output, scatter_index)
|
||||
grad_hidden_states = grad_hidden_states.reshape(hidden_states.shape)
|
||||
|
||||
return (
|
||||
None, # num_experts
|
||||
grad_gate_weight, # gate_weights
|
||||
None, # expert_index
|
||||
grad_hidden_states, # hidden_states
|
||||
grad_fc1_weight, # fc1_weight
|
||||
grad_fc2_weight, # fc2_weight
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patched forward functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _triton_moe_experts_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Replacement forward for v5+ MoE expert modules with stacked 3D weights."""
|
||||
return TritonFusedMoeFunction.apply(
|
||||
self.num_experts,
|
||||
top_k_weights.to(hidden_states.dtype),
|
||||
top_k_index,
|
||||
hidden_states,
|
||||
self.gate_up_proj,
|
||||
self.down_proj,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy (transformers < 5.0) support: weight stacking + SparseMoeBlock patch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StackedExpertWeights(torch.nn.Module):
|
||||
"""Lightweight container holding stacked 3D expert weight tensors."""
|
||||
|
||||
def __init__(self, gate_up_proj: torch.Tensor, down_proj: torch.Tensor, num_experts: int):
|
||||
super().__init__()
|
||||
self.gate_up_proj = torch.nn.Parameter(gate_up_proj)
|
||||
self.down_proj = torch.nn.Parameter(down_proj)
|
||||
self.num_experts = num_experts
|
||||
|
||||
|
||||
def _stack_expert_weights(module: torch.nn.Module) -> None:
|
||||
"""Replace nn.ModuleList of individual experts with stacked 3D parameter tensors."""
|
||||
experts = module.experts
|
||||
num_experts = len(experts)
|
||||
|
||||
gate_up_list = []
|
||||
for expert in experts:
|
||||
gate_w = expert.gate_proj.weight.data # (inter, hidden)
|
||||
up_w = expert.up_proj.weight.data # (inter, hidden)
|
||||
gate_up_list.append(torch.cat([gate_w, up_w], dim=0)) # (2*inter, hidden)
|
||||
gate_up_proj = torch.stack(gate_up_list, dim=0) # (E, 2*inter, hidden)
|
||||
|
||||
down_proj = torch.stack([e.down_proj.weight.data for e in experts], dim=0) # (E, hidden, inter)
|
||||
|
||||
module.experts = _StackedExpertWeights(gate_up_proj, down_proj, num_experts)
|
||||
logger.info(
|
||||
f"cuda_fused_moe: Stacked {num_experts} expert weights into "
|
||||
f"gate_up_proj {tuple(gate_up_proj.shape)}, down_proj {tuple(down_proj.shape)}"
|
||||
)
|
||||
|
||||
|
||||
def _triton_moe_sparse_block_forward(self, hidden_states: torch.Tensor):
|
||||
"""Replacement forward for legacy SparseMoeBlock with inline routing."""
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
router_logits = self.gate(hidden_states)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = TritonFusedMoeFunction.apply(
|
||||
self.num_experts,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
hidden_states,
|
||||
self.experts.gate_up_proj,
|
||||
self.experts.down_proj,
|
||||
)
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TRITON_MOE_MAPPING: dict[str, dict[str, object]] = {
|
||||
"MixtralForCausalLM": {
|
||||
"MixtralExperts": _triton_moe_experts_forward,
|
||||
},
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeExperts": _triton_moe_experts_forward,
|
||||
"Qwen3MoeSparseMoeBlock": _triton_moe_sparse_block_forward,
|
||||
},
|
||||
"Qwen3_5MoeForCausalLM": {
|
||||
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
|
||||
},
|
||||
"Qwen3_5MoeForConditionalGeneration": {
|
||||
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Kernel registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@register_kernel
|
||||
class CudaFusedMoEKernel(BaseKernel):
|
||||
"""Pure-Triton fused MoE kernel for NVIDIA CUDA GPUs.
|
||||
|
||||
Replaces HuggingFace per-expert Python loops with a fully fused Triton pipeline:
|
||||
- Forward: scatter + grouped GEMMs + gather (single kernel per GEMM)
|
||||
- Backward: all dX and dW via grouped GEMMs (no Python loops)
|
||||
|
||||
Requires: CUDA GPU + Triton
|
||||
"""
|
||||
|
||||
_kernel_id = "cuda_fused_moe"
|
||||
_device = DeviceType.CUDA
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
if not super().check_deps():
|
||||
return False
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
logger.info("cuda_fused_moe: Triton not available, kernel disabled.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
model = kwargs.get("model")
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
logger.warning("cuda_fused_moe: Dependencies not met. Skipping kernel application.")
|
||||
return model
|
||||
|
||||
archs = getattr(model.config, "architectures", None) or []
|
||||
target_mapping = None
|
||||
for arch in archs:
|
||||
if arch in _TRITON_MOE_MAPPING:
|
||||
target_mapping = _TRITON_MOE_MAPPING[arch]
|
||||
break
|
||||
|
||||
if target_mapping is None:
|
||||
logger.info(
|
||||
f"cuda_fused_moe: Model architecture {archs} not supported. "
|
||||
f"Supported: {list(_TRITON_MOE_MAPPING.keys())}"
|
||||
)
|
||||
return model
|
||||
|
||||
patched_count = 0
|
||||
for module in model.modules():
|
||||
class_name = module.__class__.__name__
|
||||
if class_name not in target_mapping:
|
||||
continue
|
||||
|
||||
target_fn = target_mapping[class_name]
|
||||
|
||||
if hasattr(module, "gate_up_proj") and hasattr(module, "down_proj"):
|
||||
module.forward = types.MethodType(target_fn, module)
|
||||
patched_count += 1
|
||||
elif (
|
||||
hasattr(module, "experts")
|
||||
and isinstance(module.experts, torch.nn.ModuleList)
|
||||
and hasattr(module, "gate")
|
||||
):
|
||||
_stack_expert_weights(module)
|
||||
module.forward = types.MethodType(target_fn, module)
|
||||
patched_count += 1
|
||||
|
||||
if patched_count > 0:
|
||||
logger.info(f"cuda_fused_moe: Patched {patched_count} MoE expert modules with pure Triton pipeline.")
|
||||
else:
|
||||
logger.warning("cuda_fused_moe: No MoE expert modules found to patch.")
|
||||
|
||||
return model
|
||||
@@ -228,6 +228,30 @@ class NpuMoeFused:
|
||||
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
||||
return routed_out
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_v5_forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
|
||||
|
||||
Transformers v5 stores expert weights in F.linear layout:
|
||||
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
|
||||
down_proj: [num_experts, hidden_dim, intermediate_dim]
|
||||
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
|
||||
"""
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||
hidden_states, top_k_index.to(torch.int32)
|
||||
)
|
||||
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
|
||||
|
||||
gate_up_proj = self.gate_up_proj.transpose(1, 2)
|
||||
down_proj = self.down_proj.transpose(1, 2)
|
||||
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
|
||||
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
|
||||
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
|
||||
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
|
||||
|
||||
|
||||
# moe patch config mapping
|
||||
kernel_moe_mapping = {
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3_5MoeForCausalLM": {
|
||||
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3_5MoeForConditionalGeneration": {
|
||||
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
}
|
||||
else:
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
|
||||
},
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
||||
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
||||
}
|
||||
}
|
||||
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
|
||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Pure-Triton grouped GEMM and MoE scatter/gather kernels.
|
||||
# Design adapted from VeOmni (ByteDance-Seed/VeOmni) group_gemm kernels.
|
||||
|
||||
"""Pure-Triton MoE kernels: grouped GEMM, scatter, and gather.
|
||||
|
||||
Provides four kernel types for fused MoE forward+backward without Python loops:
|
||||
- group_gemm_same_nk: Variable-M grouped GEMM (forward & backward dX)
|
||||
- group_gemm_same_mn: Variable-K grouped GEMM (backward dW)
|
||||
- moe_scatter: Token dispatch to sorted expert buffers
|
||||
- moe_gather: Token reduction from expert buffers
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Triton helper: grouped tile indexing with L2 cache-friendly swizzle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _get_pid_mn(pid, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr):
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
num_pid_in_group = GROUP_SIZE * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# group_gemm_same_nk: All experts share same N, K; variable M per expert
|
||||
# Used for: forward (x @ W.T) and backward dX (grad @ W)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||
],
|
||||
key=["N", "K"],
|
||||
)
|
||||
@triton.jit
|
||||
def _group_gemm_same_nk_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
cumsum_M,
|
||||
max_M,
|
||||
G: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
TRANSPOSE_B: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
GROUP: tl.constexpr,
|
||||
):
|
||||
pid_m, pid_n = _get_pid_mn(tl.program_id(0), max_M, N, BLOCK_M, BLOCK_N, GROUP)
|
||||
gid = tl.program_id(1).to(tl.int64)
|
||||
|
||||
gtid_start = tl.load(cumsum_M + gid - 1, mask=gid > 0, other=0).to(tl.int64)
|
||||
gtid_end = tl.load(cumsum_M + gid).to(tl.int64)
|
||||
m_size = gtid_end - gtid_start
|
||||
|
||||
if pid_m * BLOCK_M >= m_size:
|
||||
return
|
||||
|
||||
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
# a is (total_M, K) row-major, offset by expert start
|
||||
a_base = a_ptr + gtid_start * K
|
||||
# b is (G, N, K) if TRANSPOSE_B else (G, K, N)
|
||||
b_base = b_ptr + gid * K * N
|
||||
# c is (total_M, N) row-major
|
||||
c_base = c_ptr + gtid_start * N
|
||||
|
||||
if TRANSPOSE_B:
|
||||
# b layout: (G, N, K), we compute a @ b.T = a(M,K) @ b(N,K).T -> (M,N)
|
||||
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
|
||||
b_ptrs = b_base + offs_n[:, None] * K + offs_k[None, :]
|
||||
else:
|
||||
# b layout: (G, K, N), we compute a @ b = a(M,K) @ b(K,N) -> (M,N)
|
||||
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
|
||||
b_ptrs = b_base + offs_k[:, None] * N + offs_n[None, :]
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
for k_start in range(0, K, BLOCK_K):
|
||||
k_offs = k_start + offs_k
|
||||
k_mask = k_offs < K
|
||||
|
||||
a_block = tl.load(a_ptrs, mask=(offs_m[:, None] < m_size) & k_mask[None, :], other=0.0)
|
||||
|
||||
if TRANSPOSE_B:
|
||||
b_block = tl.load(b_ptrs, mask=(offs_n[:, None] < N) & k_mask[None, :], other=0.0)
|
||||
acc += tl.dot(a_block, tl.trans(b_block))
|
||||
else:
|
||||
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
|
||||
acc += tl.dot(a_block, b_block)
|
||||
|
||||
if TRANSPOSE_B:
|
||||
a_ptrs += BLOCK_K
|
||||
b_ptrs += BLOCK_K
|
||||
else:
|
||||
a_ptrs += BLOCK_K
|
||||
b_ptrs += BLOCK_K * N
|
||||
|
||||
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
|
||||
c_mask = (offs_m[:, None] < m_size) & (offs_n[None, :] < N)
|
||||
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
|
||||
|
||||
|
||||
def group_gemm_same_nk(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
cumsum_M: torch.Tensor,
|
||||
max_M: int,
|
||||
transpose_b: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Grouped GEMM where all groups share same N, K dimensions but variable M.
|
||||
|
||||
Args:
|
||||
a: (total_M, K) input tensor, rows grouped by expert
|
||||
b: (G, N, K) if transpose_b else (G, K, N) weight tensor
|
||||
cumsum_M: (G,) cumulative token counts per expert
|
||||
max_M: maximum tokens any single expert has
|
||||
transpose_b: if True, compute a @ b.T; else compute a @ b
|
||||
|
||||
Returns:
|
||||
c: (total_M, N) output tensor
|
||||
"""
|
||||
if transpose_b:
|
||||
G, N, K = b.shape
|
||||
else:
|
||||
G, K, N = b.shape
|
||||
|
||||
c = torch.empty((a.shape[0], N), dtype=a.dtype, device=a.device)
|
||||
|
||||
_group_gemm_same_nk_kernel[
|
||||
(lambda meta: (triton.cdiv(max_M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))
|
||||
](
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
c_ptr=c,
|
||||
cumsum_M=cumsum_M,
|
||||
max_M=max_M,
|
||||
G=G,
|
||||
N=N,
|
||||
K=K,
|
||||
TRANSPOSE_B=transpose_b,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# group_gemm_same_mn: All experts share same M, N (weight dims); variable K
|
||||
# Used for: backward dW (grad.T @ input)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||
],
|
||||
key=["M", "N"],
|
||||
)
|
||||
@triton.jit
|
||||
def _group_gemm_same_mn_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
cumsum_K,
|
||||
G: tl.constexpr,
|
||||
M: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
GROUP: tl.constexpr,
|
||||
):
|
||||
pid_m, pid_n = _get_pid_mn(tl.program_id(0), M, N, BLOCK_M, BLOCK_N, GROUP)
|
||||
gid = tl.program_id(1).to(tl.int64)
|
||||
|
||||
gtid_start = tl.load(cumsum_K + gid - 1, mask=gid > 0, other=0).to(tl.int64)
|
||||
gtid_end = tl.load(cumsum_K + gid).to(tl.int64)
|
||||
k_size = gtid_end - gtid_start
|
||||
|
||||
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# c is (G, M, N)
|
||||
c_base = c_ptr + gid * M * N
|
||||
|
||||
if k_size == 0:
|
||||
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
|
||||
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||
tl.store(c_ptrs, tl.zeros((BLOCK_M, BLOCK_N), dtype=c_ptr.dtype.element_ty), mask=c_mask)
|
||||
return
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
# a is (total_K, M), compute a.T @ b -> (M, N)
|
||||
# b is (total_K, N)
|
||||
a_base = a_ptr + gtid_start * M
|
||||
b_base = b_ptr + gtid_start * N
|
||||
|
||||
for k_start in range(0, k_size, BLOCK_K):
|
||||
k_offs = k_start + offs_k
|
||||
k_mask = k_offs < k_size
|
||||
|
||||
a_ptrs = a_base + k_offs[:, None] * M + offs_m[None, :]
|
||||
a_block_t = tl.trans(tl.load(a_ptrs, mask=k_mask[:, None] & (offs_m[None, :] < M), other=0.0))
|
||||
|
||||
# Load b block: (BLOCK_K, BLOCK_N)
|
||||
b_ptrs = b_base + k_offs[:, None] * N + offs_n[None, :]
|
||||
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
|
||||
|
||||
acc += tl.dot(a_block_t, b_block)
|
||||
|
||||
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
|
||||
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
|
||||
|
||||
|
||||
def group_gemm_same_mn(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
cumsum_K: torch.Tensor,
|
||||
) -> None:
|
||||
"""Grouped GEMM where all groups produce same (M, N) output; variable K reduction.
|
||||
|
||||
Computes: c[g] = a[s:e].T @ b[s:e] for each group g,
|
||||
where s, e are defined by cumsum_K boundaries.
|
||||
|
||||
Args:
|
||||
a: (total_K, M) input tensor grouped by expert
|
||||
b: (total_K, N) input tensor grouped by expert
|
||||
c: (G, M, N) output tensor (pre-allocated)
|
||||
cumsum_K: (G,) cumulative token counts per expert
|
||||
"""
|
||||
G, M, N = c.shape
|
||||
|
||||
_group_gemm_same_mn_kernel[(lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))](
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
c_ptr=c,
|
||||
cumsum_K=cumsum_K,
|
||||
G=G,
|
||||
M=M,
|
||||
N=N,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# moe_scatter: Dispatch tokens to sorted expert buffer positions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _moe_scatter_kernel(
|
||||
x_ptr,
|
||||
out_ptr,
|
||||
index_ptr,
|
||||
M,
|
||||
N: tl.constexpr,
|
||||
TOPK: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
"""Scatter: for each token i, copy x[i] to out[index[i, k]] for k in 0..topk-1."""
|
||||
pid_m = tl.program_id(0).to(tl.int64)
|
||||
pid_n = tl.program_id(1)
|
||||
|
||||
if pid_m >= M:
|
||||
return
|
||||
|
||||
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
n_mask = offs_n < N
|
||||
|
||||
# Load input row
|
||||
x_ptrs = x_ptr + pid_m * N + offs_n
|
||||
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0)
|
||||
|
||||
# Store to each topk destination
|
||||
for k in tl.static_range(TOPK):
|
||||
dst_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
|
||||
out_ptrs = out_ptr + dst_idx * N + offs_n
|
||||
tl.store(out_ptrs, x_vals, mask=n_mask)
|
||||
|
||||
|
||||
def moe_scatter(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
|
||||
"""Scatter tokens to sorted expert buffer.
|
||||
|
||||
For each token i and topk slot k, copies x[i] to output[index[i, k]].
|
||||
|
||||
Args:
|
||||
x: (M, N) input hidden states
|
||||
index: (M, topk) scatter indices
|
||||
|
||||
Returns:
|
||||
out: (M * topk, N) scattered output
|
||||
"""
|
||||
M, N = x.shape
|
||||
topk = index.shape[1]
|
||||
out = torch.empty(M * topk, N, dtype=x.dtype, device=x.device)
|
||||
|
||||
BLOCK_N = min(triton.next_power_of_2(N), 1024)
|
||||
grid = (M, triton.cdiv(N, BLOCK_N))
|
||||
|
||||
_moe_scatter_kernel[grid](
|
||||
x_ptr=x,
|
||||
out_ptr=out,
|
||||
index_ptr=index,
|
||||
M=M,
|
||||
N=N,
|
||||
TOPK=topk,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# moe_gather: Reduce expert outputs back to token positions (sum over topk)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _moe_gather_kernel(
|
||||
x_ptr,
|
||||
out_ptr,
|
||||
index_ptr,
|
||||
M,
|
||||
N: tl.constexpr,
|
||||
TOPK: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
"""Gather: for each token i, out[i] = sum_k(x[index[i, k]]) over topk."""
|
||||
pid_m = tl.program_id(0).to(tl.int64)
|
||||
pid_n = tl.program_id(1)
|
||||
|
||||
if pid_m >= M:
|
||||
return
|
||||
|
||||
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
n_mask = offs_n < N
|
||||
|
||||
acc = tl.zeros([BLOCK_N], dtype=tl.float32)
|
||||
|
||||
for k in tl.static_range(TOPK):
|
||||
src_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
|
||||
x_ptrs = x_ptr + src_idx * N + offs_n
|
||||
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0).to(tl.float32)
|
||||
acc += x_vals
|
||||
|
||||
out_ptrs = out_ptr + pid_m * N + offs_n
|
||||
tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=n_mask)
|
||||
|
||||
|
||||
def moe_gather(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
|
||||
"""Gather and reduce expert outputs back to original token positions.
|
||||
|
||||
For each token i, sums x[index[i, k]] over all topk slots.
|
||||
|
||||
Args:
|
||||
x: (M * topk, N) expert outputs in sorted buffer
|
||||
index: (M, topk) scatter indices (same as used in moe_scatter)
|
||||
|
||||
Returns:
|
||||
out: (M, N) gathered output
|
||||
"""
|
||||
M, topk = index.shape
|
||||
N = x.shape[1]
|
||||
out = torch.empty(M, N, dtype=x.dtype, device=x.device)
|
||||
|
||||
BLOCK_N = min(triton.next_power_of_2(N), 1024)
|
||||
grid = (M, triton.cdiv(N, BLOCK_N))
|
||||
|
||||
_moe_gather_kernel[grid](
|
||||
x_ptr=x,
|
||||
out_ptr=out,
|
||||
index_ptr=index,
|
||||
M=M,
|
||||
N=N,
|
||||
TOPK=topk,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
return out
|
||||
@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
|
||||
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
|
||||
|
||||
.. note::
|
||||
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
|
||||
without hardcoding version numbers. Two methods are used: weight value inspection
|
||||
(most reliable) and class name pattern matching (backward compatibility).
|
||||
This must follow the module's forward semantics. Do not infer it from trained
|
||||
weight values because standard RMSNorm weights can also be close to zero.
|
||||
"""
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
weight_mean = module.weight.data.mean().item()
|
||||
if abs(weight_mean) < 0.3:
|
||||
return True
|
||||
residual_rmsnorm_classes = {
|
||||
"Qwen3_5RMSNorm",
|
||||
"Qwen3_5MoeRMSNorm",
|
||||
"Qwen3NextRMSNorm",
|
||||
}
|
||||
|
||||
class_name = module.__class__.__name__
|
||||
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
|
||||
for pattern in residual_patterns:
|
||||
if pattern in class_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
return class_name in residual_rmsnorm_classes
|
||||
|
||||
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
|
||||
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||
|
||||
if hasattr(self, "weight") and self.weight is not None:
|
||||
if _should_use_residual_rmsnorm(self):
|
||||
if getattr(self, "_npu_use_residual_rmsnorm", False):
|
||||
effective_weight = 1.0 + self.weight.float()
|
||||
else:
|
||||
effective_weight = self.weight.float()
|
||||
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
if "Gated" in module.__class__.__name__:
|
||||
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
|
||||
else:
|
||||
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
|
||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||
|
||||
return model
|
||||
|
||||
@@ -58,7 +58,7 @@ class Registry:
|
||||
device = kernel_cls.get_device()
|
||||
|
||||
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
|
||||
if device != get_current_accelerator().type:
|
||||
if get_current_accelerator().type not in device:
|
||||
return
|
||||
|
||||
if not kernel_id:
|
||||
|
||||
@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
|
||||
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
||||
# in shape : e.g., [s/p:h:]
|
||||
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
|
||||
|
||||
# scatter 2, gather 1
|
||||
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
||||
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||
@@ -123,10 +122,15 @@ class UlyssesAttention(torch.nn.Module):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** -0.5
|
||||
|
||||
if attention_mask is None:
|
||||
if position_ids is not None:
|
||||
attention_mask = torch.ones_like(position_ids).to(torch.int64)
|
||||
global_position_ids = [
|
||||
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_position_ids, position_ids, group=self.spg)
|
||||
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
|
||||
attention_mask = None
|
||||
else:
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
|
||||
@@ -12,23 +12,272 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
from math import ceil
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils.data import default_collate
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import pad_and_truncate
|
||||
from ...utils.objects import StatefulBuffer
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import BatchInfo, BatchInput, DataLoader
|
||||
from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput
|
||||
|
||||
|
||||
class BatchingPlugin(BasePlugin):
|
||||
def compute_length(self, data_provider: DataLoader) -> int:
|
||||
def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int:
|
||||
"""Return the raw data provider batch size for this batching strategy."""
|
||||
return self["get_data_provider_batch_size"](batch_info)
|
||||
|
||||
def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
"""Compute the length of the batch generator.
|
||||
|
||||
The approximate length is used to calculate the lr schedule.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return self["compute_length"](data_provider, batch_info)
|
||||
|
||||
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
|
||||
def fill_buffer(
|
||||
self,
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
"""Fill the buffer with data."""
|
||||
raise NotImplementedError()
|
||||
return self["fill_buffer"](buffer, batch_info, next_samples)
|
||||
|
||||
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
"""Generate a batch from the buffer."""
|
||||
raise NotImplementedError()
|
||||
return self["generate_batch"](buffer, batch_info)
|
||||
|
||||
|
||||
def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
|
||||
"""Return sample counts for micro batches formed by one padded-token budget."""
|
||||
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
sizes = []
|
||||
index = 0
|
||||
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
|
||||
max_sample_len = 0
|
||||
used = 0
|
||||
is_complete = False
|
||||
while index + used < len(samples):
|
||||
sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len)
|
||||
padded_tokens = max(max_sample_len, sample_len) * (used + 1)
|
||||
if used > 0 and padded_tokens > budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
max_sample_len = max(max_sample_len, sample_len)
|
||||
used += 1
|
||||
if max_sample_len * used >= budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
if used == 0 or not is_complete:
|
||||
break
|
||||
|
||||
sizes.append(used)
|
||||
index += used
|
||||
|
||||
return sizes
|
||||
|
||||
|
||||
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
|
||||
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
sizes = []
|
||||
index = 0
|
||||
|
||||
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
|
||||
current_tokens = 0
|
||||
used = 0
|
||||
is_complete = False
|
||||
|
||||
while index + used < len(samples):
|
||||
sample = samples[index + used]
|
||||
sample_len = min(len(sample["input_ids"]), cutoff_len)
|
||||
|
||||
if current_tokens + sample_len > budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
current_tokens += sample_len
|
||||
used += 1
|
||||
|
||||
if used <= 0 or not is_complete:
|
||||
break
|
||||
|
||||
sizes.append(used)
|
||||
index += used
|
||||
|
||||
return sizes
|
||||
|
||||
|
||||
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
|
||||
"""Pack fixed samples into one padding-free sequence without a token budget."""
|
||||
packed: dict[str, list[Any]] = {}
|
||||
position_ids: list[int] = []
|
||||
|
||||
for sample_index, sample in enumerate(samples):
|
||||
# Padding-free still truncates each sample by cutoff_len before packing
|
||||
# all samples into one contiguous sequence.
|
||||
sample_len = min(len(sample["input_ids"]), cutoff_len)
|
||||
if sample_len <= 0:
|
||||
continue
|
||||
|
||||
for key, value in sample.items():
|
||||
if key in ("attention_mask", "position_ids") or isinstance(value, str):
|
||||
continue
|
||||
|
||||
if key not in packed:
|
||||
packed[key] = []
|
||||
|
||||
sliced_value = list(value[:sample_len])
|
||||
if sample_index > 0 and sliced_value:
|
||||
if key == "labels":
|
||||
sliced_value[0] = IGNORE_INDEX
|
||||
elif key == "loss_weights":
|
||||
sliced_value[0] = 0.0
|
||||
|
||||
packed[key].extend(sliced_value)
|
||||
|
||||
position_ids.extend(range(sample_len))
|
||||
|
||||
if not position_ids:
|
||||
return None
|
||||
|
||||
packed["position_ids"] = position_ids
|
||||
packed["attention_mask"] = [1] * len(position_ids)
|
||||
return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()}
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("get_data_provider_batch_size")
|
||||
def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("compute_length")
|
||||
def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
return len(data_provider)
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("fill_buffer")
|
||||
def fill_padding_free_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]:
|
||||
samples = next_samples(False)
|
||||
if samples is None:
|
||||
break
|
||||
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("generate_batch")
|
||||
def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_size = batch_info["micro_batch_size"]
|
||||
num_micro_batch = batch_info["num_micro_batch"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
batch_size = micro_batch_size * num_micro_batch
|
||||
if len(buffer) < batch_size:
|
||||
return None
|
||||
|
||||
samples = buffer.get(batch_size)
|
||||
batch = []
|
||||
for i in range(num_micro_batch):
|
||||
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
|
||||
packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len)
|
||||
if packed_micro_batch is None:
|
||||
return None
|
||||
|
||||
batch.append(packed_micro_batch)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size")
|
||||
def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("compute_length")
|
||||
def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
return ceil(len(data_provider) / batch_size)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("fill_buffer")
|
||||
def fill_dynamic_batching_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
|
||||
samples = next_samples(True)
|
||||
if samples is None:
|
||||
break
|
||||
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("generate_batch")
|
||||
def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info)
|
||||
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
|
||||
return None
|
||||
|
||||
batch = []
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
for num_samples in micro_batch_sample_counts:
|
||||
samples = buffer.get(num_samples)
|
||||
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
|
||||
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("compute_length")
|
||||
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
return ceil(len(data_provider) / batch_size)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
|
||||
def fill_dynamic_padding_free_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
|
||||
samples = next_samples(True)
|
||||
if samples is None:
|
||||
break
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
|
||||
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
|
||||
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
|
||||
return None
|
||||
|
||||
batch = []
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
|
||||
for num_samples in micro_batch_sample_counts:
|
||||
samples = buffer.get(num_samples)
|
||||
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
|
||||
if packed_batch is None:
|
||||
return None
|
||||
|
||||
batch.append(packed_batch)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -381,7 +381,7 @@ class FSDP2Engine:
|
||||
|
||||
with torch.no_grad():
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
||||
if isinstance(grad_norm, torch.distributed.tensor.DTensor):
|
||||
grad_norm = grad_norm.full_tensor()
|
||||
|
||||
for param in model.parameters():
|
||||
|
||||
@@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt
|
||||
|
||||
@DistributedPlugin("deepspeed").register()
|
||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||
if dist_config.get("cp_size", 1) > 1:
|
||||
raise ValueError("CP currently requires `dist_config.name: fsdp2`.")
|
||||
|
||||
from .deepspeed import DeepSpeedEngine
|
||||
|
||||
return DeepSpeedEngine(
|
||||
@@ -78,14 +81,14 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register("save_checkpoint")
|
||||
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
||||
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
from .deepspeed import save_checkpoint
|
||||
|
||||
return save_checkpoint(model, optimizer, ckpt_dir)
|
||||
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register("load_checkpoint")
|
||||
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
||||
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
from .deepspeed import load_checkpoint
|
||||
|
||||
return load_checkpoint(model, optimizer, ckpt_dir)
|
||||
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import InputArgument, TrainingArguments, get_args
|
||||
from ..config.arg_utils import ModelClass
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
from ..utils import logging
|
||||
from ..utils.types import BatchInput, HFModel, Tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _validate_rm_dataset_format(train_dataset: DataEngine, dataset_path: str) -> None:
|
||||
"""Validate RM dataset format early for clearer error messages."""
|
||||
if len(train_dataset) == 0:
|
||||
raise ValueError(f"RM training dataset is empty: {dataset_path}")
|
||||
|
||||
sample = train_dataset[0]
|
||||
if "chosen_messages" in sample and "rejected_messages" in sample:
|
||||
return
|
||||
|
||||
dataset_name = sample.get("_dataset_name", "unknown")
|
||||
sample_keys = sorted(sample.keys())
|
||||
raise ValueError(
|
||||
"RM training requires pair-format samples containing chosen/rejected responses. "
|
||||
f"First sample from dataset '{dataset_name}' has keys: {sample_keys}. "
|
||||
"Please use pair data (e.g. a dataset with chosen_messages/rejected_messages, "
|
||||
"or set converter='pair' for raw chosen/rejected fields)."
|
||||
)
|
||||
|
||||
|
||||
def _init_score_head(model: HFModel) -> None:
|
||||
"""Initialize the score head for RM training with small Gaussian weights.
|
||||
|
||||
Uses Gaussian initialization so that different parameters have distinct values,
|
||||
providing better gradient flow than zero initialization while keeping initial
|
||||
scores small enough that the starting loss is close to ln(2).
|
||||
"""
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
score = getattr(unwrapped, "score", None)
|
||||
if score is not None and hasattr(score, "weight"):
|
||||
hidden_size = score.weight.shape[-1]
|
||||
std = 1.0 / (hidden_size * 10)
|
||||
with torch.no_grad():
|
||||
score.weight.normal_(mean=0.0, std=std)
|
||||
if score.bias is not None:
|
||||
score.bias.zero_()
|
||||
logger.info_rank0(f"Initialized score head with Gaussian (std={std:.6f}): {score.weight.shape}")
|
||||
|
||||
|
||||
class RMTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: HFModel,
|
||||
renderer,
|
||||
train_dataset,
|
||||
callbacks=None,
|
||||
) -> None:
|
||||
cp_size = args.dist_config.get("cp_size", 1) if args.dist_config is not None else 1
|
||||
if cp_size > 1:
|
||||
raise NotImplementedError("RM trainer currently only supports cp_size == 1.")
|
||||
|
||||
super().__init__(args, model, renderer, train_dataset, callbacks)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
if self.args.dist_config is None:
|
||||
if DistributedInterface().get_world_size(Dim.DP) > 1:
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
device_ids = None if self.device.type == "cpu" else [self.device.index]
|
||||
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=True)
|
||||
else:
|
||||
super()._shard_model()
|
||||
|
||||
@property
|
||||
def _unwrapped_model(self):
|
||||
"""Access the underlying model, unwrapping DDP/FSDP wrappers if present."""
|
||||
model = self.model
|
||||
if hasattr(model, "module"):
|
||||
model = model.module
|
||||
return model
|
||||
|
||||
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
|
||||
|
||||
token_type_ids = batch.get("token_type_ids")
|
||||
if token_type_ids is None:
|
||||
raise ValueError(
|
||||
"RM training requires pair data with token_type_ids. "
|
||||
"Ensure the dataset has chosen_messages/rejected_messages."
|
||||
)
|
||||
token_type_ids = token_type_ids.to(self.device, non_blocking=True)
|
||||
|
||||
# Use token_type_ids as document-index attention mask (values: 1=chosen, 2=rejected, 0=padding).
|
||||
# Transformers v5 models natively support this format in _update_causal_mask,
|
||||
# constructing the correct block-diagonal causal mask internally for all attention backends.
|
||||
model_attention_mask = token_type_ids
|
||||
|
||||
# Build position_ids that reset at each document boundary.
|
||||
batch_size, seq_len = token_type_ids.shape
|
||||
arange = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||
chosen_mask = token_type_ids == 1
|
||||
rejected_mask = token_type_ids == 2
|
||||
chosen_lens = chosen_mask.sum(dim=1, keepdim=True)
|
||||
position_ids = torch.zeros_like(token_type_ids)
|
||||
position_ids[chosen_mask] = arange[chosen_mask]
|
||||
position_ids[rejected_mask] = (arange - chosen_lens)[rejected_mask]
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=model_attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
rewards = model_output.logits.float().squeeze(-1)
|
||||
|
||||
chosen_mask = token_type_ids == 1
|
||||
rejected_mask = token_type_ids == 2
|
||||
|
||||
valid_pair_mask = chosen_mask.any(dim=-1) & rejected_mask.any(dim=-1)
|
||||
if not torch.any(valid_pair_mask):
|
||||
raise ValueError(
|
||||
"No valid RM pairs found in this micro-batch. "
|
||||
"This is usually caused by cutoff_len being too small and truncating chosen/rejected tokens."
|
||||
)
|
||||
|
||||
rewards = rewards[valid_pair_mask]
|
||||
chosen_mask = chosen_mask[valid_pair_mask]
|
||||
rejected_mask = rejected_mask[valid_pair_mask]
|
||||
|
||||
seq_len = rewards.size(-1)
|
||||
position_index = torch.arange(seq_len, device=self.device).unsqueeze(0)
|
||||
chosen_last_idx = (position_index * chosen_mask.long()).max(dim=-1).values
|
||||
rejected_last_idx = (position_index * rejected_mask.long()).max(dim=-1).values
|
||||
|
||||
chosen_scores = rewards.gather(dim=1, index=chosen_last_idx.unsqueeze(-1)).squeeze(-1)
|
||||
rejected_scores = rewards.gather(dim=1, index=rejected_last_idx.unsqueeze(-1)).squeeze(-1)
|
||||
return -F.logsigmoid(chosen_scores - rejected_scores).mean()
|
||||
|
||||
|
||||
def run_rm(args: InputArgument = None):
|
||||
model_args, data_args, training_args, _ = get_args(args)
|
||||
model_args.model_class = ModelClass.CLS
|
||||
DistributedInterface(training_args.dist_config)
|
||||
train_dataset = DataEngine(data_args.train_dataset)
|
||||
_validate_rm_dataset_format(train_dataset, data_args.train_dataset)
|
||||
model_engine = ModelEngine(model_args, is_train=True)
|
||||
_init_score_head(model_engine.model)
|
||||
trainer = RMTrainer(
|
||||
args=training_args,
|
||||
model=model_engine.model,
|
||||
renderer=model_engine.renderer,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.fit()
|
||||
trainer.save_model()
|
||||
DistributedInterface().destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_rm()
|
||||
|
||||
@@ -53,7 +53,7 @@ class LoggingCallback(TrainerCallback):
|
||||
return
|
||||
|
||||
# Human-readable output to stdout
|
||||
display_logs = {**logs, "total_steps": state.num_training_steps}
|
||||
display_logs = {**logs, "step": state.global_step, "total_steps": state.num_training_steps}
|
||||
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
||||
logger.info_rank0(parts)
|
||||
|
||||
|
||||
@@ -13,21 +13,45 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import set_seed as hf_set_seed
|
||||
|
||||
from ..accelerator.helper import is_torch_npu_available
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
def enable_full_determinism(seed: int) -> None:
|
||||
"""Enable full deterministic mode for reproducible distributed training."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.enabled = False
|
||||
if is_torch_npu_available():
|
||||
torch.npu.manual_seed(seed)
|
||||
torch.npu.manual_seed_all(seed)
|
||||
|
||||
|
||||
def set_seed(seed: int, full_determinism: bool = False) -> None:
|
||||
"""Set seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed: Random seed.
|
||||
full_determinism: Whether to enable full deterministic mode.
|
||||
"""
|
||||
if full_determinism:
|
||||
enable_full_determinism(seed)
|
||||
else:
|
||||
hf_set_seed(seed)
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,10 @@ class StatefulBuffer:
|
||||
def size(self) -> int:
|
||||
return self._buffer_size
|
||||
|
||||
@property
|
||||
def samples(self) -> list[ModelInput]:
|
||||
return self._buffer
|
||||
|
||||
def put(self, samples: list[ModelInput]) -> None:
|
||||
"""Add samples to the buffer."""
|
||||
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from enum import StrEnum, unique
|
||||
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
@@ -54,6 +54,13 @@ else:
|
||||
ProcessGroup = None
|
||||
|
||||
|
||||
@unique
|
||||
class AttentionFunction(StrEnum):
|
||||
EAGER = "eager"
|
||||
SDPA = "sdpa"
|
||||
FLASH_ATTENTION_2 = "flash_attention_2"
|
||||
|
||||
|
||||
class DatasetInfo(TypedDict, total=False):
|
||||
path: str
|
||||
"""Local file path."""
|
||||
@@ -171,8 +178,6 @@ class BatchInfo(TypedDict):
|
||||
"""Number of micro batches."""
|
||||
cutoff_len: int
|
||||
"""Cutoff length."""
|
||||
data_iter: Iterator[list[ModelInput]]
|
||||
"""Data iterator."""
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
|
||||
@@ -58,10 +58,3 @@ def test_multi_device():
|
||||
master_port = find_available_port()
|
||||
world_size = 2
|
||||
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/accelerator/test_interface.py
|
||||
"""
|
||||
test_all_device()
|
||||
|
||||
@@ -70,13 +70,3 @@ def test_get_args_from_yaml(tmp_path: Path):
|
||||
assert training_args.bf16 is False
|
||||
assert training_args.dist_config is None
|
||||
assert sample_args.sample_backend == "hf"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.config.test_args_parser
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_get_args_from_yaml(tmp_path=Path(tmp_dir))
|
||||
|
||||
@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
|
||||
for index in indexes:
|
||||
print(data_engine[index])
|
||||
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.test_data_engine
|
||||
"""
|
||||
test_map_dataset(1)
|
||||
|
||||
@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
|
||||
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
|
||||
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.test_model_loader
|
||||
"""
|
||||
test_tiny_qwen()
|
||||
test_tiny_qwen_with_kernel_plugin()
|
||||
|
||||
@@ -16,6 +16,164 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||
from llamafactory.v1.plugins.trainer_plugins.batching import (
|
||||
BatchingPlugin,
|
||||
_get_dynamic_micro_batch_sizes,
|
||||
_get_dynamic_padding_free_micro_batch_sizes,
|
||||
)
|
||||
from llamafactory.v1.utils.constants import IGNORE_INDEX
|
||||
from llamafactory.v1.utils.objects import StatefulBuffer
|
||||
|
||||
|
||||
def _make_model_input(length: int, start: int = 0):
|
||||
input_ids = list(range(start, start + length))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": [1] * length,
|
||||
"labels": input_ids.copy(),
|
||||
"loss_weights": [1.0] * length,
|
||||
}
|
||||
|
||||
|
||||
class _RestartableDataProvider:
|
||||
def __init__(self, batches):
|
||||
self.batches = batches
|
||||
self.num_iters = 0
|
||||
|
||||
def __iter__(self):
|
||||
self.num_iters += 1
|
||||
return iter(self.batches)
|
||||
|
||||
|
||||
def test_padding_free():
|
||||
buffer = StatefulBuffer()
|
||||
# Input samples:
|
||||
# sample 0 input_ids: [0, 1]
|
||||
# sample 1 input_ids: [10, 11, 12, 13]
|
||||
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
|
||||
|
||||
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
|
||||
|
||||
# Output batch:
|
||||
# sample 1 is truncated to [10, 11, 12]
|
||||
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (1, 5)
|
||||
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
|
||||
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
|
||||
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
|
||||
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
|
||||
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
|
||||
assert len(buffer) == 0
|
||||
|
||||
|
||||
def test_batching_plugin_data_provider_batch_sizes():
|
||||
batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 3,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
|
||||
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
|
||||
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
|
||||
|
||||
|
||||
def test_dynamic_batching():
|
||||
# Input samples:
|
||||
# sample lengths: [3, 4, 6, 2, 8, 9]
|
||||
# input_ids:
|
||||
# [0, 1, 2]
|
||||
# [10, 11, 12, 13]
|
||||
# [20, 21, 22, 23, 24, 25]
|
||||
# [30, 31]
|
||||
# [40, 41, 42, 43, 44, 45, 46, 47]
|
||||
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
|
||||
samples = [
|
||||
_make_model_input(3, 0),
|
||||
_make_model_input(4, 10),
|
||||
_make_model_input(6, 20),
|
||||
_make_model_input(2, 30),
|
||||
_make_model_input(8, 40),
|
||||
_make_model_input(9, 50),
|
||||
]
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
# Dynamic batching output plan:
|
||||
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
|
||||
# as the padded-token budget for one training micro batch.
|
||||
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
|
||||
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
|
||||
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put(samples)
|
||||
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
|
||||
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (3, 6)
|
||||
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
|
||||
buffer = StatefulBuffer()
|
||||
# Input buffer:
|
||||
# only one sample with length [6].
|
||||
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
|
||||
# this buffer has not filled the budget and has no next sample to prove overflow,
|
||||
# so dynamic batching cannot produce a batch yet.
|
||||
buffer.put([_make_model_input(6, 0)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
|
||||
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
|
||||
# Batch generation does not read from the data iterator. It only returns None and keeps
|
||||
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
|
||||
assert len(buffer) == 1
|
||||
|
||||
|
||||
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
|
||||
# Input data provider:
|
||||
# each iterator pass yields one sample with length [6].
|
||||
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
|
||||
# _fill_buffer keeps restarting the iterator until the next appended sample
|
||||
# proves that the previous dynamic micro batch has reached its budget boundary.
|
||||
samples = [_make_model_input(6, 0)]
|
||||
data_provider = _RestartableDataProvider([[sample] for sample in samples])
|
||||
|
||||
batch_generator = BatchGenerator.__new__(BatchGenerator)
|
||||
batch_generator.batching_strategy = "dynamic_batching"
|
||||
batch_generator.micro_batch_size = 2
|
||||
batch_generator.num_micro_batch = 1
|
||||
batch_generator._buffer = StatefulBuffer()
|
||||
batch_generator._data_provider = data_provider
|
||||
batch_generator._data_iter = iter(data_provider)
|
||||
batch_generator._batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 1,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
batch_generator._fill_buffer()
|
||||
|
||||
# Filled buffer after restart:
|
||||
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
|
||||
# because adding it to the first dynamic micro batch would exceed the budget.
|
||||
assert data_provider.num_iters == 4
|
||||
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
|
||||
|
||||
batch = batch_generator._generate_batch()
|
||||
|
||||
# Output batch:
|
||||
# dynamic batching returns [micro_batch_0]
|
||||
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (3, 6)
|
||||
assert len(batch_generator._buffer) == 1
|
||||
|
||||
|
||||
def test_normal_batching():
|
||||
@@ -45,8 +203,166 @@ def test_normal_batching():
|
||||
assert batch[0]["input_ids"].shape == (4, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_dynamic_padding_free():
|
||||
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
|
||||
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
|
||||
# input_ids breakdown:
|
||||
# sample 0: [0,1,2] (length=3)
|
||||
# sample 1: [10,11,12,13] (length=4)
|
||||
# sample 2: [20,21,22,23,24,25] (length=6)
|
||||
# sample 3: [30,31] (length=2)
|
||||
# sample 4: [40-47] (length=8)
|
||||
# sample 5: [50-58] (length=9)
|
||||
samples = [
|
||||
_make_model_input(3, 0),
|
||||
_make_model_input(4, 10),
|
||||
_make_model_input(6, 20),
|
||||
_make_model_input(2, 30),
|
||||
_make_model_input(8, 40),
|
||||
_make_model_input(9, 50),
|
||||
]
|
||||
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
|
||||
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
|
||||
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put(samples)
|
||||
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
|
||||
|
||||
assert batch is not None
|
||||
assert len(batch) == 1 # num_micro_batch=1
|
||||
packed_batch = batch[0]
|
||||
|
||||
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
|
||||
assert packed_batch["input_ids"].shape == (1, 15)
|
||||
|
||||
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
|
||||
assert packed_batch["input_ids"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13, # Sample 1
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25, # Sample 2
|
||||
30,
|
||||
31,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
|
||||
assert packed_batch["labels"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
IGNORE_INDEX,
|
||||
11,
|
||||
12,
|
||||
13, # Sample 1
|
||||
IGNORE_INDEX,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25, # Sample 2
|
||||
IGNORE_INDEX,
|
||||
31,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify attention_mask
|
||||
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
|
||||
|
||||
# Verify position_ids
|
||||
assert packed_batch["position_ids"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3, # Sample 1
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5, # Sample 2
|
||||
0,
|
||||
1,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
|
||||
assert len(buffer) == 2
|
||||
|
||||
|
||||
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put([_make_model_input(6, 0)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
|
||||
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
|
||||
# Batch generation does not read from the data iterator. It only returns None and keeps
|
||||
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
|
||||
assert len(buffer) == 1
|
||||
|
||||
|
||||
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
|
||||
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
|
||||
|
||||
Data provider yields one sample of length 6 per iteration.
|
||||
_fill_buffer keeps restarting iterator until next sample exceeds budget.
|
||||
Budget = 2 * 10 = 20 tokens.
|
||||
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
|
||||
So buffer will have 4 samples after fill_buffer.
|
||||
"""
|
||||
python -m tests_v1.core.utils.test_batching
|
||||
"""
|
||||
test_normal_batching()
|
||||
samples = [_make_model_input(6, 0)]
|
||||
data_provider = _RestartableDataProvider([[sample] for sample in samples])
|
||||
|
||||
batch_generator = BatchGenerator.__new__(BatchGenerator)
|
||||
batch_generator.batching_strategy = "dynamic_padding_free"
|
||||
batch_generator.micro_batch_size = 2
|
||||
batch_generator.num_micro_batch = 1
|
||||
batch_generator._buffer = StatefulBuffer()
|
||||
batch_generator._data_provider = data_provider
|
||||
batch_generator._data_iter = iter(data_provider)
|
||||
batch_generator._batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 1,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
|
||||
batch_generator._fill_buffer()
|
||||
|
||||
# Buffer after restarts:
|
||||
# 3 samples can fit (18 tokens)
|
||||
# 4th sample is kept in buffer for next batch
|
||||
# => num_iters = 4
|
||||
assert data_provider.num_iters == 4
|
||||
assert _get_dynamic_padding_free_micro_batch_sizes(
|
||||
batch_generator._buffer.samples, batch_generator._batch_info
|
||||
) == [3]
|
||||
|
||||
batch = batch_generator._generate_batch()
|
||||
|
||||
# Output batch:
|
||||
# dynamic_padding_free returns [micro_batch_0]
|
||||
# 3 samples packed into shape [1, 18]
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (1, 18)
|
||||
assert len(batch_generator._buffer) == 1
|
||||
|
||||
@@ -227,17 +227,3 @@ def test_process_dpo_samples():
|
||||
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
|
||||
assert model_inputs[0]["extra_info"] == "test"
|
||||
assert model_inputs[0]["_dataset_name"] == "default"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.utils.test_rendering
|
||||
"""
|
||||
test_chatml_rendering()
|
||||
test_chatml_parse()
|
||||
test_chatml_rendering_remote(16)
|
||||
test_qwen3_nothink_rendering()
|
||||
test_qwen3_nothink_parse()
|
||||
test_qwen3_nothink_rendering_remote(16)
|
||||
test_process_sft_samples()
|
||||
test_process_dpo_samples()
|
||||
|
||||
@@ -117,12 +117,3 @@ def test_pair_converter(num_samples: int):
|
||||
],
|
||||
}
|
||||
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.plugins.data_plugins.test_converter
|
||||
"""
|
||||
test_alpaca_converter(1)
|
||||
test_sharegpt_converter()
|
||||
test_pair_converter(1)
|
||||
|
||||
@@ -52,12 +52,3 @@ def test_init_on_default():
|
||||
)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
assert model_engine.model.device == DistributedInterface().current_device
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/plugins/model_plugins/test_init_plugin.py
|
||||
"""
|
||||
test_init_on_meta()
|
||||
test_init_on_rank0()
|
||||
test_init_on_default()
|
||||
|
||||
@@ -35,10 +35,3 @@ def test_sync_sampler():
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": "This is a test."}],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python tests_v1/sampler/test_cli_sampler.py
|
||||
"""
|
||||
test_sync_sampler()
|
||||
|
||||
Reference in New Issue
Block a user