11 Commits

Author SHA1 Message Date
Yaowei Zheng
7af909522a [version] release v0.9.5 (#10532) 2026-05-30 23:57:09 +08:00
xvxuopop
e016d2480e [fix] Fix NPU FusedMoE and RMSNorm (#10512) 2026-05-30 21:42:54 +08:00
jiaqiw09
7d719182c9 [model] fix non-packing batch (bsz>1) for Qwen3.5 with flash attention (#10529) 2026-05-30 21:41:41 +08:00
jiaqiw09
01398eb18d [v1] fix padding free with sp (#10513) 2026-05-26 23:49:21 +08:00
cxy
8e68764b65 [v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
2026-05-25 20:40:21 +08:00
Copilot
16ff5a23cb [fix] use getattr for profiler attrs to support MCA TrainingArguments (#10506)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2026-05-21 17:26:29 +08:00
jiaqiw09
bdcb92d035 [v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469) 2026-05-21 17:14:19 +08:00
sunyi0505
7e20db5735 [v1] support liger_kernel (#10493) 2026-05-21 11:44:56 +08:00
浮梦
2322bf1cc2 [v1] add cuda fused moe kernel, implementing with triton (#10481) 2026-05-20 20:49:42 +08:00
浮梦
368c48968f [callback] add torch profiler callback (#10463) 2026-05-20 20:47:52 +08:00
浮梦
8b5ea65770 [v1] support reward training stage (#10431)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-05-20 20:46:52 +08:00
53 changed files with 2518 additions and 216 deletions

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
\[ English | [中文](README_zh.md) \]
@@ -52,14 +50,11 @@ Start local training:
Start cloud training:
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
Read technical notes:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
- **Official Blog**: https://blog.llamafactory.net/en/
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -78,7 +73,6 @@ Read technical notes:
- [Data Preparation](#data-preparation)
- [Quickstart](#quickstart)
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
- [LLaMA Factory Online](#llama-factory-online)
- [Build Docker](#build-docker)
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
- [Download from ModelScope Hub](#download-from-modelscope-hub)
@@ -117,15 +111,11 @@ Read technical notes:
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
<details><summary>All Blogs</summary>
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
@@ -661,10 +651,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
llamafactory-cli webui
```
### LLaMA Factory Online
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
### Build Docker
For CUDA users:

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
\[ [English](README.md) | 中文 \]
@@ -52,8 +50,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
开始云端训练:
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online在线微调**https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
阅读技术文档:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
@@ -61,7 +57,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **官方博客**https://blog.llamafactory.net/
- **官方课程**https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [数据准备](#数据准备)
- [快速开始](#快速开始)
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
- [构建 Docker](#构建-docker)
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
- [从魔搭社区下载](#从魔搭社区下载)
@@ -119,15 +113,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
<details><summary>全部博客</summary>
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
- [LLaMA Factory微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
@@ -662,10 +652,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
llamafactory-cli webui
```
### LLaMA Factory Online 在线微调
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
### 构建 Docker
CUDA 用户:

View File

@@ -0,0 +1,31 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: normal
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: dynamic_batching
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: dynamic_padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,28 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: liger_kernel
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.5.dev0"
VERSION = "0.9.5"
def print_env() -> None:

View File

@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_ARGS = [
ModelArguments,
DataArguments,
TrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_MCA_ARGS = [
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
else:
_TRAIN_MCA_ARGS = []

View File

@@ -14,6 +14,7 @@
import json
from dataclasses import dataclass, field
from typing import Optional
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
@@ -63,6 +64,58 @@ class RayArguments:
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
@dataclass
class ProfilerArguments:
r"""Arguments for torch profiler configuration."""
enable_torch_profiler: bool = field(
default=False,
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
)
profiler_output_dir: Optional[str] = field(
default=None,
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
)
profiler_wait_steps: int = field(
default=1,
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
)
profiler_warmup_steps: int = field(
default=1,
metadata={"help": "Number of profiler warm-up steps per cycle."},
)
profiler_active_steps: int = field(
default=1,
metadata={"help": "Number of steps to actively record per cycle."},
)
profiler_repeat: int = field(
default=1,
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
)
profiler_record_shapes: bool = field(
default=True,
metadata={"help": "Whether to record tensor shapes during profiling."},
)
profiler_profile_memory: bool = field(
default=True,
metadata={"help": "Whether to profile memory usage."},
)
profiler_with_stack: bool = field(
default=True,
metadata={"help": "Whether to record stack traces during profiling."},
)
profile_modules: Optional[str] = field(
default=None,
metadata={
"help": (
"Comma-separated list of module name patterns to profile with CUDA events. "
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
"Reports per-module forward/backward timing statistics at each logging step."
)
},
)
@dataclass
class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
@@ -87,7 +140,7 @@ class Fp8Arguments:
@dataclass
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(

View File

@@ -162,8 +162,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
if position_ids is not None and batch_size == 1:
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
else:
cu_seqlens = None
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.

View File

@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import fnmatch
import json
import os
import signal
import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
@@ -31,7 +33,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
from ..extras.packages import is_safetensors_available
@@ -338,6 +340,96 @@ class LogCallback(TrainerCallback):
self.thread_pool.submit(self._write_log, args.output_dir, logs)
class TorchProfilerCallback(TrainerCallback):
r"""A callback for collecting torch.profiler traces during training.
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
Configuration fields (in YAML):
profiler_output_dir where to write traces (default: <output_dir>/profiler)
profiler_wait_steps steps to skip at start of each cycle (default: 1)
profiler_warmup_steps profiler warm-up steps per cycle (default: 1)
profiler_active_steps steps to record per cycle (default: 1)
profiler_repeat number of cycles; 0 = forever (default: 1)
profiler_record_shapes record tensor shapes (default: true)
profiler_profile_memory profile memory usage (default: true)
profiler_with_stack record stack traces (default: true)
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
``<profiler_output_dir>/rank_<N>/``.
"""
def __init__(self, training_args: "TrainingArguments") -> None:
self.profiler = None
self.profiler_args = training_args
@staticmethod
def _get_rank() -> int:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
@override
def on_train_begin(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
pa = self.profiler_args
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
rank = self._get_rank()
trace_dir = os.path.join(output_dir, f"rank_{rank}")
os.makedirs(trace_dir, exist_ok=True)
activities = [torch.profiler.ProfilerActivity.CPU]
try:
if is_torch_cuda_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
if is_torch_npu_available():
activities.append(torch.profiler.ProfilerActivity.NPU)
except Exception:
pass
self.profiler = torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=pa.profiler_wait_steps,
warmup=pa.profiler_warmup_steps,
active=pa.profiler_active_steps,
repeat=pa.profiler_repeat,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
record_shapes=pa.profiler_record_shapes,
profile_memory=pa.profiler_profile_memory,
with_stack=pa.profiler_with_stack,
)
self.profiler.start()
logger.info_rank0(
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
)
@override
def on_step_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.step()
@override
def on_train_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
logger.info_rank0("TorchProfiler stopped.")
class ReporterCallback(TrainerCallback):
r"""A callback for reporting training status to external logger."""
@@ -394,3 +486,143 @@ class ReporterCallback(TrainerCallback):
"generating_args": self.generating_args.to_dict(),
}
)
class ModuleProfilerCallback(TrainerCallback):
r"""Profile forward/backward time of specified modules using accelerator events.
Hooks are registered on modules matching the user-provided name patterns.
Timing statistics are logged at each trainer logging step.
Usage in YAML config:
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
Supports fnmatch wildcards:
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
"""
@staticmethod
def _get_accelerator():
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
if is_torch_cuda_available():
return torch.cuda.Event, torch.cuda.synchronize
if is_torch_npu_available():
return torch.npu.Event, torch.npu.synchronize
return None, None
def __init__(self, profile_modules: str) -> None:
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
self._create_event, self._synchronize = self._get_accelerator()
self._handles: list[Any] = []
self._forward_times: dict[str, list[float]] = defaultdict(list)
self._backward_times: dict[str, list[float]] = defaultdict(list)
self._pending_forward: dict[str, tuple] = {}
self._pending_backward: dict[str, tuple] = {}
@property
def enabled(self) -> bool:
return self._create_event is not None
def _match(self, name: str) -> bool:
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
def _make_forward_pre_hook(self, name: str):
def hook(module, input):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_forward[name] = (start, end)
return hook
def _make_forward_hook(self, name: str):
def hook(module, input, output):
pair = self._pending_forward.get(name)
if pair is not None:
pair[1].record()
return hook
def _make_backward_pre_hook(self, name: str):
def hook(module, grad_output):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_backward[name] = (start, end)
return hook
def _make_backward_hook(self, name: str):
def hook(module, grad_input, grad_output):
pair = self._pending_backward.get(name)
if pair is not None:
pair[1].record()
return hook
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
return
model = kwargs.get("model")
if model is None:
return
matched = []
for name, module in model.named_modules():
if not name or not self._match(name):
continue
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
matched.append(name)
if matched:
logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
)
else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
return
self._synchronize()
for name, (start, end) in self._pending_forward.items():
self._forward_times[name].append(start.elapsed_time(end))
self._pending_forward.clear()
for name, (start, end) in self._pending_backward.items():
self._backward_times[name].append(start.elapsed_time(end))
self._pending_backward.clear()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self._forward_times and not self._backward_times:
return
lines = ["[ModuleProfiler] Timing (ms):"]
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
for name in all_names:
fwd = self._forward_times.get(name, [])
bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
logger.info_rank0("\n".join(lines))
self._forward_times.clear()
self._backward_times.clear()
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
for handle in self._handles:
handle.remove()
self._handles.clear()

View File

@@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer):
self.running = RunningMoments(self.accelerator)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -69,10 +69,10 @@ class CustomTrainer(Trainer):
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer):
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
return super().create_optimizer(*args, **kwargs)
@override
def create_scheduler(

View File

@@ -32,7 +32,13 @@ from ..extras.packages import (
)
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .callbacks import (
LogCallback,
ModuleProfilerCallback,
PissaConvertCallback,
ReporterCallback,
TorchProfilerCallback,
)
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
if getattr(training_args, "enable_torch_profiler", False):
callbacks.append(TorchProfilerCallback(training_args))
if getattr(training_args, "profile_modules", None):
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.types import AttentionFunction
from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
__all__ = [
"AttentionFunction",
"BatchingStrategy",
"DataArguments",
"InputArgument",

View File

@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
model_args, data_args, training_args, sample_args = parsed_args
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
return tuple(parsed_args)
return model_args, data_args, training_args, sample_args
if __name__ == "__main__":

View File

@@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from ..utils.types import AttentionFunction
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -32,6 +33,12 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
flash_attn: AttentionFunction = field(
default=AttentionFunction.SDPA,
metadata={
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
@@ -54,6 +61,12 @@ class ModelArguments:
)
def __post_init__(self) -> None:
supported_flash_attn = [item.value for item in AttentionFunction]
if self.flash_attn not in supported_flash_attn:
raise ValueError(
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
)
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
full_determinism: bool = field(
default=False,
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
)
resume_from_checkpoint: str | None = field(
default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
@@ -116,3 +120,9 @@ class TrainingArguments:
self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
if self.max_steps is None or self.max_steps <= 0:
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
if self.save_epochs is not None:
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")

View File

@@ -34,7 +34,7 @@ import torch.nn.functional as F
from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..config import BatchingStrategy, TrainingArguments
from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
@@ -134,6 +134,9 @@ class BaseTrainer:
global_step=self.global_step,
epoch=self._resume_epoch,
)
# Keep callback state aligned with checkpoint-resumed trainer counters.
self.state.global_step = self.global_step
self.state.epoch = self._resume_epoch
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
@@ -144,13 +147,19 @@ class BaseTrainer:
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
raise ValueError(
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
)
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None:
if (
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
):
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
renderer=self.renderer,
@@ -234,6 +243,7 @@ class BaseTrainer:
self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
for micro_batches in self.train_batch_generator:
self.global_step += 1
@@ -303,7 +313,7 @@ class BaseTrainer:
if self.global_step % self.args.logging_steps == 0:
logs = {
"epoch": epoch,
"step": self.global_step,
"step": self.state.global_step,
"loss": step_loss,
"grad_norm": grad_norm,
"learning_rate": current_lr,
@@ -335,7 +345,9 @@ class BaseTrainer:
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
model_to_save.save_pretrained(
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
)
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -120,6 +120,7 @@ class ModelEngine:
init_device = DistributedInterface().current_device
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
@@ -143,6 +144,12 @@ class ModelEngine:
elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification
self.model_config.num_labels = 1
self.model_config.classifier_dropout = 0.0
text_config = getattr(self.model_config, "text_config", None)
if text_config is not None:
text_config.num_labels = 1
text_config.classifier_dropout = 0.0
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
@@ -158,6 +165,7 @@ class ModelEngine:
self.args.model,
config=self.model_config,
dtype="auto",
attn_implementation=self.args.flash_attn,
trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
)
@@ -182,9 +190,12 @@ class ModelEngine:
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model, include_kernels=self.args.kernel_config.get("include_kernels")
)
kernel_config = self.args.kernel_config
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
if kernel_config.name == "liger_kernel":
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
kernel_kwargs["require_logits"] = self.is_train
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
return model

View File

@@ -42,6 +42,8 @@ from .rendering import Renderer
logger = logging.get_logger(__name__)
__all__ = ["BatchGenerator"]
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
if not self.drop_last:
raise ValueError("Drop last must be True.")
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
}
self._init_data_provider()
self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer = StatefulBuffer()
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
"data_iter": self._data_iter,
}
logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
@@ -137,27 +138,33 @@ class BatchGenerator(Iterator):
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
generato_seed = torch.Generator()
generato_seed.manual_seed(self.seed)
if self.batching_strategy == BatchingStrategy.NORMAL:
batch_size = self.micro_batch_size * self.num_micro_batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
generator_seed = torch.Generator()
generator_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
batch_size=batch_size,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last,
generator=generato_seed,
generator=generator_seed,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
def __len__(self) -> int:
return self._length
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
def _generate_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
try:
return next(self._data_iter)
except StopIteration:
if not restart:
return None
# Dynamic batching may restart the provider to fill one token-budgeted batch.
self._data_iter = iter(self._data_provider)
try:
return next(self._data_iter)
except StopIteration:
return None
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer.state_dict(),

View File

@@ -172,7 +172,7 @@ def _save_standard_training_states(
if rank == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_dir = os.path.join(ckpt_dir, "model")
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
processor.save_pretrained(model_dir)
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
@@ -212,7 +212,11 @@ def _load_standard_training_states(
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
if state_dict:
model_to_load.load_state_dict(state_dict)
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
if incompatible_keys.missing_keys:
raise RuntimeError(
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
)
else:
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")

View File

@@ -148,7 +148,9 @@ def launch():
elif command == "dpo":
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
raise NotImplementedError("RM trainer is not implemented yet.")
from llamafactory.v1.trainers.rm_trainer import run_rm
run_rm()
else:
print(f"Unknown command: {command}.\n{USAGE}")
@@ -175,9 +177,9 @@ def main():
# run_dpo()
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
# from llamafactory.v1.trainers.rm_trainer import run_rm
# run_rm()
raise NotImplementedError("RM trainer is not implemented yet.")
from llamafactory.v1.trainers.rm_trainer import run_rm
run_rm()
if __name__ == "__main__":

View File

@@ -34,7 +34,7 @@ class BaseKernel(ABC):
"""
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
_device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc.
@classmethod
def get_kernel_id(cls) -> str:
@@ -42,8 +42,8 @@ class BaseKernel(ABC):
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
def get_device(cls) -> list[DeviceType]:
"""Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"])."""
return cls._device
@classmethod
@@ -58,7 +58,7 @@ class BaseKernel(ABC):
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if cls._device != get_current_accelerator().type:
if get_current_accelerator().type not in cls._device:
return False
return True

View File

@@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode
apply_kernel(kernel, model=model)
return model
@KernelPlugin("liger_kernel").register()
def apply_liger_kernels(
model: HFModel,
include_kernels: str = None,
require_logits: bool = False,
) -> HFModel:
"""Applies Liger kernel to the model.
Args:
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with
library defaults. If a comma-separated list (e.g.
``rope,rms_norm``), enable only those ops; names match
``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``,
``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``.
If ``None`` or ``False``, do nothing. Defaults to ``None``.
require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor
of non-fused CE so the forward pass returns ``logits``. Needed
for trainers that compute weighted loss from logits (e.g. v1
SFT with ``loss_weights``). Defaults to ``False`` (fused CE
when supported). The v1 ``run_sft`` entrypoint sets
``require_logits`` to true for ``liger_kernel`` when the key
is omitted so SFT weighted loss keeps working.
Returns:
HFModel: The model with Liger kernel applied.
"""
if not include_kernels:
return model
if include_kernels == "auto" or include_kernels is True:
use_kernels = "auto"
else:
use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()]
if not use_kernels:
return model
try:
from .liger_kernel_ops import LigerKernel
except ImportError as e:
logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}")
return model
return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits)

View File

@@ -0,0 +1,148 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of Liger Kernel.
Init Phase:
1. Define LigerKernel class.
2. Register Liger kernel.
"""
import inspect
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.logging import get_logger
from ....utils.types import HFModel
from .base import BaseKernel
logger = get_logger(__name__)
_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = {
"qwen3": "apply_liger_kernel_to_qwen3",
"qwen3_moe": "apply_liger_kernel_to_qwen3_moe",
"qwen3_next": "apply_liger_kernel_to_qwen3_next",
"qwen3_5": "apply_liger_kernel_to_qwen3_5",
"qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text",
"qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe",
"qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text",
}
class LigerKernel(BaseKernel):
"""Liger Kernel for optimized model training."""
_device = [DeviceType.CUDA, DeviceType.NPU]
@classmethod
def check_deps(cls) -> bool:
"""Checks if the required dependencies for the kernel are available."""
try:
import liger_kernel # noqa: F401
return super().check_deps()
except ImportError:
logger.warning_rank0(
"Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel."
)
return False
@classmethod
def apply(cls, **kwargs) -> "HFModel":
"""Applies the Liger kernel to the model.
Args:
**kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op
names to enable exclusively, or the string ``"auto"`` to use each
``apply_liger_kernel_to_*`` function's signature defaults (same as calling
upstream with only ``model``). Optional ``require_logits`` forces non-fused
cross entropy when supported.
Returns:
HFModel: The model with Liger kernel applied.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model")
use_kernels = kwargs.get("use_kernels", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(
f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}"
)
require_logits = kwargs.get("require_logits", False)
model_type = getattr(model.config, "model_type", None)
if model_type not in _LIGER_FN_BY_MODEL_TYPE:
logger.warning_rank0("Current model does not support liger kernel.")
return model
import liger_kernel.transformers as liger_transformers
apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type])
sig = inspect.signature(apply_liger_kernel).parameters
togglable = [name for name in sig if name != "model"]
def _normalize_op_name(raw: str) -> str:
key = raw.strip().lower().replace("-", "_")
aliases = {
"rmsnorm": "rms_norm",
"flce": "fused_linear_cross_entropy",
"lce": "fused_linear_cross_entropy",
"fused_ce": "fused_linear_cross_entropy",
}
return aliases.get(key, key)
if use_kernels is not None and len(use_kernels) == 0:
return model
if use_kernels != "auto":
selected = {_normalize_op_name(k) for k in use_kernels}
ops = selected - set(togglable)
if ops:
raise ValueError(
f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}"
)
if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected:
raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.")
call_kwargs = {name: (name in selected) for name in togglable}
call_kwargs["model"] = model
else:
# Mirror ``liger_kernel`` signature defaults so patches match upstream defaults
# and logging reflects enabled ops (omitted kwargs only live in the callee).
call_kwargs = {"model": model}
for name in togglable:
param = sig[name]
if param.default is not inspect.Parameter.empty:
call_kwargs[name] = param.default
if require_logits and "fused_linear_cross_entropy" in sig:
logger.warning_rank0("Current training stage does not support chunked cross entropy.")
call_kwargs["fused_linear_cross_entropy"] = False
call_kwargs["cross_entropy"] = True
apply_liger_kernel(**call_kwargs)
applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on)
logger.info_rank0(f"These Liger ops are applied to the model: {applied}")
return model

View File

@@ -0,0 +1,429 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pure-Triton Fused MoE Kernel for NVIDIA GPUs.
Replaces the HuggingFace per-expert Python loop with a fully fused Triton pipeline:
- Forward: scatter → grouped GEMM fc1 → SiLU·gate → apply routing → grouped GEMM fc2 → gather
- Backward: all dX via grouped GEMM, all dW via grouped GEMM (no Python loops)
Supported models: Mixtral, Qwen3-MoE, Qwen3.5-MoE.
"""
import logging
import types
import torch
import torch.nn.functional as F
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
from .triton_grouped_gemm import (
group_gemm_same_mn,
group_gemm_same_nk,
moe_gather,
moe_scatter,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Autograd Function: Full Triton MoE forward + backward
# ---------------------------------------------------------------------------
class TritonFusedMoeFunction(torch.autograd.Function):
"""Fused MoE expert computation using Triton grouped GEMMs.
Forward: scatter → fc1 (group GEMM) → SiLU·gate → weight → fc2 (group GEMM) → gather
Backward: all gradients computed via grouped GEMMs in single kernel launches.
"""
@staticmethod
def forward(
ctx,
num_experts,
gate_weights,
expert_index,
hidden_states,
fc1_weight,
fc2_weight,
):
"""Forward pass.
Args:
ctx: autograd context
num_experts: int
gate_weights: (num_tokens, top_k) routing weights
expert_index: (num_tokens, top_k) expert assignments
hidden_states: (num_tokens, hidden_dim)
fc1_weight: (E, 2*inter, hidden) merged gate+up weight
fc2_weight: (E, hidden, inter) down projection weight
"""
# Compute scatter index: maps (token, topk) → position in sorted buffer
scatter_index = expert_index.flatten().argsort(stable=True).argsort().int().view(expert_index.shape)
# Token counts per expert and cumulative boundaries
splits = torch.zeros(num_experts, dtype=torch.int32, device=hidden_states.device)
flat_experts = expert_index.flatten().int()
splits.scatter_add_(0, flat_experts.long(), torch.ones_like(flat_experts))
cumsum_t = torch.cumsum(splits, dim=0)
# Scatter hidden states to sorted expert buffer
scatter_output = moe_scatter(hidden_states, scatter_index)
# FC1: grouped GEMM (scatter_output @ fc1_weight.T)
max_M = int(splits.max().item())
fc1_output = group_gemm_same_nk(
a=scatter_output,
b=fc1_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=True,
)
# SiLU gate activation
fc1_1_output, fc1_2_output = fc1_output.chunk(2, dim=-1)
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
fc1_activation = fc1_1_activation * fc1_2_output
# Apply routing weights before fc2 (mathematically equivalent to after)
reshaped_gate_weight = gate_weights.reshape(-1, 1)
scattered_gate_weight = torch.empty_like(reshaped_gate_weight)
scattered_gate_weight[scatter_index.flatten().long()] = reshaped_gate_weight
fc1_weighted_output = fc1_activation * scattered_gate_weight
# FC2: grouped GEMM (fc1_weighted @ fc2_weight.T)
fc2_output = group_gemm_same_nk(
a=fc1_weighted_output,
b=fc2_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=True,
)
# Gather back to original token positions (sum over topk)
expert_output = moe_gather(fc2_output, scatter_index)
ctx.num_experts = num_experts
ctx.save_for_backward(
gate_weights,
fc1_weight,
fc2_weight,
hidden_states,
scatter_index,
scatter_output,
cumsum_t,
fc1_1_output,
fc1_2_output,
fc1_activation,
scattered_gate_weight,
fc1_weighted_output,
)
return expert_output
@staticmethod
def backward(ctx, grad_output):
(
gate_weights,
fc1_weight,
fc2_weight,
hidden_states,
scatter_index,
scatter_output,
cumsum_t,
fc1_1_output,
fc1_2_output,
fc1_activation,
scattered_gate_weight,
fc1_weighted_output,
) = ctx.saved_tensors
num_experts = ctx.num_experts
hidden_dim = grad_output.shape[-1]
grad_output = grad_output.reshape(-1, hidden_dim).contiguous()
# Recompute max_M from cumsum
splits = torch.zeros(num_experts, dtype=cumsum_t.dtype, device=cumsum_t.device)
splits[0] = cumsum_t[0]
splits[1:] = cumsum_t[1:] - cumsum_t[:-1]
max_M = int(splits.max().item())
# Step 1: Scatter grad_output to expert buffer
grad_fc2_output = moe_scatter(grad_output, scatter_index)
# Step 2: FC2 backward
# dX for fc2: grad_fc2_output @ fc2_weight (transpose_b=False since fc2 is (E, hidden, inter))
grad_fc1_weighted_output = group_gemm_same_nk(
a=grad_fc2_output,
b=fc2_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=False,
)
# dW for fc2: grad_fc2_output.T @ fc1_weighted_output
grad_fc2_weight = None
if fc2_weight.requires_grad:
grad_fc2_weight = torch.empty_like(fc2_weight)
group_gemm_same_mn(
a=grad_fc2_output,
b=fc1_weighted_output,
c=grad_fc2_weight,
cumsum_K=cumsum_t,
)
# Step 3: Routing weight backward
grad_fc1_activation = grad_fc1_weighted_output * scattered_gate_weight
grad_scattered_gate_weight = torch.sum(fc1_activation * grad_fc1_weighted_output, dim=-1)
grad_gate_weight = grad_scattered_gate_weight[scatter_index.flatten().long()]
grad_gate_weight = grad_gate_weight.reshape(gate_weights.shape)
# Recompute silu activation for backward
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
# Step 4: SiLU gate backward
grad_fc1_1_activation = grad_fc1_activation * fc1_2_output
grad_fc1_2_output = fc1_1_activation * grad_fc1_activation
# SiLU backward: d/dx[x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
grad_fc1_1_output = torch.ops.aten.silu_backward(grad_fc1_1_activation, fc1_1_output)
# Merge fc1 gradients back to (total_M, 2*inter)
grad_fc1_output = torch.cat([grad_fc1_1_output, grad_fc1_2_output], dim=-1)
# Step 5: FC1 backward
# dX for fc1: grad_fc1_output @ fc1_weight (transpose_b=False)
grad_scatter_output = group_gemm_same_nk(
a=grad_fc1_output,
b=fc1_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=False,
)
# dW for fc1: grad_fc1_output.T @ scatter_output
grad_fc1_weight = None
if fc1_weight.requires_grad:
grad_fc1_weight = torch.empty_like(fc1_weight)
group_gemm_same_mn(
a=grad_fc1_output,
b=scatter_output,
c=grad_fc1_weight,
cumsum_K=cumsum_t,
)
# Step 6: Gather gradients back to original positions
grad_hidden_states = moe_gather(grad_scatter_output, scatter_index)
grad_hidden_states = grad_hidden_states.reshape(hidden_states.shape)
return (
None, # num_experts
grad_gate_weight, # gate_weights
None, # expert_index
grad_hidden_states, # hidden_states
grad_fc1_weight, # fc1_weight
grad_fc2_weight, # fc2_weight
)
# ---------------------------------------------------------------------------
# Patched forward functions
# ---------------------------------------------------------------------------
def _triton_moe_experts_forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""Replacement forward for v5+ MoE expert modules with stacked 3D weights."""
return TritonFusedMoeFunction.apply(
self.num_experts,
top_k_weights.to(hidden_states.dtype),
top_k_index,
hidden_states,
self.gate_up_proj,
self.down_proj,
)
# ---------------------------------------------------------------------------
# Legacy (transformers < 5.0) support: weight stacking + SparseMoeBlock patch
# ---------------------------------------------------------------------------
class _StackedExpertWeights(torch.nn.Module):
"""Lightweight container holding stacked 3D expert weight tensors."""
def __init__(self, gate_up_proj: torch.Tensor, down_proj: torch.Tensor, num_experts: int):
super().__init__()
self.gate_up_proj = torch.nn.Parameter(gate_up_proj)
self.down_proj = torch.nn.Parameter(down_proj)
self.num_experts = num_experts
def _stack_expert_weights(module: torch.nn.Module) -> None:
"""Replace nn.ModuleList of individual experts with stacked 3D parameter tensors."""
experts = module.experts
num_experts = len(experts)
gate_up_list = []
for expert in experts:
gate_w = expert.gate_proj.weight.data # (inter, hidden)
up_w = expert.up_proj.weight.data # (inter, hidden)
gate_up_list.append(torch.cat([gate_w, up_w], dim=0)) # (2*inter, hidden)
gate_up_proj = torch.stack(gate_up_list, dim=0) # (E, 2*inter, hidden)
down_proj = torch.stack([e.down_proj.weight.data for e in experts], dim=0) # (E, hidden, inter)
module.experts = _StackedExpertWeights(gate_up_proj, down_proj, num_experts)
logger.info(
f"cuda_fused_moe: Stacked {num_experts} expert weights into "
f"gate_up_proj {tuple(gate_up_proj.shape)}, down_proj {tuple(down_proj.shape)}"
)
def _triton_moe_sparse_block_forward(self, hidden_states: torch.Tensor):
"""Replacement forward for legacy SparseMoeBlock with inline routing."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = TritonFusedMoeFunction.apply(
self.num_experts,
routing_weights,
selected_experts,
hidden_states,
self.experts.gate_up_proj,
self.experts.down_proj,
)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
# ---------------------------------------------------------------------------
# Module mapping
# ---------------------------------------------------------------------------
_TRITON_MOE_MAPPING: dict[str, dict[str, object]] = {
"MixtralForCausalLM": {
"MixtralExperts": _triton_moe_experts_forward,
},
"Qwen3MoeForCausalLM": {
"Qwen3MoeExperts": _triton_moe_experts_forward,
"Qwen3MoeSparseMoeBlock": _triton_moe_sparse_block_forward,
},
"Qwen3_5MoeForCausalLM": {
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
},
"Qwen3_5MoeForConditionalGeneration": {
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
},
}
# ---------------------------------------------------------------------------
# Kernel registration
# ---------------------------------------------------------------------------
@register_kernel
class CudaFusedMoEKernel(BaseKernel):
"""Pure-Triton fused MoE kernel for NVIDIA CUDA GPUs.
Replaces HuggingFace per-expert Python loops with a fully fused Triton pipeline:
- Forward: scatter + grouped GEMMs + gather (single kernel per GEMM)
- Backward: all dX and dW via grouped GEMMs (no Python loops)
Requires: CUDA GPU + Triton
"""
_kernel_id = "cuda_fused_moe"
_device = DeviceType.CUDA
@classmethod
def check_deps(cls) -> bool:
if not super().check_deps():
return False
try:
import triton # noqa: F401
return True
except ImportError:
logger.info("cuda_fused_moe: Triton not available, kernel disabled.")
return False
@classmethod
def apply(cls, **kwargs) -> HFModel:
model = kwargs.get("model")
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
logger.warning("cuda_fused_moe: Dependencies not met. Skipping kernel application.")
return model
archs = getattr(model.config, "architectures", None) or []
target_mapping = None
for arch in archs:
if arch in _TRITON_MOE_MAPPING:
target_mapping = _TRITON_MOE_MAPPING[arch]
break
if target_mapping is None:
logger.info(
f"cuda_fused_moe: Model architecture {archs} not supported. "
f"Supported: {list(_TRITON_MOE_MAPPING.keys())}"
)
return model
patched_count = 0
for module in model.modules():
class_name = module.__class__.__name__
if class_name not in target_mapping:
continue
target_fn = target_mapping[class_name]
if hasattr(module, "gate_up_proj") and hasattr(module, "down_proj"):
module.forward = types.MethodType(target_fn, module)
patched_count += 1
elif (
hasattr(module, "experts")
and isinstance(module.experts, torch.nn.ModuleList)
and hasattr(module, "gate")
):
_stack_expert_weights(module)
module.forward = types.MethodType(target_fn, module)
patched_count += 1
if patched_count > 0:
logger.info(f"cuda_fused_moe: Patched {patched_count} MoE expert modules with pure Triton pipeline.")
else:
logger.warning("cuda_fused_moe: No MoE expert modules found to patch.")
return model

View File

@@ -228,6 +228,30 @@ class NpuMoeFused:
routed_out = self.experts(hidden_states, routing_weights, router_indices)
return routed_out
@staticmethod
def npu_moe_experts_v5_forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
Transformers v5 stores expert weights in F.linear layout:
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
down_proj: [num_experts, hidden_dim, intermediate_dim]
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
"""
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
hidden_states, top_k_index.to(torch.int32)
)
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
gate_up_proj = self.gate_up_proj.transpose(1, 2)
down_proj = self.down_proj.transpose(1, 2)
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
class Qwen3NpuMoeFused:
"""Container for Qwen3 NPU fused MoE forward functions."""
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
# moe patch config mapping
kernel_moe_mapping = {
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
if is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForCausalLM": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForConditionalGeneration": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
}
}
if not is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
else:
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
},
}

View File

@@ -0,0 +1,417 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Pure-Triton grouped GEMM and MoE scatter/gather kernels.
# Design adapted from VeOmni (ByteDance-Seed/VeOmni) group_gemm kernels.
"""Pure-Triton MoE kernels: grouped GEMM, scatter, and gather.
Provides four kernel types for fused MoE forward+backward without Python loops:
- group_gemm_same_nk: Variable-M grouped GEMM (forward & backward dX)
- group_gemm_same_mn: Variable-K grouped GEMM (backward dW)
- moe_scatter: Token dispatch to sorted expert buffers
- moe_gather: Token reduction from expert buffers
"""
import torch
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# Triton helper: grouped tile indexing with L2 cache-friendly swizzle
# ---------------------------------------------------------------------------
@triton.jit
def _get_pid_mn(pid, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr):
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
return pid_m, pid_n
# ---------------------------------------------------------------------------
# group_gemm_same_nk: All experts share same N, K; variable M per expert
# Used for: forward (x @ W.T) and backward dX (grad @ W)
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
],
key=["N", "K"],
)
@triton.jit
def _group_gemm_same_nk_kernel(
a_ptr,
b_ptr,
c_ptr,
cumsum_M,
max_M,
G: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
TRANSPOSE_B: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP: tl.constexpr,
):
pid_m, pid_n = _get_pid_mn(tl.program_id(0), max_M, N, BLOCK_M, BLOCK_N, GROUP)
gid = tl.program_id(1).to(tl.int64)
gtid_start = tl.load(cumsum_M + gid - 1, mask=gid > 0, other=0).to(tl.int64)
gtid_end = tl.load(cumsum_M + gid).to(tl.int64)
m_size = gtid_end - gtid_start
if pid_m * BLOCK_M >= m_size:
return
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# a is (total_M, K) row-major, offset by expert start
a_base = a_ptr + gtid_start * K
# b is (G, N, K) if TRANSPOSE_B else (G, K, N)
b_base = b_ptr + gid * K * N
# c is (total_M, N) row-major
c_base = c_ptr + gtid_start * N
if TRANSPOSE_B:
# b layout: (G, N, K), we compute a @ b.T = a(M,K) @ b(N,K).T -> (M,N)
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_base + offs_n[:, None] * K + offs_k[None, :]
else:
# b layout: (G, K, N), we compute a @ b = a(M,K) @ b(K,N) -> (M,N)
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_base + offs_k[:, None] * N + offs_n[None, :]
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
k_offs = k_start + offs_k
k_mask = k_offs < K
a_block = tl.load(a_ptrs, mask=(offs_m[:, None] < m_size) & k_mask[None, :], other=0.0)
if TRANSPOSE_B:
b_block = tl.load(b_ptrs, mask=(offs_n[:, None] < N) & k_mask[None, :], other=0.0)
acc += tl.dot(a_block, tl.trans(b_block))
else:
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a_block, b_block)
if TRANSPOSE_B:
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K
else:
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K * N
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < m_size) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
def group_gemm_same_nk(
a: torch.Tensor,
b: torch.Tensor,
cumsum_M: torch.Tensor,
max_M: int,
transpose_b: bool = False,
) -> torch.Tensor:
"""Grouped GEMM where all groups share same N, K dimensions but variable M.
Args:
a: (total_M, K) input tensor, rows grouped by expert
b: (G, N, K) if transpose_b else (G, K, N) weight tensor
cumsum_M: (G,) cumulative token counts per expert
max_M: maximum tokens any single expert has
transpose_b: if True, compute a @ b.T; else compute a @ b
Returns:
c: (total_M, N) output tensor
"""
if transpose_b:
G, N, K = b.shape
else:
G, K, N = b.shape
c = torch.empty((a.shape[0], N), dtype=a.dtype, device=a.device)
_group_gemm_same_nk_kernel[
(lambda meta: (triton.cdiv(max_M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))
](
a_ptr=a,
b_ptr=b,
c_ptr=c,
cumsum_M=cumsum_M,
max_M=max_M,
G=G,
N=N,
K=K,
TRANSPOSE_B=transpose_b,
)
return c
# ---------------------------------------------------------------------------
# group_gemm_same_mn: All experts share same M, N (weight dims); variable K
# Used for: backward dW (grad.T @ input)
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
],
key=["M", "N"],
)
@triton.jit
def _group_gemm_same_mn_kernel(
a_ptr,
b_ptr,
c_ptr,
cumsum_K,
G: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP: tl.constexpr,
):
pid_m, pid_n = _get_pid_mn(tl.program_id(0), M, N, BLOCK_M, BLOCK_N, GROUP)
gid = tl.program_id(1).to(tl.int64)
gtid_start = tl.load(cumsum_K + gid - 1, mask=gid > 0, other=0).to(tl.int64)
gtid_end = tl.load(cumsum_K + gid).to(tl.int64)
k_size = gtid_end - gtid_start
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# c is (G, M, N)
c_base = c_ptr + gid * M * N
if k_size == 0:
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, tl.zeros((BLOCK_M, BLOCK_N), dtype=c_ptr.dtype.element_ty), mask=c_mask)
return
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_k = tl.arange(0, BLOCK_K)
# a is (total_K, M), compute a.T @ b -> (M, N)
# b is (total_K, N)
a_base = a_ptr + gtid_start * M
b_base = b_ptr + gtid_start * N
for k_start in range(0, k_size, BLOCK_K):
k_offs = k_start + offs_k
k_mask = k_offs < k_size
a_ptrs = a_base + k_offs[:, None] * M + offs_m[None, :]
a_block_t = tl.trans(tl.load(a_ptrs, mask=k_mask[:, None] & (offs_m[None, :] < M), other=0.0))
# Load b block: (BLOCK_K, BLOCK_N)
b_ptrs = b_base + k_offs[:, None] * N + offs_n[None, :]
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a_block_t, b_block)
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
def group_gemm_same_mn(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
cumsum_K: torch.Tensor,
) -> None:
"""Grouped GEMM where all groups produce same (M, N) output; variable K reduction.
Computes: c[g] = a[s:e].T @ b[s:e] for each group g,
where s, e are defined by cumsum_K boundaries.
Args:
a: (total_K, M) input tensor grouped by expert
b: (total_K, N) input tensor grouped by expert
c: (G, M, N) output tensor (pre-allocated)
cumsum_K: (G,) cumulative token counts per expert
"""
G, M, N = c.shape
_group_gemm_same_mn_kernel[(lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))](
a_ptr=a,
b_ptr=b,
c_ptr=c,
cumsum_K=cumsum_K,
G=G,
M=M,
N=N,
)
# ---------------------------------------------------------------------------
# moe_scatter: Dispatch tokens to sorted expert buffer positions
# ---------------------------------------------------------------------------
@triton.jit
def _moe_scatter_kernel(
x_ptr,
out_ptr,
index_ptr,
M,
N: tl.constexpr,
TOPK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Scatter: for each token i, copy x[i] to out[index[i, k]] for k in 0..topk-1."""
pid_m = tl.program_id(0).to(tl.int64)
pid_n = tl.program_id(1)
if pid_m >= M:
return
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
# Load input row
x_ptrs = x_ptr + pid_m * N + offs_n
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0)
# Store to each topk destination
for k in tl.static_range(TOPK):
dst_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
out_ptrs = out_ptr + dst_idx * N + offs_n
tl.store(out_ptrs, x_vals, mask=n_mask)
def moe_scatter(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""Scatter tokens to sorted expert buffer.
For each token i and topk slot k, copies x[i] to output[index[i, k]].
Args:
x: (M, N) input hidden states
index: (M, topk) scatter indices
Returns:
out: (M * topk, N) scattered output
"""
M, N = x.shape
topk = index.shape[1]
out = torch.empty(M * topk, N, dtype=x.dtype, device=x.device)
BLOCK_N = min(triton.next_power_of_2(N), 1024)
grid = (M, triton.cdiv(N, BLOCK_N))
_moe_scatter_kernel[grid](
x_ptr=x,
out_ptr=out,
index_ptr=index,
M=M,
N=N,
TOPK=topk,
BLOCK_N=BLOCK_N,
)
return out
# ---------------------------------------------------------------------------
# moe_gather: Reduce expert outputs back to token positions (sum over topk)
# ---------------------------------------------------------------------------
@triton.jit
def _moe_gather_kernel(
x_ptr,
out_ptr,
index_ptr,
M,
N: tl.constexpr,
TOPK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Gather: for each token i, out[i] = sum_k(x[index[i, k]]) over topk."""
pid_m = tl.program_id(0).to(tl.int64)
pid_n = tl.program_id(1)
if pid_m >= M:
return
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
acc = tl.zeros([BLOCK_N], dtype=tl.float32)
for k in tl.static_range(TOPK):
src_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
x_ptrs = x_ptr + src_idx * N + offs_n
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0).to(tl.float32)
acc += x_vals
out_ptrs = out_ptr + pid_m * N + offs_n
tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=n_mask)
def moe_gather(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""Gather and reduce expert outputs back to original token positions.
For each token i, sums x[index[i, k]] over all topk slots.
Args:
x: (M * topk, N) expert outputs in sorted buffer
index: (M, topk) scatter indices (same as used in moe_scatter)
Returns:
out: (M, N) gathered output
"""
M, topk = index.shape
N = x.shape[1]
out = torch.empty(M, N, dtype=x.dtype, device=x.device)
BLOCK_N = min(triton.next_power_of_2(N), 1024)
grid = (M, triton.cdiv(N, BLOCK_N))
_moe_gather_kernel[grid](
x_ptr=x,
out_ptr=out,
index_ptr=index,
M=M,
N=N,
TOPK=topk,
BLOCK_N=BLOCK_N,
)
return out

View File

@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
.. note::
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
without hardcoding version numbers. Two methods are used: weight value inspection
(most reliable) and class name pattern matching (backward compatibility).
This must follow the module's forward semantics. Do not infer it from trained
weight values because standard RMSNorm weights can also be close to zero.
"""
if hasattr(module, "weight") and module.weight is not None:
weight_mean = module.weight.data.mean().item()
if abs(weight_mean) < 0.3:
return True
residual_rmsnorm_classes = {
"Qwen3_5RMSNorm",
"Qwen3_5MoeRMSNorm",
"Qwen3NextRMSNorm",
}
class_name = module.__class__.__name__
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
for pattern in residual_patterns:
if pattern in class_name:
return True
return False
return class_name in residual_rmsnorm_classes
def npu_rms_norm_forward(self, hidden_states):
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
if hasattr(self, "weight") and self.weight is not None:
if _should_use_residual_rmsnorm(self):
if getattr(self, "_npu_use_residual_rmsnorm", False):
effective_weight = 1.0 + self.weight.float()
else:
effective_weight = self.weight.float()
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
if "Gated" in module.__class__.__name__:
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
else:
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model

View File

@@ -58,7 +58,7 @@ class Registry:
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
if get_current_accelerator().type not in device:
return
if not kernel_id:

View File

@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
@@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
if attention_mask is None:
if position_ids is not None:
attention_mask = torch.ones_like(position_ids).to(torch.int64)
else:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
if position_ids is not None:
global_position_ids = [
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_position_ids, position_ids, group=self.spg)
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
attention_mask = None
else:
attention_mask = attention_mask.to(torch.int64)
if attention_mask is None:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
else:
attention_mask = attention_mask.to(torch.int64)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
context_layer = self.attn_fn(
q,

View File

@@ -12,23 +12,272 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from math import ceil
from typing import Any
import torch
from torch.utils.data import default_collate
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import pad_and_truncate
from ...utils.objects import StatefulBuffer
from ...utils.plugin import BasePlugin
from ...utils.types import BatchInfo, BatchInput, DataLoader
from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput
class BatchingPlugin(BasePlugin):
def compute_length(self, data_provider: DataLoader) -> int:
def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int:
"""Return the raw data provider batch size for this batching strategy."""
return self["get_data_provider_batch_size"](batch_info)
def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int:
"""Compute the length of the batch generator.
The approximate length is used to calculate the lr schedule.
"""
raise NotImplementedError()
return self["compute_length"](data_provider, batch_info)
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
def fill_buffer(
self,
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
"""Fill the buffer with data."""
raise NotImplementedError()
return self["fill_buffer"](buffer, batch_info, next_samples)
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
"""Generate a batch from the buffer."""
raise NotImplementedError()
return self["generate_batch"](buffer, batch_info)
def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
"""Return sample counts for micro batches formed by one padded-token budget."""
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
max_sample_len = 0
used = 0
is_complete = False
while index + used < len(samples):
sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len)
padded_tokens = max(max_sample_len, sample_len) * (used + 1)
if used > 0 and padded_tokens > budget:
is_complete = True
break
max_sample_len = max(max_sample_len, sample_len)
used += 1
if max_sample_len * used >= budget:
is_complete = True
break
if used == 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
current_tokens = 0
used = 0
is_complete = False
while index + used < len(samples):
sample = samples[index + used]
sample_len = min(len(sample["input_ids"]), cutoff_len)
if current_tokens + sample_len > budget:
is_complete = True
break
current_tokens += sample_len
used += 1
if used <= 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
"""Pack fixed samples into one padding-free sequence without a token budget."""
packed: dict[str, list[Any]] = {}
position_ids: list[int] = []
for sample_index, sample in enumerate(samples):
# Padding-free still truncates each sample by cutoff_len before packing
# all samples into one contiguous sequence.
sample_len = min(len(sample["input_ids"]), cutoff_len)
if sample_len <= 0:
continue
for key, value in sample.items():
if key in ("attention_mask", "position_ids") or isinstance(value, str):
continue
if key not in packed:
packed[key] = []
sliced_value = list(value[:sample_len])
if sample_index > 0 and sliced_value:
if key == "labels":
sliced_value[0] = IGNORE_INDEX
elif key == "loss_weights":
sliced_value[0] = 0.0
packed[key].extend(sliced_value)
position_ids.extend(range(sample_len))
if not position_ids:
return None
packed["position_ids"] = position_ids
packed["attention_mask"] = [1] * len(position_ids)
return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()}
@BatchingPlugin("padding_free").register("get_data_provider_batch_size")
def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
@BatchingPlugin("padding_free").register("compute_length")
def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
return len(data_provider)
@BatchingPlugin("padding_free").register("fill_buffer")
def fill_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]:
samples = next_samples(False)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("padding_free").register("generate_batch")
def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
num_micro_batch = batch_info["num_micro_batch"]
cutoff_len = batch_info["cutoff_len"]
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return None
samples = buffer.get(batch_size)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len)
if packed_micro_batch is None:
return None
batch.append(packed_micro_batch)
return batch
@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size")
def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_batching").register("compute_length")
def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_batching").register("fill_buffer")
def fill_dynamic_batching_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_batching").register("generate_batch")
def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
return batch
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_padding_free").register("compute_length")
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
def fill_dynamic_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
if packed_batch is None:
return None
batch.append(packed_batch)
return batch

View File

@@ -381,7 +381,7 @@ class FSDP2Engine:
with torch.no_grad():
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
if isinstance(grad_norm, torch.distributed.tensor.DTensor):
grad_norm = grad_norm.full_tensor()
for param in model.parameters():

View File

@@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
if dist_config.get("cp_size", 1) > 1:
raise ValueError("CP currently requires `dist_config.name: fsdp2`.")
from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(
@@ -78,14 +81,14 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
@DistributedPlugin("deepspeed").register("save_checkpoint")
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .deepspeed import save_checkpoint
return save_checkpoint(model, optimizer, ckpt_dir)
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
@DistributedPlugin("deepspeed").register("load_checkpoint")
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .deepspeed import load_checkpoint
return load_checkpoint(model, optimizer, ckpt_dir)
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)

View File

@@ -0,0 +1,183 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from ..accelerator.interface import Dim, DistributedInterface
from ..config import InputArgument, TrainingArguments, get_args
from ..config.arg_utils import ModelClass
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..utils import logging
from ..utils.types import BatchInput, HFModel, Tensor
logger = logging.get_logger(__name__)
def _validate_rm_dataset_format(train_dataset: DataEngine, dataset_path: str) -> None:
"""Validate RM dataset format early for clearer error messages."""
if len(train_dataset) == 0:
raise ValueError(f"RM training dataset is empty: {dataset_path}")
sample = train_dataset[0]
if "chosen_messages" in sample and "rejected_messages" in sample:
return
dataset_name = sample.get("_dataset_name", "unknown")
sample_keys = sorted(sample.keys())
raise ValueError(
"RM training requires pair-format samples containing chosen/rejected responses. "
f"First sample from dataset '{dataset_name}' has keys: {sample_keys}. "
"Please use pair data (e.g. a dataset with chosen_messages/rejected_messages, "
"or set converter='pair' for raw chosen/rejected fields)."
)
def _init_score_head(model: HFModel) -> None:
"""Initialize the score head for RM training with small Gaussian weights.
Uses Gaussian initialization so that different parameters have distinct values,
providing better gradient flow than zero initialization while keeping initial
scores small enough that the starting loss is close to ln(2).
"""
unwrapped = model.module if hasattr(model, "module") else model
score = getattr(unwrapped, "score", None)
if score is not None and hasattr(score, "weight"):
hidden_size = score.weight.shape[-1]
std = 1.0 / (hidden_size * 10)
with torch.no_grad():
score.weight.normal_(mean=0.0, std=std)
if score.bias is not None:
score.bias.zero_()
logger.info_rank0(f"Initialized score head with Gaussian (std={std:.6f}): {score.weight.shape}")
class RMTrainer(BaseTrainer):
def __init__(
self,
args: TrainingArguments,
model: HFModel,
renderer,
train_dataset,
callbacks=None,
) -> None:
cp_size = args.dist_config.get("cp_size", 1) if args.dist_config is not None else 1
if cp_size > 1:
raise NotImplementedError("RM trainer currently only supports cp_size == 1.")
super().__init__(args, model, renderer, train_dataset, callbacks)
def _shard_model(self) -> None:
if self.args.dist_config is None:
if DistributedInterface().get_world_size(Dim.DP) > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
device_ids = None if self.device.type == "cpu" else [self.device.index]
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=True)
else:
super()._shard_model()
@property
def _unwrapped_model(self):
"""Access the underlying model, unwrapping DDP/FSDP wrappers if present."""
model = self.model
if hasattr(model, "module"):
model = model.module
return model
def compute_loss(self, batch: BatchInput) -> Tensor:
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
token_type_ids = batch.get("token_type_ids")
if token_type_ids is None:
raise ValueError(
"RM training requires pair data with token_type_ids. "
"Ensure the dataset has chosen_messages/rejected_messages."
)
token_type_ids = token_type_ids.to(self.device, non_blocking=True)
# Use token_type_ids as document-index attention mask (values: 1=chosen, 2=rejected, 0=padding).
# Transformers v5 models natively support this format in _update_causal_mask,
# constructing the correct block-diagonal causal mask internally for all attention backends.
model_attention_mask = token_type_ids
# Build position_ids that reset at each document boundary.
batch_size, seq_len = token_type_ids.shape
arange = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
chosen_mask = token_type_ids == 1
rejected_mask = token_type_ids == 2
chosen_lens = chosen_mask.sum(dim=1, keepdim=True)
position_ids = torch.zeros_like(token_type_ids)
position_ids[chosen_mask] = arange[chosen_mask]
position_ids[rejected_mask] = (arange - chosen_lens)[rejected_mask]
model_output = self.model(
input_ids=input_ids,
attention_mask=model_attention_mask,
position_ids=position_ids,
use_cache=False,
return_dict=True,
)
rewards = model_output.logits.float().squeeze(-1)
chosen_mask = token_type_ids == 1
rejected_mask = token_type_ids == 2
valid_pair_mask = chosen_mask.any(dim=-1) & rejected_mask.any(dim=-1)
if not torch.any(valid_pair_mask):
raise ValueError(
"No valid RM pairs found in this micro-batch. "
"This is usually caused by cutoff_len being too small and truncating chosen/rejected tokens."
)
rewards = rewards[valid_pair_mask]
chosen_mask = chosen_mask[valid_pair_mask]
rejected_mask = rejected_mask[valid_pair_mask]
seq_len = rewards.size(-1)
position_index = torch.arange(seq_len, device=self.device).unsqueeze(0)
chosen_last_idx = (position_index * chosen_mask.long()).max(dim=-1).values
rejected_last_idx = (position_index * rejected_mask.long()).max(dim=-1).values
chosen_scores = rewards.gather(dim=1, index=chosen_last_idx.unsqueeze(-1)).squeeze(-1)
rejected_scores = rewards.gather(dim=1, index=rejected_last_idx.unsqueeze(-1)).squeeze(-1)
return -F.logsigmoid(chosen_scores - rejected_scores).mean()
def run_rm(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
model_args.model_class = ModelClass.CLS
DistributedInterface(training_args.dist_config)
train_dataset = DataEngine(data_args.train_dataset)
_validate_rm_dataset_format(train_dataset, data_args.train_dataset)
model_engine = ModelEngine(model_args, is_train=True)
_init_score_head(model_engine.model)
trainer = RMTrainer(
args=training_args,
model=model_engine.model,
renderer=model_engine.renderer,
train_dataset=train_dataset,
)
trainer.fit()
trainer.save_model()
DistributedInterface().destroy()
if __name__ == "__main__":
run_rm()

View File

@@ -53,7 +53,7 @@ class LoggingCallback(TrainerCallback):
return
# Human-readable output to stdout
display_logs = {**logs, "total_steps": state.num_training_steps}
display_logs = {**logs, "step": state.global_step, "total_steps": state.num_training_steps}
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
logger.info_rank0(parts)

View File

@@ -13,22 +13,46 @@
# limitations under the License.
import random
import numpy as np
import torch
from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.helper import is_torch_npu_available
from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
def enable_full_determinism(seed: int) -> None:
"""Enable full deterministic mode for reproducible distributed training."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
if is_torch_npu_available():
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
def set_seed(seed: int, full_determinism: bool = False) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
full_determinism: Whether to enable full deterministic mode.
"""
hf_set_seed(seed)
if full_determinism:
enable_full_determinism(seed)
else:
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool:

View File

@@ -33,6 +33,10 @@ class StatefulBuffer:
def size(self) -> int:
return self._buffer_size
@property
def samples(self) -> list[ModelInput]:
return self._buffer
def put(self, samples: list[ModelInput]) -> None:
"""Add samples to the buffer."""
num_tokens = sum(len(sample["input_ids"]) for sample in samples)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from enum import StrEnum, unique
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
@@ -54,6 +54,13 @@ else:
ProcessGroup = None
@unique
class AttentionFunction(StrEnum):
EAGER = "eager"
SDPA = "sdpa"
FLASH_ATTENTION_2 = "flash_attention_2"
class DatasetInfo(TypedDict, total=False):
path: str
"""Local file path."""
@@ -171,8 +178,6 @@ class BatchInfo(TypedDict):
"""Number of micro batches."""
cutoff_len: int
"""Cutoff length."""
data_iter: Iterator[list[ModelInput]]
"""Data iterator."""
class ModelOutput(NamedTuple):

View File

@@ -58,10 +58,3 @@ def test_multi_device():
master_port = find_available_port()
world_size = 2
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
if __name__ == "__main__":
"""
python tests_v1/accelerator/test_interface.py
"""
test_all_device()

View File

@@ -70,13 +70,3 @@ def test_get_args_from_yaml(tmp_path: Path):
assert training_args.bf16 is False
assert training_args.dist_config is None
assert sample_args.sample_backend == "hf"
if __name__ == "__main__":
"""
python -m tests_v1.config.test_args_parser
"""
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
test_get_args_from_yaml(tmp_path=Path(tmp_dir))

View File

@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
for index in indexes:
print(data_engine[index])
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
if __name__ == "__main__":
"""
python -m tests_v1.core.test_data_engine
"""
test_map_dataset(1)

View File

@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
if __name__ == "__main__":
"""
python -m tests_v1.core.test_model_loader
"""
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -16,6 +16,164 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.plugins.trainer_plugins.batching import (
BatchingPlugin,
_get_dynamic_micro_batch_sizes,
_get_dynamic_padding_free_micro_batch_sizes,
)
from llamafactory.v1.utils.constants import IGNORE_INDEX
from llamafactory.v1.utils.objects import StatefulBuffer
def _make_model_input(length: int, start: int = 0):
input_ids = list(range(start, start + length))
return {
"input_ids": input_ids,
"attention_mask": [1] * length,
"labels": input_ids.copy(),
"loss_weights": [1.0] * length,
}
class _RestartableDataProvider:
def __init__(self, batches):
self.batches = batches
self.num_iters = 0
def __iter__(self):
self.num_iters += 1
return iter(self.batches)
def test_padding_free():
buffer = StatefulBuffer()
# Input samples:
# sample 0 input_ids: [0, 1]
# sample 1 input_ids: [10, 11, 12, 13]
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
# Output batch:
# sample 1 is truncated to [10, 11, 12]
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 5)
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
assert len(buffer) == 0
def test_batching_plugin_data_provider_batch_sizes():
batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 3,
"cutoff_len": 10,
}
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
def test_dynamic_batching():
# Input samples:
# sample lengths: [3, 4, 6, 2, 8, 9]
# input_ids:
# [0, 1, 2]
# [10, 11, 12, 13]
# [20, 21, 22, 23, 24, 25]
# [30, 31]
# [40, 41, 42, 43, 44, 45, 46, 47]
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Dynamic batching output plan:
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
# as the padded-token budget for one training micro batch.
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
assert len(buffer) == 3
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
# Input buffer:
# only one sample with length [6].
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
# this buffer has not filled the budget and has no next sample to prove overflow,
# so dynamic batching cannot produce a batch yet.
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
# Input data provider:
# each iterator pass yields one sample with length [6].
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
# _fill_buffer keeps restarting the iterator until the next appended sample
# proves that the previous dynamic micro batch has reached its budget boundary.
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_batching"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
batch_generator._fill_buffer()
# Filled buffer after restart:
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
# because adding it to the first dynamic micro batch would exceed the budget.
assert data_provider.num_iters == 4
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic batching returns [micro_batch_0]
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert len(batch_generator._buffer) == 1
def test_normal_batching():
@@ -45,8 +203,166 @@ def test_normal_batching():
assert batch[0]["input_ids"].shape == (4, 10)
if __name__ == "__main__":
def test_dynamic_padding_free():
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
# input_ids breakdown:
# sample 0: [0,1,2] (length=3)
# sample 1: [10,11,12,13] (length=4)
# sample 2: [20,21,22,23,24,25] (length=6)
# sample 3: [30,31] (length=2)
# sample 4: [40-47] (length=8)
# sample 5: [50-58] (length=9)
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1 # num_micro_batch=1
packed_batch = batch[0]
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
assert packed_batch["input_ids"].shape == (1, 15)
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
assert packed_batch["input_ids"].tolist() == [
[
0,
1,
2, # Sample 0
10,
11,
12,
13, # Sample 1
20,
21,
22,
23,
24,
25, # Sample 2
30,
31,
] # Sample 3
]
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
assert packed_batch["labels"].tolist() == [
[
0,
1,
2, # Sample 0
IGNORE_INDEX,
11,
12,
13, # Sample 1
IGNORE_INDEX,
21,
22,
23,
24,
25, # Sample 2
IGNORE_INDEX,
31,
] # Sample 3
]
# Verify attention_mask
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
# Verify position_ids
assert packed_batch["position_ids"].tolist() == [
[
0,
1,
2, # Sample 0
0,
1,
2,
3, # Sample 1
0,
1,
2,
3,
4,
5, # Sample 2
0,
1,
] # Sample 3
]
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
assert len(buffer) == 2
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
Data provider yields one sample of length 6 per iteration.
_fill_buffer keeps restarting iterator until next sample exceeds budget.
Budget = 2 * 10 = 20 tokens.
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
So buffer will have 4 samples after fill_buffer.
"""
python -m tests_v1.core.utils.test_batching
"""
test_normal_batching()
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_padding_free"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
batch_generator._fill_buffer()
# Buffer after restarts:
# 3 samples can fit (18 tokens)
# 4th sample is kept in buffer for next batch
# => num_iters = 4
assert data_provider.num_iters == 4
assert _get_dynamic_padding_free_micro_batch_sizes(
batch_generator._buffer.samples, batch_generator._batch_info
) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic_padding_free returns [micro_batch_0]
# 3 samples packed into shape [1, 18]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 18)
assert len(batch_generator._buffer) == 1

View File

@@ -227,17 +227,3 @@ def test_process_dpo_samples():
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_rendering
"""
test_chatml_rendering()
test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)
test_process_sft_samples()
test_process_dpo_samples()

View File

@@ -117,12 +117,3 @@ def test_pair_converter(num_samples: int):
],
}
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
if __name__ == "__main__":
"""
python -m tests_v1.plugins.data_plugins.test_converter
"""
test_alpaca_converter(1)
test_sharegpt_converter()
test_pair_converter(1)

View File

@@ -52,12 +52,3 @@ def test_init_on_default():
)
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device
if __name__ == "__main__":
"""
python tests_v1/plugins/model_plugins/test_init_plugin.py
"""
test_init_on_meta()
test_init_on_rank0()
test_init_on_default()

View File

@@ -35,10 +35,3 @@ def test_sync_sampler():
"role": "assistant",
"content": [{"type": "text", "value": "This is a test."}],
}
if __name__ == "__main__":
"""
python tests_v1/sampler/test_cli_sampler.py
"""
test_sync_sampler()