8 Commits

Author SHA1 Message Date
Yaowei Zheng
7af909522a [version] release v0.9.5 (#10532) 2026-05-30 23:57:09 +08:00
xvxuopop
e016d2480e [fix] Fix NPU FusedMoE and RMSNorm (#10512) 2026-05-30 21:42:54 +08:00
jiaqiw09
7d719182c9 [model] fix non-packing batch (bsz>1) for Qwen3.5 with flash attention (#10529) 2026-05-30 21:41:41 +08:00
jiaqiw09
01398eb18d [v1] fix padding free with sp (#10513) 2026-05-26 23:49:21 +08:00
cxy
8e68764b65 [v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
2026-05-25 20:40:21 +08:00
Copilot
16ff5a23cb [fix] use getattr for profiler attrs to support MCA TrainingArguments (#10506)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
2026-05-21 17:26:29 +08:00
jiaqiw09
bdcb92d035 [v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469) 2026-05-21 17:14:19 +08:00
sunyi0505
7e20db5735 [v1] support liger_kernel (#10493) 2026-05-21 11:44:56 +08:00
39 changed files with 1130 additions and 188 deletions

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
\[ English | [中文](README_zh.md) \]
@@ -52,14 +50,11 @@ Start local training:
Start cloud training:
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
Read technical notes:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
- **Official Blog**: https://blog.llamafactory.net/en/
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -78,7 +73,6 @@ Read technical notes:
- [Data Preparation](#data-preparation)
- [Quickstart](#quickstart)
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
- [LLaMA Factory Online](#llama-factory-online)
- [Build Docker](#build-docker)
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
- [Download from ModelScope Hub](#download-from-modelscope-hub)
@@ -117,15 +111,11 @@ Read technical notes:
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
<details><summary>All Blogs</summary>
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
@@ -661,10 +651,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
llamafactory-cli webui
```
### LLaMA Factory Online
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
### Build Docker
For CUDA users:

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
\[ [English](README.md) | 中文 \]
@@ -52,8 +50,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
开始云端训练:
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online在线微调**https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
阅读技术文档:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
@@ -61,7 +57,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **官方博客**https://blog.llamafactory.net/
- **官方课程**https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [数据准备](#数据准备)
- [快速开始](#快速开始)
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
- [构建 Docker](#构建-docker)
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
- [从魔搭社区下载](#从魔搭社区下载)
@@ -119,15 +113,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
<details><summary>全部博客</summary>
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
- [LLaMA Factory微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
@@ -662,10 +652,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
llamafactory-cli webui
```
### LLaMA Factory Online 在线微调
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
### 构建 Docker
CUDA 用户:

View File

@@ -0,0 +1,31 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: normal
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: dynamic_batching
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: dynamic_padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,28 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: liger_kernel
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.5.dev0"
VERSION = "0.9.5"
def print_env() -> None:

View File

@@ -162,8 +162,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
if position_ids is not None and batch_size == 1:
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
else:
cu_seqlens = None
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.

View File

@@ -20,7 +20,6 @@ import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
@@ -584,7 +583,7 @@ class ModuleProfilerCallback(TrainerCallback):
if matched:
logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
)
else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@@ -616,7 +615,7 @@ class ModuleProfilerCallback(TrainerCallback):
bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
logger.info_rank0("\n".join(lines))
self._forward_times.clear()

View File

@@ -80,10 +80,10 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
if training_args.enable_torch_profiler:
if getattr(training_args, "enable_torch_profiler", False):
callbacks.append(TorchProfilerCallback(training_args))
if training_args.profile_modules:
if getattr(training_args, "profile_modules", None):
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.types import AttentionFunction
from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
__all__ = [
"AttentionFunction",
"BatchingStrategy",
"DataArguments",
"InputArgument",

View File

@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
model_args, data_args, training_args, sample_args = parsed_args
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
return tuple(parsed_args)
return model_args, data_args, training_args, sample_args
if __name__ == "__main__":

View File

@@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from ..utils.types import AttentionFunction
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -32,6 +33,12 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
flash_attn: AttentionFunction = field(
default=AttentionFunction.SDPA,
metadata={
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
@@ -54,6 +61,12 @@ class ModelArguments:
)
def __post_init__(self) -> None:
supported_flash_attn = [item.value for item in AttentionFunction]
if self.flash_attn not in supported_flash_attn:
raise ValueError(
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
)
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
full_determinism: bool = field(
default=False,
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
)
resume_from_checkpoint: str | None = field(
default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
@@ -116,3 +120,9 @@ class TrainingArguments:
self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
if self.max_steps is None or self.max_steps <= 0:
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
if self.save_epochs is not None:
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")

View File

@@ -34,7 +34,7 @@ import torch.nn.functional as F
from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..config import BatchingStrategy, TrainingArguments
from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
@@ -147,13 +147,19 @@ class BaseTrainer:
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
raise ValueError(
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
)
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None:
if (
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
):
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
renderer=self.renderer,
@@ -237,6 +243,7 @@ class BaseTrainer:
self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
for micro_batches in self.train_batch_generator:
self.global_step += 1

View File

@@ -120,6 +120,7 @@ class ModelEngine:
init_device = DistributedInterface().current_device
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
@@ -164,6 +165,7 @@ class ModelEngine:
self.args.model,
config=self.model_config,
dtype="auto",
attn_implementation=self.args.flash_attn,
trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
)
@@ -188,9 +190,12 @@ class ModelEngine:
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model, include_kernels=self.args.kernel_config.get("include_kernels")
)
kernel_config = self.args.kernel_config
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
if kernel_config.name == "liger_kernel":
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
kernel_kwargs["require_logits"] = self.is_train
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
return model

View File

@@ -42,6 +42,8 @@ from .rendering import Renderer
logger = logging.get_logger(__name__)
__all__ = ["BatchGenerator"]
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
if not self.drop_last:
raise ValueError("Drop last must be True.")
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
}
self._init_data_provider()
self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer = StatefulBuffer()
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
"data_iter": self._data_iter,
}
logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
@@ -137,12 +138,19 @@ class BatchGenerator(Iterator):
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
if self.batching_strategy == BatchingStrategy.NORMAL:
batch_size = self.micro_batch_size * self.num_micro_batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
generator_seed = torch.Generator()
generator_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
batch_size=batch_size,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
@@ -156,8 +164,7 @@ class BatchGenerator(Iterator):
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
def __len__(self) -> int:
return self._length
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
def _generate_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
try:
return next(self._data_iter)
except StopIteration:
if not restart:
return None
# Dynamic batching may restart the provider to fill one token-budgeted batch.
self._data_iter = iter(self._data_provider)
try:
return next(self._data_iter)
except StopIteration:
return None
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer.state_dict(),

View File

@@ -34,7 +34,7 @@ class BaseKernel(ABC):
"""
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
_device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc.
@classmethod
def get_kernel_id(cls) -> str:
@@ -42,8 +42,8 @@ class BaseKernel(ABC):
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
def get_device(cls) -> list[DeviceType]:
"""Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"])."""
return cls._device
@classmethod
@@ -58,7 +58,7 @@ class BaseKernel(ABC):
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if cls._device != get_current_accelerator().type:
if get_current_accelerator().type not in cls._device:
return False
return True

View File

@@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode
apply_kernel(kernel, model=model)
return model
@KernelPlugin("liger_kernel").register()
def apply_liger_kernels(
model: HFModel,
include_kernels: str = None,
require_logits: bool = False,
) -> HFModel:
"""Applies Liger kernel to the model.
Args:
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with
library defaults. If a comma-separated list (e.g.
``rope,rms_norm``), enable only those ops; names match
``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``,
``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``.
If ``None`` or ``False``, do nothing. Defaults to ``None``.
require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor
of non-fused CE so the forward pass returns ``logits``. Needed
for trainers that compute weighted loss from logits (e.g. v1
SFT with ``loss_weights``). Defaults to ``False`` (fused CE
when supported). The v1 ``run_sft`` entrypoint sets
``require_logits`` to true for ``liger_kernel`` when the key
is omitted so SFT weighted loss keeps working.
Returns:
HFModel: The model with Liger kernel applied.
"""
if not include_kernels:
return model
if include_kernels == "auto" or include_kernels is True:
use_kernels = "auto"
else:
use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()]
if not use_kernels:
return model
try:
from .liger_kernel_ops import LigerKernel
except ImportError as e:
logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}")
return model
return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits)

View File

@@ -0,0 +1,148 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of Liger Kernel.
Init Phase:
1. Define LigerKernel class.
2. Register Liger kernel.
"""
import inspect
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.logging import get_logger
from ....utils.types import HFModel
from .base import BaseKernel
logger = get_logger(__name__)
_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = {
"qwen3": "apply_liger_kernel_to_qwen3",
"qwen3_moe": "apply_liger_kernel_to_qwen3_moe",
"qwen3_next": "apply_liger_kernel_to_qwen3_next",
"qwen3_5": "apply_liger_kernel_to_qwen3_5",
"qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text",
"qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe",
"qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text",
}
class LigerKernel(BaseKernel):
"""Liger Kernel for optimized model training."""
_device = [DeviceType.CUDA, DeviceType.NPU]
@classmethod
def check_deps(cls) -> bool:
"""Checks if the required dependencies for the kernel are available."""
try:
import liger_kernel # noqa: F401
return super().check_deps()
except ImportError:
logger.warning_rank0(
"Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel."
)
return False
@classmethod
def apply(cls, **kwargs) -> "HFModel":
"""Applies the Liger kernel to the model.
Args:
**kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op
names to enable exclusively, or the string ``"auto"`` to use each
``apply_liger_kernel_to_*`` function's signature defaults (same as calling
upstream with only ``model``). Optional ``require_logits`` forces non-fused
cross entropy when supported.
Returns:
HFModel: The model with Liger kernel applied.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model")
use_kernels = kwargs.get("use_kernels", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(
f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}"
)
require_logits = kwargs.get("require_logits", False)
model_type = getattr(model.config, "model_type", None)
if model_type not in _LIGER_FN_BY_MODEL_TYPE:
logger.warning_rank0("Current model does not support liger kernel.")
return model
import liger_kernel.transformers as liger_transformers
apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type])
sig = inspect.signature(apply_liger_kernel).parameters
togglable = [name for name in sig if name != "model"]
def _normalize_op_name(raw: str) -> str:
key = raw.strip().lower().replace("-", "_")
aliases = {
"rmsnorm": "rms_norm",
"flce": "fused_linear_cross_entropy",
"lce": "fused_linear_cross_entropy",
"fused_ce": "fused_linear_cross_entropy",
}
return aliases.get(key, key)
if use_kernels is not None and len(use_kernels) == 0:
return model
if use_kernels != "auto":
selected = {_normalize_op_name(k) for k in use_kernels}
ops = selected - set(togglable)
if ops:
raise ValueError(
f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}"
)
if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected:
raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.")
call_kwargs = {name: (name in selected) for name in togglable}
call_kwargs["model"] = model
else:
# Mirror ``liger_kernel`` signature defaults so patches match upstream defaults
# and logging reflects enabled ops (omitted kwargs only live in the callee).
call_kwargs = {"model": model}
for name in togglable:
param = sig[name]
if param.default is not inspect.Parameter.empty:
call_kwargs[name] = param.default
if require_logits and "fused_linear_cross_entropy" in sig:
logger.warning_rank0("Current training stage does not support chunked cross entropy.")
call_kwargs["fused_linear_cross_entropy"] = False
call_kwargs["cross_entropy"] = True
apply_liger_kernel(**call_kwargs)
applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on)
logger.info_rank0(f"These Liger ops are applied to the model: {applied}")
return model

View File

@@ -228,6 +228,30 @@ class NpuMoeFused:
routed_out = self.experts(hidden_states, routing_weights, router_indices)
return routed_out
@staticmethod
def npu_moe_experts_v5_forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
Transformers v5 stores expert weights in F.linear layout:
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
down_proj: [num_experts, hidden_dim, intermediate_dim]
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
"""
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
hidden_states, top_k_index.to(torch.int32)
)
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
gate_up_proj = self.gate_up_proj.transpose(1, 2)
down_proj = self.down_proj.transpose(1, 2)
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
class Qwen3NpuMoeFused:
"""Container for Qwen3 NPU fused MoE forward functions."""
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
# moe patch config mapping
kernel_moe_mapping = {
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
if is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForCausalLM": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForConditionalGeneration": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
}
}
if not is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
else:
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
},
}

View File

@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
.. note::
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
without hardcoding version numbers. Two methods are used: weight value inspection
(most reliable) and class name pattern matching (backward compatibility).
This must follow the module's forward semantics. Do not infer it from trained
weight values because standard RMSNorm weights can also be close to zero.
"""
if hasattr(module, "weight") and module.weight is not None:
weight_mean = module.weight.data.mean().item()
if abs(weight_mean) < 0.3:
return True
residual_rmsnorm_classes = {
"Qwen3_5RMSNorm",
"Qwen3_5MoeRMSNorm",
"Qwen3NextRMSNorm",
}
class_name = module.__class__.__name__
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
for pattern in residual_patterns:
if pattern in class_name:
return True
return False
return class_name in residual_rmsnorm_classes
def npu_rms_norm_forward(self, hidden_states):
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
if hasattr(self, "weight") and self.weight is not None:
if _should_use_residual_rmsnorm(self):
if getattr(self, "_npu_use_residual_rmsnorm", False):
effective_weight = 1.0 + self.weight.float()
else:
effective_weight = self.weight.float()
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
if "Gated" in module.__class__.__name__:
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
else:
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model

View File

@@ -58,7 +58,7 @@ class Registry:
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
if get_current_accelerator().type not in device:
return
if not kernel_id:

View File

@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
@@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
if attention_mask is None:
if position_ids is not None:
attention_mask = torch.ones_like(position_ids).to(torch.int64)
else:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
if position_ids is not None:
global_position_ids = [
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_position_ids, position_ids, group=self.spg)
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
attention_mask = None
else:
attention_mask = attention_mask.to(torch.int64)
if attention_mask is None:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
else:
attention_mask = attention_mask.to(torch.int64)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
context_layer = self.attn_fn(
q,

View File

@@ -12,23 +12,272 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from math import ceil
from typing import Any
import torch
from torch.utils.data import default_collate
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import pad_and_truncate
from ...utils.objects import StatefulBuffer
from ...utils.plugin import BasePlugin
from ...utils.types import BatchInfo, BatchInput, DataLoader
from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput
class BatchingPlugin(BasePlugin):
def compute_length(self, data_provider: DataLoader) -> int:
def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int:
"""Return the raw data provider batch size for this batching strategy."""
return self["get_data_provider_batch_size"](batch_info)
def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int:
"""Compute the length of the batch generator.
The approximate length is used to calculate the lr schedule.
"""
raise NotImplementedError()
return self["compute_length"](data_provider, batch_info)
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
def fill_buffer(
self,
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
"""Fill the buffer with data."""
raise NotImplementedError()
return self["fill_buffer"](buffer, batch_info, next_samples)
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
"""Generate a batch from the buffer."""
raise NotImplementedError()
return self["generate_batch"](buffer, batch_info)
def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
"""Return sample counts for micro batches formed by one padded-token budget."""
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
max_sample_len = 0
used = 0
is_complete = False
while index + used < len(samples):
sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len)
padded_tokens = max(max_sample_len, sample_len) * (used + 1)
if used > 0 and padded_tokens > budget:
is_complete = True
break
max_sample_len = max(max_sample_len, sample_len)
used += 1
if max_sample_len * used >= budget:
is_complete = True
break
if used == 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
current_tokens = 0
used = 0
is_complete = False
while index + used < len(samples):
sample = samples[index + used]
sample_len = min(len(sample["input_ids"]), cutoff_len)
if current_tokens + sample_len > budget:
is_complete = True
break
current_tokens += sample_len
used += 1
if used <= 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
"""Pack fixed samples into one padding-free sequence without a token budget."""
packed: dict[str, list[Any]] = {}
position_ids: list[int] = []
for sample_index, sample in enumerate(samples):
# Padding-free still truncates each sample by cutoff_len before packing
# all samples into one contiguous sequence.
sample_len = min(len(sample["input_ids"]), cutoff_len)
if sample_len <= 0:
continue
for key, value in sample.items():
if key in ("attention_mask", "position_ids") or isinstance(value, str):
continue
if key not in packed:
packed[key] = []
sliced_value = list(value[:sample_len])
if sample_index > 0 and sliced_value:
if key == "labels":
sliced_value[0] = IGNORE_INDEX
elif key == "loss_weights":
sliced_value[0] = 0.0
packed[key].extend(sliced_value)
position_ids.extend(range(sample_len))
if not position_ids:
return None
packed["position_ids"] = position_ids
packed["attention_mask"] = [1] * len(position_ids)
return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()}
@BatchingPlugin("padding_free").register("get_data_provider_batch_size")
def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
@BatchingPlugin("padding_free").register("compute_length")
def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
return len(data_provider)
@BatchingPlugin("padding_free").register("fill_buffer")
def fill_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]:
samples = next_samples(False)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("padding_free").register("generate_batch")
def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
num_micro_batch = batch_info["num_micro_batch"]
cutoff_len = batch_info["cutoff_len"]
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return None
samples = buffer.get(batch_size)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len)
if packed_micro_batch is None:
return None
batch.append(packed_micro_batch)
return batch
@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size")
def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_batching").register("compute_length")
def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_batching").register("fill_buffer")
def fill_dynamic_batching_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_batching").register("generate_batch")
def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
return batch
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_padding_free").register("compute_length")
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
def fill_dynamic_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
if packed_batch is None:
return None
batch.append(packed_batch)
return batch

View File

@@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
if dist_config.get("cp_size", 1) > 1:
raise ValueError("CP currently requires `dist_config.name: fsdp2`.")
from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(

View File

@@ -13,22 +13,46 @@
# limitations under the License.
import random
import numpy as np
import torch
from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.helper import is_torch_npu_available
from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
def enable_full_determinism(seed: int) -> None:
"""Enable full deterministic mode for reproducible distributed training."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
if is_torch_npu_available():
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
def set_seed(seed: int, full_determinism: bool = False) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
full_determinism: Whether to enable full deterministic mode.
"""
hf_set_seed(seed)
if full_determinism:
enable_full_determinism(seed)
else:
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool:

View File

@@ -33,6 +33,10 @@ class StatefulBuffer:
def size(self) -> int:
return self._buffer_size
@property
def samples(self) -> list[ModelInput]:
return self._buffer
def put(self, samples: list[ModelInput]) -> None:
"""Add samples to the buffer."""
num_tokens = sum(len(sample["input_ids"]) for sample in samples)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from enum import StrEnum, unique
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
@@ -54,6 +54,13 @@ else:
ProcessGroup = None
@unique
class AttentionFunction(StrEnum):
EAGER = "eager"
SDPA = "sdpa"
FLASH_ATTENTION_2 = "flash_attention_2"
class DatasetInfo(TypedDict, total=False):
path: str
"""Local file path."""
@@ -171,8 +178,6 @@ class BatchInfo(TypedDict):
"""Number of micro batches."""
cutoff_len: int
"""Cutoff length."""
data_iter: Iterator[list[ModelInput]]
"""Data iterator."""
class ModelOutput(NamedTuple):

View File

@@ -58,10 +58,3 @@ def test_multi_device():
master_port = find_available_port()
world_size = 2
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
if __name__ == "__main__":
"""
python tests_v1/accelerator/test_interface.py
"""
test_all_device()

View File

@@ -70,13 +70,3 @@ def test_get_args_from_yaml(tmp_path: Path):
assert training_args.bf16 is False
assert training_args.dist_config is None
assert sample_args.sample_backend == "hf"
if __name__ == "__main__":
"""
python -m tests_v1.config.test_args_parser
"""
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
test_get_args_from_yaml(tmp_path=Path(tmp_dir))

View File

@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
for index in indexes:
print(data_engine[index])
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
if __name__ == "__main__":
"""
python -m tests_v1.core.test_data_engine
"""
test_map_dataset(1)

View File

@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
if __name__ == "__main__":
"""
python -m tests_v1.core.test_model_loader
"""
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -16,6 +16,164 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.plugins.trainer_plugins.batching import (
BatchingPlugin,
_get_dynamic_micro_batch_sizes,
_get_dynamic_padding_free_micro_batch_sizes,
)
from llamafactory.v1.utils.constants import IGNORE_INDEX
from llamafactory.v1.utils.objects import StatefulBuffer
def _make_model_input(length: int, start: int = 0):
input_ids = list(range(start, start + length))
return {
"input_ids": input_ids,
"attention_mask": [1] * length,
"labels": input_ids.copy(),
"loss_weights": [1.0] * length,
}
class _RestartableDataProvider:
def __init__(self, batches):
self.batches = batches
self.num_iters = 0
def __iter__(self):
self.num_iters += 1
return iter(self.batches)
def test_padding_free():
buffer = StatefulBuffer()
# Input samples:
# sample 0 input_ids: [0, 1]
# sample 1 input_ids: [10, 11, 12, 13]
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
# Output batch:
# sample 1 is truncated to [10, 11, 12]
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 5)
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
assert len(buffer) == 0
def test_batching_plugin_data_provider_batch_sizes():
batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 3,
"cutoff_len": 10,
}
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
def test_dynamic_batching():
# Input samples:
# sample lengths: [3, 4, 6, 2, 8, 9]
# input_ids:
# [0, 1, 2]
# [10, 11, 12, 13]
# [20, 21, 22, 23, 24, 25]
# [30, 31]
# [40, 41, 42, 43, 44, 45, 46, 47]
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Dynamic batching output plan:
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
# as the padded-token budget for one training micro batch.
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
assert len(buffer) == 3
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
# Input buffer:
# only one sample with length [6].
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
# this buffer has not filled the budget and has no next sample to prove overflow,
# so dynamic batching cannot produce a batch yet.
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
# Input data provider:
# each iterator pass yields one sample with length [6].
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
# _fill_buffer keeps restarting the iterator until the next appended sample
# proves that the previous dynamic micro batch has reached its budget boundary.
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_batching"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
batch_generator._fill_buffer()
# Filled buffer after restart:
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
# because adding it to the first dynamic micro batch would exceed the budget.
assert data_provider.num_iters == 4
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic batching returns [micro_batch_0]
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert len(batch_generator._buffer) == 1
def test_normal_batching():
@@ -45,8 +203,166 @@ def test_normal_batching():
assert batch[0]["input_ids"].shape == (4, 10)
if __name__ == "__main__":
def test_dynamic_padding_free():
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
# input_ids breakdown:
# sample 0: [0,1,2] (length=3)
# sample 1: [10,11,12,13] (length=4)
# sample 2: [20,21,22,23,24,25] (length=6)
# sample 3: [30,31] (length=2)
# sample 4: [40-47] (length=8)
# sample 5: [50-58] (length=9)
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1 # num_micro_batch=1
packed_batch = batch[0]
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
assert packed_batch["input_ids"].shape == (1, 15)
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
assert packed_batch["input_ids"].tolist() == [
[
0,
1,
2, # Sample 0
10,
11,
12,
13, # Sample 1
20,
21,
22,
23,
24,
25, # Sample 2
30,
31,
] # Sample 3
]
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
assert packed_batch["labels"].tolist() == [
[
0,
1,
2, # Sample 0
IGNORE_INDEX,
11,
12,
13, # Sample 1
IGNORE_INDEX,
21,
22,
23,
24,
25, # Sample 2
IGNORE_INDEX,
31,
] # Sample 3
]
# Verify attention_mask
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
# Verify position_ids
assert packed_batch["position_ids"].tolist() == [
[
0,
1,
2, # Sample 0
0,
1,
2,
3, # Sample 1
0,
1,
2,
3,
4,
5, # Sample 2
0,
1,
] # Sample 3
]
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
assert len(buffer) == 2
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
Data provider yields one sample of length 6 per iteration.
_fill_buffer keeps restarting iterator until next sample exceeds budget.
Budget = 2 * 10 = 20 tokens.
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
So buffer will have 4 samples after fill_buffer.
"""
python -m tests_v1.core.utils.test_batching
"""
test_normal_batching()
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_padding_free"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
batch_generator._fill_buffer()
# Buffer after restarts:
# 3 samples can fit (18 tokens)
# 4th sample is kept in buffer for next batch
# => num_iters = 4
assert data_provider.num_iters == 4
assert _get_dynamic_padding_free_micro_batch_sizes(
batch_generator._buffer.samples, batch_generator._batch_info
) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic_padding_free returns [micro_batch_0]
# 3 samples packed into shape [1, 18]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 18)
assert len(batch_generator._buffer) == 1

View File

@@ -227,17 +227,3 @@ def test_process_dpo_samples():
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_rendering
"""
test_chatml_rendering()
test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)
test_process_sft_samples()
test_process_dpo_samples()

View File

@@ -117,12 +117,3 @@ def test_pair_converter(num_samples: int):
],
}
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
if __name__ == "__main__":
"""
python -m tests_v1.plugins.data_plugins.test_converter
"""
test_alpaca_converter(1)
test_sharegpt_converter()
test_pair_converter(1)

View File

@@ -52,12 +52,3 @@ def test_init_on_default():
)
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device
if __name__ == "__main__":
"""
python tests_v1/plugins/model_plugins/test_init_plugin.py
"""
test_init_on_meta()
test_init_on_rank0()
test_init_on_default()

View File

@@ -35,10 +35,3 @@ def test_sync_sampler():
"role": "assistant",
"content": [{"type": "text", "value": "This is a test."}],
}
if __name__ == "__main__":
"""
python tests_v1/sampler/test_cli_sampler.py
"""
test_sync_sampler()