12 Commits

Author SHA1 Message Date
Shanay Mehta
aab9b400bb [model] Add DeepSpeed Z3 leaf module for Qwen3-Next (#10194)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 19:54:37 +08:00
P. Clawmogorov
50599c719b [misc] remove safe_serialization arg for transformers v5 compatibility (#10208)
Co-authored-by: P. Clawmogorov <262173731+Alm0stSurely@users.noreply.github.com>
2026-02-24 11:14:19 +08:00
Kingsley
a0f3ad0cee [mca] update supported models (#10196) 2026-02-20 22:02:49 +08:00
jiaqiw09
f80e15dbb4 [ci] fix ut huggingface hub 429 error when transformers>=5.0.0 (#10155) 2026-02-12 22:14:10 +08:00
sunyi0505
991267fd3b [v1] support quantization (#10161) 2026-02-12 20:37:41 +08:00
浮梦
5c52afa30d [v1] support deepspeed (#10181) 2026-02-12 17:24:30 +08:00
Junyou Su
675ce8cc7f [algo] add ASFT (#10174) 2026-02-12 13:12:14 +08:00
jiaqiw09
ab073f4c13 [v1] add LoRA/Freeze support and merge workflow (#10157) 2026-02-12 13:02:09 +08:00
Shanay Mehta
184304b5b4 [model] add liger kernel support for Qwen3-Next (#10176)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 21:47:48 +08:00
Xue Yadong
d3ebd5678d [model] support GLM-OCR SFT (#10183) 2026-02-10 21:41:01 +08:00
浮梦
1d5e8ebcd0 [v1] init commit for v1 docs (#10145)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
Co-authored-by: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-02-09 19:43:55 +08:00
Shanay Mehta
ea644d04ec [model] support GLM-4.7-Flash SFT (#10173)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 10:40:44 +08:00
98 changed files with 3748 additions and 79 deletions

77
.github/workflows/docs.yml vendored Normal file
View File

@@ -0,0 +1,77 @@
name: Build and Deploy Sphinx Docs
on:
push:
branches: ["main"]
paths:
- "docs/**"
pull_request:
branches: ["main"]
paths:
- "docs/**"
workflow_dispatch:
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: "pages"
cancel-in-progress: false
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
- name: Build Sphinx
run: |
sphinx-build -b html docs/zh docs/_build/html/zh
sphinx-build -b html docs/en docs/_build/html/en
printf '%s\n' \
'<!DOCTYPE html>' \
'<html>' \
' <head>' \
' <meta charset="utf-8" />' \
' <meta http-equiv="refresh" content="0; url=zh/index.html" />' \
' <script>window.location.href="zh/index.html"+window.location.search+window.location.hash;</script>' \
' <title>Redirecting...</title>' \
' </head>' \
' <body>' \
' <a href="zh/index.html">Redirecting...</a>' \
' </body>' \
'</html>' \
> docs/_build/html/index.html
touch docs/_build/html/.nojekyll
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: docs/_build/html
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
needs: build
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

View File

@@ -61,6 +61,7 @@ jobs:
uv venv
uv pip install -e .
uv pip install -r requirements/dev.txt
uv pip install -r requirements/bitsandbytes.txt
- name: Check quality
run: |

20
docs/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

50
docs/_static/css/lang-switcher.css vendored Normal file
View File

@@ -0,0 +1,50 @@
.lang-switcher {
display: flex;
align-items: center;
justify-content: center;
}
.lang-switcher__select {
appearance: none;
-webkit-appearance: none;
-moz-appearance: none;
padding: 6px 28px 6px 10px;
border-radius: 999px;
border: 1px solid rgba(0, 0, 0, 0.18);
background-color: #ffffff;
color: #333333;
font-size: 13px;
line-height: 18px;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.08);
cursor: pointer;
background-image: linear-gradient(45deg, transparent 50%, #667085 50%),
linear-gradient(135deg, #667085 50%, transparent 50%);
background-position: calc(100% - 16px) 50%, calc(100% - 11px) 50%;
background-size: 5px 5px, 5px 5px;
background-repeat: no-repeat;
}
.lang-switcher__select:focus {
outline: none;
border-color: rgba(41, 128, 185, 0.65);
box-shadow: 0 0 0 3px rgba(41, 128, 185, 0.18);
}
.wy-side-nav-search .lang-switcher {
margin-top: 10px;
}
.wy-side-nav-search .lang-switcher__select {
border-color: rgba(255, 255, 255, 0.18);
background-color: rgba(255, 255, 255, 0.08);
color: #ffffff;
box-shadow: none;
background-image: linear-gradient(45deg, transparent 50%, rgba(255, 255, 255, 0.75) 50%),
linear-gradient(135deg, rgba(255, 255, 255, 0.75) 50%, transparent 50%);
}
.wy-side-nav-search .lang-switcher__select:focus {
border-color: rgba(255, 255, 255, 0.45);
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
}

93
docs/_static/js/switcher.js vendored Normal file
View File

@@ -0,0 +1,93 @@
document.addEventListener('DOMContentLoaded', function () {
var path = window.location.pathname || '';
var isZh = path.indexOf('/zh/') !== -1;
var isEn = path.indexOf('/en/') !== -1;
if (!isZh && !isEn) return;
var currentLang = isZh ? 'zh' : 'en';
function buildSwitcher() {
var container = document.createElement('div');
container.className = 'lang-switcher';
var select = document.createElement('select');
select.setAttribute('aria-label', 'Language');
select.className = 'lang-switcher__select';
var optionZh = document.createElement('option');
optionZh.value = 'zh';
optionZh.textContent = 'Simplified Chinese';
optionZh.selected = isZh;
var optionEn = document.createElement('option');
optionEn.value = 'en';
optionEn.textContent = 'English';
optionEn.selected = isEn;
select.appendChild(optionZh);
select.appendChild(optionEn);
select.addEventListener('change', function () {
var nextLang = select.value;
if (nextLang === currentLang) return;
var targetUrl = path.replace('/' + currentLang + '/', '/' + nextLang + '/');
window.location.href = targetUrl + window.location.search + window.location.hash;
});
container.appendChild(select);
return container;
}
function hideOtherLanguageToc() {
var captions = document.querySelectorAll('p.caption');
for (var i = 0; i < captions.length; i++) {
var caption = captions[i];
var textEl = caption.querySelector('.caption-text');
if (!textEl) continue;
var label = (textEl.textContent || '').trim().toLowerCase();
var isCaptionZh = label === '中文' || label === 'chinese' || label === 'zh';
var isCaptionEn = label === 'english' || label === 'en';
if (!isCaptionZh && !isCaptionEn) continue;
var shouldHide = (currentLang === 'zh' && isCaptionEn) || (currentLang === 'en' && isCaptionZh);
var shouldHideCaption = true;
var next = caption.nextElementSibling;
if (next && next.tagName && next.tagName.toLowerCase() === 'ul') {
if (shouldHide) {
caption.style.display = 'none';
next.style.display = 'none';
} else if (shouldHideCaption) {
caption.style.display = 'none';
}
} else if (shouldHide) {
caption.style.display = 'none';
} else if (shouldHideCaption) {
caption.style.display = 'none';
}
}
}
var side = document.querySelector('.wy-side-nav-search');
if (side) {
var sideSwitcher = buildSwitcher();
sideSwitcher.style.marginTop = '8px';
sideSwitcher.style.display = 'flex';
sideSwitcher.style.justifyContent = 'center';
side.appendChild(sideSwitcher);
} else {
var topRight = buildSwitcher();
topRight.style.position = 'fixed';
topRight.style.top = '12px';
topRight.style.right = '12px';
topRight.style.zIndex = '9999';
document.body.appendChild(topRight);
}
hideOtherLanguageToc();
window.addEventListener('load', hideOtherLanguageToc);
setTimeout(hideOtherLanguageToc, 50);
setTimeout(hideOtherLanguageToc, 300);
});

37
docs/conf.py Normal file
View File

@@ -0,0 +1,37 @@
# Configuration file for the Sphinx documentation builder.
import os
import sys
# Define common settings here
project = 'LlamaFactory'
copyright = '2024, LlamaFactory Team'
author = 'LlamaFactory Team'
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
'myst_parser',
]
templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
html_js_files = [
'js/switcher.js',
]
html_css_files = [
'css/lang-switcher.css',
]
myst_enable_extensions = [
"colon_fence",
"deflist",
]
myst_heading_anchors = 3

View File

@@ -0,0 +1,3 @@
# Custom Kernels
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Fused Operators
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Triton
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# DeepSpeed
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# FSDP
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Parallel (DP, TP, EP, SP, CP)
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# LoRA
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Quantization
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

20
docs/en/conf.py Normal file
View File

@@ -0,0 +1,20 @@
import os
import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import *
# Language settings
language = 'en'
html_search_language = 'en'
# Static files
# Point to the root _static directory
html_static_path = ['../_static']
# Add custom JS for language switcher
html_js_files = [
'js/switcher.js',
]

View File

@@ -0,0 +1,3 @@
# Data Processing
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# DataEngine
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# ModelEngine
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Trainer
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Data Plugins
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Initialization
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Kernels
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Rendering
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Getting Started
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Data Argument
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Model Argument
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Sample Argument
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# Training Argument
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

62
docs/en/index.rst Normal file
View File

@@ -0,0 +1,62 @@
LlamaFactory Docs
=================
.. toctree::
:maxdepth: 1
:caption: Getting Started
getting-started
installation
llamaboard-web-ui
.. toctree::
:maxdepth: 1
:caption: Data Preparation
data-preparation/data-processing
.. toctree::
:maxdepth: 1
:caption: Training
training/sft
training/dpo
.. toctree::
:maxdepth: 1
:caption: Inference
inference/deploy
.. toctree::
:maxdepth: 1
:caption: Advanced
advanced/lora-and-quantization/lora
advanced/lora-and-quantization/quantization
advanced/distributed/fsdp
advanced/distributed/deepspeed
advanced/distributed/parallel-dp-tp-ep-sp-cp
advanced/custom-kernels/triton
advanced/custom-kernels/fused-operators
.. toctree::
:maxdepth: 1
:caption: Hyperparameters
hyperparameters/data-argument
hyperparameters/model-argument
hyperparameters/sample-argument
hyperparameters/training-argument
.. toctree::
:maxdepth: 1
:caption: Dev Guide
dev-guide/core/data-engine
dev-guide/core/model-engine
dev-guide/core/trainer
dev-guide/plugins/data-plugins
dev-guide/plugins/model-plugins/initialization
dev-guide/plugins/model-plugins/kernels
dev-guide/plugins/model-plugins/rendering

View File

@@ -0,0 +1,3 @@
# Deploy
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

3
docs/en/installation.md Normal file
View File

@@ -0,0 +1,3 @@
# Installation
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

View File

@@ -0,0 +1,3 @@
# LlamaBoard Web UI
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

3
docs/en/training/dpo.md Normal file
View File

@@ -0,0 +1,3 @@
# DPO
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

3
docs/en/training/sft.md Normal file
View File

@@ -0,0 +1,3 @@
# SFT
This page is not yet available in English. Use the language switcher to view Simplified Chinese.

35
docs/make.bat Normal file
View File

@@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to your PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

3
docs/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
sphinx>=6.0.0
sphinx-rtd-theme>=1.2.0
myst-parser>=2.0.0

View File

@@ -0,0 +1,93 @@
# LLaMA-Factory Kernels 系统
## 概述
LLaMA-Factory Kernels 系统用于管理不同硬件设备提供的高性能计算内核kernel实现该系统通过替换模型中的关键模块如 RMSNorm、SwiGLU、RoPE、MoE 等)为硬件优化的版本,从而显著提升模型训练和推理的性能。
Kernels 系统采用基于注册表的自动发现机制能够根据当前运行环境自动检测可用的硬件设备NPU、CUDA 等),并使能相应的高性能 kernels。这种设计使得用户无需关心底层实现细节只需简单调用接口即可获得性能提升。
## 核心特性
- **自动注册机制**:基于 `@register_kernel` 装饰器实现自动注册系统。系统启动时会自动扫描 `ops` 目录下的 kernel 实现,并将其注册到全局注册表中。
- **设备适配感知**自动检测当前硬件设备NPU、CUDA 等)并应用相应的优化。系统会跳过不支持的设备,确保在不同环境下都能正常工作。
- **模块化设计**:每个 kernel 独立实现,互不干扰。可以单独应用某个 kernel也可以批量应用所有默认的 kernels。
- **后向兼容**kernel 替换不修改模型权重,保持数值一致性。优化后的实现与原始实现保持精度一致(在浮点误差范围内)。
- **灵活扩展**:通过继承 `BaseKernel` 基类并使用装饰器,可以轻松添加新的 kernel 实现,支持新的硬件设备或优化算法。
## 使用方式
### 1. 通过训练 YAML 配置文件使用
要在训练过程中使能 kernels只需在配置文件中增加如下配置即可自动使能所有默认可用的 kernels
```yaml
...
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
...
```
### 2. 调用 API 使能
#### 2.1 apply_default_kernels 使能所有默认 kernels
`apply_default_kernels` API 能够自动应用当前设备上所有默认注册的 kernels
```python
from transformers import AutoModelForCausalLM
from llamafactory.v1.plugins.model_plugins.kernels import apply_default_kernels
# 加载模型
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
# 自动应用所有默认 kernels
model = apply_default_kernels(model, include_kernels="auto")
```
#### 2.2 apply_kernel 使能特定 kernel
如果需要更精细的控制,例如在某些场合单独应用某个 kernel可以手动调用 `apply_kernel` 函数并传入 kernel ID
```python
from transformers import AutoModelForCausalLM
from llamafactory.v1.plugins.model_plugins.kernels import apply_kernel
# 加载模型
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
# 手动应用各个 kernels
# 注意kernel ID 必须与定义时的 _kernel_id 一致
model = apply_kernel("npu_fused_rope", model=model)
model = apply_kernel("npu_fused_rmsnorm", model=model)
model = apply_kernel("npu_fused_swiglu", model=model)
model = apply_kernel("npu_fused_moe", model=model)
### 3. 查询已注册的可用 kernels
可以通过 `get_default_kernels` 获取当前环境中所有已注册且可用的默认 kernel ID
```python
from llamafactory.v1.plugins.model_plugins.kernels import get_default_kernels
# 获取默认 kernel 列表
available_kernels = get_default_kernels()
print(f"Available kernels: {available_kernels}")
# 输出示例: ['npu_fused_rmsnorm', 'npu_fused_swiglu', 'npu_fused_rope', 'npu_fused_moe']
```
### 当前已实现的 kernels
| Kernel ID | 功能 | 支持的设备 | 备注 |
|-----------|------|-----------|------|
| [npu_fused_rmsnorm](./fused-operators.md/#npufusedrmsnorm) | RMSNorm 融合算子 | NPU | NPU 设备的高性能 RMSNorm 实现 |
| [npu_fused_swiglu](./fused-operators.md/#npufusedswiglu) | SwiGLU 融合算子 | NPU | NPU 设备的高性能 SwiGLU 实现 |
| [npu_fused_rope](./fused-operators.md/#npufusedrope) | RoPE 融合算子 | NPU | NPU 设备的高性能 RoPE 实现 |
| [npu_fused_moe](./fused-operators.md/#npufusedmoe) | MoE 融合算子 | NPU | MoE 融合算子,适配 Qwen3-MoE 等模型 |
我们会持续适配更多的 kernels如果您需要自己开发新的 kernels请参考我们的 [Kernel 开发文档](../../dev-guide/plugins/model-plugins/kernels.md),欢迎您向 LLaMA-Factory 贡献代码。

View File

@@ -0,0 +1,104 @@
# Fused Operators
LLaMA-Factory 提供了一系列针对特定硬件优化的融合算子。这些算子位于 `src/llamafactory/v1/plugins/model_plugins/kernels/ops` 目录下。
系统启动时,`scan_all_kernels` 函数会自动扫描该目录,注册所有可用的算子。您可以通过 `apply_default_kernels(model, include_kernels="auto")` 一键启用它们,或者使用 `apply_kernel` 单独启用。
以下是当前支持的融合算子详情:
## NpuFusedRMSNorm
RMSNormRoot Mean Square Layer Normalization是一种常用于大模型的归一化方法。在推理或训练中RMSNorm 融合算子 将bias、residual等操作进行融合可以减少显存访问次数加速计算。
Ascend npu 通过 `torch_npu.npu_rms_norm` 接口提供 RMSNorm 融合算子调用接口,支持 float16, bfloat16, float 等数据格式。RMSNorm 算子常见于Qwen等LLM模型中由于torch侧没有提供 RMSNorm 算子的接口,因此在模型中通常是以自定义类的形式出现,通过替换 RMSNorm 类的 `forward` 方法即可使能。
```python
def _npu_rms_forward(self, hidden_states):
"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
hidden_states: Input hidden states tensor, same shape as the baseline.
Returns:
Normalized tensor consistent with the baseline RMSNorm behavior.
"""
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
```
在 LlamaFactory 中,通过 `NpuRMSNormKernel` 提供使能该融合算子的入口,只需要调用 `apply_kernel("npu_fused_rmsnorm", model=model)` 即可针对已适配的模型使能 npu RMSNorm 融合算子。
## NpuFusedSwiGlu
SwiGLUSwish-Gated Linear Unit是一种结合了Swish激活函数和门控线性单元GLU的混合激活函数其主要功能是对输入张量进行门控线性变换近年来被广泛应用于 LLM 模型中的 MLP 层。SwiGLU 融合算子将分割、激活、矩阵乘等多个操作融合为单一硬件指令,避免多次内核启动开销。
Ascend npu 通过 `torch_npu.npu_swiglu` 接口提供 SwiGLU 融合算子调用接口,支持 float16bfloat16float SwiGLU 算子常见于Qwen等LLM模型中由于torch侧没有提供 SwiGLU 算子的接口,因此在模型中通常是以自定义类的形式出现,通过替换 SwiGLU 类的 `forward` 方法即可使能。替换过程可参考如下示例:
```python
# 原始 MLP forward 方法:
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
# 替换后的 forward 方法:
def _npu_swiglu_forward(self, hidden_state):
return self.down_proj(
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
)
```
在 LLaMA-Factory 中,通过 `NpuSwiGluKernel` 提供使能该融合算子的入口,只需要调用 `apply_kernel("npu_fused_swiglu", model=model)` 即可针对已适配的模型使能 npu SwiGLU 融合算子。对于未适配的模型,如有需要,您可根据示例以及[开发者文档](../../dev-guide/plugins/model-plugins/kernels.md)自行适配。
## NpuFusedRoPE
RoPERotary Positional Embedding旋转式位置嵌入 是一种位置编码技术,广泛应用于 Qwen 等 LLM 模型中,用于有效编码文本序列的位置信息。它结合了绝对位置编码的稳定性与相对位置编码的灵活性,同时具备优秀的长度泛化能力。传统 RoPE 算子通常在 LLM 等模型结构中通过自定义函数的形式实现。RoPE 融合算子将原计算流程合并为单个硬件优化算子,从而提升性能。
Ascend npu 通过 `torch_npu.npu_rotary_mul` 提供 RoPE 融合算子调用接口,支持 float16bfloat16float32 等数据格式。以 Qwen3 系列模型为例,通过替换其 `apply_rotary_pos_emb` 函数即可实现 RoPE融合算子使能
```python
# 原始 apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# 替换 RoPE 融合算子后:
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
```
在 LLaMA-Factory 中,通过 `NpuRoPEKernel` 提供使能该融合算子的入口,只需要调用 `apply_kernel("npu_fused_rope", model=model)` 即可针对已适配的模型使能 npu RoPE 融合算子。对于未适配的模型,如有需要,您可根据示例以及[开发者文档](../../dev-guide/plugins/model-plugins/kernels.md)自行适配。
## NpuFusedMoE
MoEMixture of Experts模型通过稀疏激活扩展容量。在原生 Transformers 实现中,使用串行循环遍历专家,导致内核启动开销大、硬件利用率低。
**MoE 融合算子** 利用 **GMMGrouped Matrix Multiplication分组矩阵乘** 技术,支持在单个硬件指令内并行处理多组不同形状(行数不一)的矩阵乘法,消减循环开销,同时无需额外的显存复制,显著提升训练性能。
Ascend npu 通过 `torch_npu.npu_grouped_matmul` 等接口提供底层支持,通过替换模型中的 MoE Block forward 方法,即可利用 NPU 的分组矩阵乘能力。
核心逻辑替换如下(简化示意):
```python
def _npu_moe_forward(self, hidden_states, routing_weights, router_indices):
# 1. 排序:将乱序的 Token 按指派的专家归类,并生成索引映射
permuted_states, row_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices)
# 2. 统计:计算每个专家需要处理的 Token 数量
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
# 3. 计算 (GMM):一次性并行计算所有专家的权重,自动适配不同专家的输入长度
inter_states = torch_npu.npu_grouped_matmul(permuted_states, self.gate_up_proj_weights, split_sizes=tokens_per_expert, ...)
inter_states = torch_npu.npu_swiglu(inter_states)
output = torch_npu.npu_grouped_matmul(inter_states, self.down_proj_weights, split_sizes=tokens_per_expert, ...)
# 4. 还原:将结果恢复成原始 Token 顺序并应用路由权重
return torch_npu.npu_moe_token_unpermute(output, row_map, routing_weights)
```
在 LLaMA-Factory 中,通过 `NpuFusedMoEKernel` 提供使能该融合算子的入口。只需要调用 `apply_kernel("npu_fused_moe", model=model)` 即可针对已适配的模型使能 NPU MoE 融合算子。对于未适配的模型,您也可以参考上述示例代码以及[开发者文档](../../dev-guide/plugins/model-plugins/kernels.md)自行适配。

View File

@@ -0,0 +1 @@
# Triton

View File

@@ -0,0 +1 @@
# DeepSpeed

View File

@@ -0,0 +1 @@
# FSDP

View File

@@ -0,0 +1 @@
# Parallel(DP, TP, EP, SP, CP)

View File

@@ -0,0 +1,3 @@
# Lora
参数管理(二级参数形式)

View File

@@ -0,0 +1 @@
# Quantization

20
docs/zh/conf.py Normal file
View File

@@ -0,0 +1,20 @@
import os
import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import *
# Language settings
language = 'zh_CN'
html_search_language = 'zh'
# Static files
# Point to the root _static directory
html_static_path = ['../_static']
# Add custom JS for language switcher
html_js_files = [
'js/switcher.js',
]

View File

@@ -0,0 +1,479 @@
# LLaMA-Factory v1 数据预处理
## 总览
LLaMA-Factory `v1` 采用了全新的数据处理架构,主要包含以下核心组件:
- **DataEngine**:数据引擎,负责数据集的加载、索引和转换等各种插件的接入和调用,并提供数据访问接口
- **DataConverterPlugin**:数据转换器,将非标准格式转换为统一的标准格式
- **DataLoaderPlugin**:数据加载插件,支持多种文件格式的加载
- **DataIndexPlugin**:数据索引插件,支持数据集的采样和权重调整
- **DataSelectorPlugin**:数据选择插件,支持灵活的数据访问方式
与 LLaMA-Factory `v0` 版本相比,`v1` 版本采用了统一的数据格式Messages Format所有数据都会被转换为标准的对话消息列表此外`v1` 版本通过 DataEngine 与 Plugin 机制,提供了自定义数据处理流的接口,具有更好的可扩展性和一致性。
---
## 目录
- [基本用法](#基本用法)
- [标准数据格式](#标准数据格式)
- [数据集配置文件](#数据集配置文件)
- [完整示例](#完整示例)
---
## 基本用法
### 在训练配置文件,可以通过如下方式配置数据集:
<details open>
<summary>方式 1使用 HF Hub Repo ID</summary>
直接指定 HF Hub 上的数据集 Repo IDDataEngine 会自动从 HF Hub 下载并加载数据集。
注:使用 Repo ID 直接加载的数据集需要为标准格式
**训练配置文件示例:**
```yaml
# example_sft.yaml
...
dataset: llamafactory/v1-sft-demo # HF Hub Repo ID
...
```
</details>
<details>
<summary>方式 2使用 HF Hub 上的 YAML 配置文件</summary>
`dataset`字段指定 HF Hub 上的 `dataset_info.yaml` 的 URIDataEngine 会自动下载该配置文件并根据其中的配置加载数据集。
**训练配置文件示例:**
```yaml
# example_sft.yaml
...
dataset: llamafactory/v1-sft-demo/dataset_info.yaml # 远程 dataset_info.yaml 路径
...
```
</details>
<details>
<summary>方式 3使用本地 HF 数据集文件路径</summary>
`dataset`字段指定本地的数据集文件路径(`.json``.jsonl` 等)
注:直接指定数据集文件路径,要求该数据文件的格式已为标准格式
**训练配置文件示例:**
```yaml
# example_sft.yaml
...
dataset: ~/data/v1_sft_demo.jsonl # 本地数据集文件绝对路径
...
```
</details>
<details>
<summary>方式 4使用本地 YAML 配置文件路径</summary>
`dataset`字段指定本地的 `dataset_info.yaml` 配置文件路径DataEngine 会根据该配置加载其中的数据集。
**训练配置文件示例:**
```yaml
# example_sft.yaml
...
dataset: ~/data/dataset_info.yaml # 本地 dataset_info.yaml 文件路径
...
```
</details>
---
## 标准数据格式
v1 使用统一的 **Messages 格式**作为标准数据格式。每个样本都是一个包含 `messages` 字段的 JSON 对象。
针对alpaca、sharegpt、以及dpo等格式的数据可以通过内置的`DataConverterPlugin`插件,自动将其转化为标准格式,对于其他自定义格式的数据,用户也可通过自定义`DataConverterPlugin`来实现数据格式标准化,这部分内容参见[`DataConverterPlugin`](../dev-guide/plugins/data-plugins.md/#data-converter-plugin)
### 1. SFT监督微调样本格式
```json
{
"messages": [
{
"role": "system",
"content": [{"type": "text", "value": "You are a helpful assistant."}],
"loss_weight": 0.0
},
{
"role": "user",
"content": [{"type": "text", "value": "Hello, who are you?"}],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [{"type": "text", "value": "I am an AI assistant."}],
"loss_weight": 1.0
}
]
}
```
#### 字段说明:
- **messages**: 消息列表,包含一轮或多轮对话
- **role**: 消息角色,可选值:
- `"system"`: 系统提示
- `"user"`: 用户输入
- `"assistant"`: 模型回复
- **content**: 内容列表,每个元素包含:
- **type**: 内容类型,可选值:
- `"text"`: 文本内容
- `"image_url"`: 图像 URL多模态
- `"audio_url"`: 音频 URL多模态
- `"video_url"`: 视频 URL多模态
- `"tools"`: 工具描述
- `"tool_calls"`: 工具调用
- `"reasoning"`: 推理过程
- **value**: 具体内容(字符串)
- **loss_weight**: 损失权重(浮点数)
- `0.0`: 不计算损失(用于提示词部分)
- `1.0`: 完全计算损失(用于回复部分)
- 可设置为其他值以调整不同部分的学习权重
- **_dataset_name** (可选): 数据集名称,由 DataEngine 自动添加
- **extra_info** (可选): 额外信息字段
### 2. DPO偏好对齐样本格式
```json
{
"chosen_messages": [
{
"role": "user",
"content": [{"type": "text", "value": "用户提问"}],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [{"type": "text", "value": "更优的回答"}],
"loss_weight": 1.0
}
],
"rejected_messages": [
{
"role": "user",
"content": [{"type": "text", "value": "用户提问"}],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [{"type": "text", "value": "较差的回答"}],
"loss_weight": 1.0
}
]
}
```
### 3. 多模态支持
对于多模态数据,可以在 `content` 列表中添加非文本类型的内容:
```json
{
"messages": [
{
"role": "user",
"content": [
{"type": "text", "value": "这张图片里有什么?"},
{"type": "image_url", "value": "path/to/image.jpg"}
],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [{"type": "text", "value": "图片中有一只猫。"}],
"loss_weight": 1.0
}
]
}
```
**说明**`image_url``audio_url``video_url` 的路径可以是相对路径或绝对路径,具体加载方式由 `DataLoaderPlugin` 决定。
---
## 数据集配置文件
### 1. dataset_info.yaml 配置文件格式
`dataset_info.yaml` 支持同时配置多个数据集,支持分别从 HF Hub 和本地获取数据集,数据集默认会混合并打乱顺序。
**示例配置文件:`data/dataset_info.yaml`**
```yaml
# 数据集 1使用本地文件 + Alpaca 转换器
identity:
file_name: ~/data/identity.json #本地数据集文件绝对路径
converter: alpaca # 使用 alpaca 转换器
# 数据集 2指定自定义数据集目录
alpaca_en_demo:
file_name: ~/data/alpaca_en_demo.json # 数据集文件名
converter: alpaca # 转换器插件
size: 500 # 只使用 500 个样本
weight: 0.5 # 数据集权重,用于控制该数据集的采样频率
split: train # 数据集划分,默认为 train
streaming: false # 是否流式加载,默认为 false
# 数据集 3从 Hugging Face Hub 加载
hf_dataset:
hf_hub_url: llamafactory/v1-sft-demo # HF repo ID
streaming: false
# 数据集 4已经是标准格式无需转换器
standard:
file_name: ~/data/v1_sft_demo.jsonl # 本地标准数据集文件路径
# 数据集 5自定义数据集和 converter 插件
custom_dataset:
file_name: custom_data.json
converter: custom_converter
weight: 1.0
```
### 2. 配置字段说明
#### 数据源配置(二者必选其一):
- **hf_hub_url** (str): Hugging Face Hub 数据集仓库 ID
- 示例:`"llamafactory/v1-sft-demo"`
- 如果指定,则从 HF Hub 加载数据集
- **file_name** (str): 本地文件路径
- 支持格式:`.json``.jsonl``.csv``.parquet``.arrow``.txt`
#### 可选配置:
- **split** (str): 数据集划分,默认为 `"train"`
- **converter** (str): 数据转换器名称
- 可选值:`"alpaca"`(更多转换器持续添加中,也可在 data_plugin 中添加自定义 converter
- 如果不指定,则假定数据已是标准格式
- **size** (int): 使用的样本数量,默认使用全部
- **weight** (float): 数据集权重,用于混合数据集时的采样频率,默认为 1.0
- **streaming** (bool): 是否流式加载,默认为 `False`
---
## 完整示例
### 1. 基础使用示例
```python
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
# 使用本地 YAML 配置
data_args = DataArguments(
dataset="~/data/v1_sft_demo.jsonl",
cutoff_len=2048
)
# 初始化 DataEngine
engine = DataEngine(data_args=data_args)
# 查看数据集信息
print(f"数据集总样本数: {len(engine)}")
print(f"数据集列表: {list(engine.datasets.keys())}")
# 访问数据样本
sample = engine[0]
print(f"样本格式: {sample.keys()}")
print(f"消息列表: {sample['messages']}")
# 批量访问
batch = engine[0:10]
print(f"批量样本数: {len(batch)}")
```
### 2. 输出示例
**查看数据集信息输出:**
```
数据集总样本数: 500
数据集列表: ['default']
样本格式: dict_keys(['_dataset_name', 'messages'])
消息列表: [{'role': 'user', 'content': [{'type': 'text', 'value': 'hi'}], 'loss_weight': 0.0}, {'role': 'assistant', 'content': [{'type': 'text', 'value': 'Hello! I am {{name}}, an AI assistant developed by {{author}}. How can I assist you today?'}], 'loss_weight': 1.0}]
批量样本数: 10
```
**访问单个样本输出:**
```python
{
'_dataset_name': 'alpaca_en_demo',
'messages': [
{
'role': 'user',
'content': [{'type': 'text', 'value': 'What is the capital of France?'}],
'loss_weight': 0.0
},
{
'role': 'assistant',
'content': [{'type': 'text', 'value': 'The capital of France is Paris.'}],
'loss_weight': 1.0
}
]
}
```
### 3. 混合多数据集配置文件示例
**配置文件:`data/mixed_datasets.yaml`**
```yaml
dataset_1:
file_name: alpaca_en_demo.json
converter: alpaca
weight: 1.0
dataset_2:
file_name: identity.json
converter: alpaca
weight: 2.0
dataset_3:
hf_hub_url: llamafactory/v1-sft-demo
weight: 1.5
```
### 4. 多模态数据示例
**数据文件:`data/multimodal_demo.jsonl`**
标准化后数据示例:
```json
[
{
"messages": [
{
"role": "user",
"content": [
{"type": "text", "value": "Who are they?"},
{"type": "image_url", "value": "mllm_demo_data/1.jpg"}
],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [
{"type": "text", "value": "They're Kane and Gretzka from Bayern Munich."}
],
"loss_weight": 1.0
},
{
"role": "user",
"content": [
{"type": "text", "value": "What are they doing?"},
{"type": "image_url", "value": "mllm_demo_data/1.jpg"}
],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [
{"type": "text", "value": "They are celebrating on the soccer field."}
],
"loss_weight": 1.0
}
]
},
{
"messages": [
{
"role": "user",
"content": [
{"type": "text", "value": "Who is he?"},
{"type": "image_url", "value": "mllm_demo_data/2.jpg"}
],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [
{"type": "text", "value": "He's Thomas Muller from Bayern Munich."}
],
"loss_weight": 1.0
},
{
"role": "user",
"content": [
{"type": "text", "value": "Why is he on the ground?"}
],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [
{"type": "text", "value": "Because he's sliding on his knees to celebrate."}
],
"loss_weight": 1.0
}
]
}
]
```
```python
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
data_args = DataArguments(dataset="data/multimodal_demo.jsonl")
engine = DataEngine(data_args=data_args)
# 访问多模态样本
sample = engine[0]
print("用户消息内容:")
for content_item in sample['messages'][0]['content']:
print(f" 类型: {content_item['type']}, 值: {content_item['value']}")
```
---
**注意事项**
1. 所有数据最终都会转换为标准的 Messages 格式
2. 通过 `converter` 插件可以支持多种数据格式
3. 通过 `weight``size` 参数可以灵活控制数据分布
4. 支持同时使用本地数据集和 HuggingFace Hub 数据集
5. 多模态数据通过在 `content` 中添加不同类型的元素来支持
6. 更多细节信息请参考我们的 [API REFERENCE](../dev-guide/core/data-engine.md/#data-engine)

View File

@@ -0,0 +1,253 @@
# DataEngine
## 1. DataEngine 简介
`DataEngine` 是 LLaMA-Factory v1 数据处理的核心类,继承自 PyTorch 的 `Dataset`,负责各种插件的接入,其他功能(如数据格式转换、数据加载等)均通过插件的形式实现并接入 `DataEngine`
`DataEngine`接受一个唯一入参:`DataArguments` 实例,所有的元数据集信息均通过该参数配置传入。
## 2. DataEngine 与 DataArguments 接口定义
```python
@dataclass
class DataArguments:
""" `DataEngine`初始化入参
args:
dataset (str): 数据集路径,远程数据集 repo id / dataset_info.yaml 路径,或本地数据集路径/dataset_info.yaml路径
cutoff_len (int): 数据集截止长度,即数据集最大样本采样数量
"""
...
class DataEngine(Dataset):
"""数据引擎DataEngine
`DataEngine` 负责数据集的加载与统一管理,支持:
- 从本地路径或 Hugging Face Hub 加载数据
- 通过插件机制加载自定义数据
- 构建统一的数据索引
- 支持流式streaming与非流式数据访问
attr:
args (DataArguments): 数据参数配置
datasets (dict[str, HFDataset]): 数据集名称到数据对象的映射
dataset_infos (dict[str, DatasetInfo]): 数据集名称到元信息的映射
data_index (list[tuple[str, int]]): 数据索引列表,每项为 (dataset_name, sample_index)
streaming (bool): 是否为流式数据集
"""
def __init__(self, data_args: DataArguments) -> None:
"""初始化 `DataEngine`
初始化时自动执行以下步骤:
1. 调用 `get_dataset_info` 从 `data_args` 读取并解析数据集元信息
2. 调用 `load_dataset`,根据配置加载数据集
3. 调用 `build_data_index`,构建统一的索引列表
args:
data_args (DataArguments): 数据参数配置对象
"""
...
def get_dataset_info(self) -> None:
"""从配置文件或远程仓库加载数据集元信息
根据 `self.args.dataset` 确定数据源,数据源支持如下选项:
- 本地 YAML 配置文件路径
- Hugging Face Hub 上的 YAML 配置文件路径
- 本地数据集文件路径
- Hugging Face Hub 数据集 repo id
"""
...
def load_dataset(self) -> None:
"""根据数据集元信息加载所有数据集
每个数据集条目可以包含以下字段:
- `hf_hub_url`: 使用 `datasets.load_dataset` 加载
- 本地数据文件:通过 `DataLoaderPlugin` 插件加载
- `streaming`: 是否启用流式模式
更新:
self.datasets (dict): 数据集名称到已加载数据对象的映射
self.streaming (bool): 如果任一数据集为流式模式,则设置为 True
"""
...
def build_data_index(self) -> None:
"""构建统一的数据索引
为所有数据集创建全局索引列表 `(dataset_name, sample_index)`
当启用流式模式时,生成固定长度(例如 1000的占位索引
否则,为每条样本建立索引。
插件 `DataIndexPlugin` 可根据数据集大小或权重调整索引分布
"""
...
def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
"""将原始样本转换为统一格式
根据 `dataset_info` 中的 `converter` 字段,调用对应的转换插件,
将原始样本标准化为统一的数据结构。
args:
raw_sample (dict[str, Any]): 原始数据样本
dataset_name (str): 样本所属的数据集名称
return:
Sample: 转换后的标准化格式样本
"""
...
def __len__(self) -> int:
"""返回数据集的总样本数
return:
int: 数据集长度
如果为流式数据集,返回 `-1`
"""
...
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
"""根据索引或选择器获取样本
args:
index (Union[int, Any]): 数据索引int 或 list[int]
return:
Union[Sample, list[Sample]]: 单个样本或样本列表
"""
...
def __iter__(self) -> Iterable:
"""返回数据集迭代器
用于非流式数据集的顺序或随机访问
流式模式下需要实现异步加载逻辑
return:
Iterable: 数据集迭代器。
"""
...
async def __aiter__(self) -> AsyncIterable:
"""返回异步数据集迭代器
用于流式数据集或异步数据加载场景
允许在异步环境中以流的方式读取样本
return:
AsyncIterable: 异步迭代器,按顺序产出样本
"""
...
```
`DataArguments` 参数说明:
`dataset`: 数据集路径,支持本地或远程,当传入本地数据集文件路径时,需要满足该数据集为标准格式;否则需要传入 `dataset_info.yaml` 来配置数据集的 `converter` 等元信息,以告知 `DataEngine` 应当如何处理该数据。
`cutoff_len`: 数据集的截止长度,即该数据集的最大样本数量。
---
## 3. DataEngine 核心方法
### 3.1 `get_dataset_info`:加载数据元信息
根据 `dataset` 参数加载数据集配置,获取数据位置、数据格式、插件配置等所有数据元信息,在实例化 `DataEngine` 时会自动调用此方法。
### 3.2 加载数据集:`load_dataset`
遍历所有数据源,根据不同的数据源加载数据,在实例化 `DataEngine` 时会自动调用此方法。
```python
for key, value in self.dataset_infos.items():
split = value.get("split", "train")
streaming = value.get("streaming", False)
if "hf_hub_url" in value:
# 从 HF Hub 加载
dataset = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
else:
# 使用 DataLoaderPlugin 加载本地文件
dataset = DataLoaderPlugin(args=self.args).auto_load_data(value)
self.datasets[key] = dataset
```
### 3.3 `build_data_index`:构建数据索引
为每个数据集创建索引列表 `[(dataset_name, sample_index), ...]`, `DataIndexPlugin`插件在此处被调用,可控制各数据集的采样频率、采样方式等,在实例化`DataEngine`时会自动调用此方法。
```python
for dataset_name, dataset in self.datasets.items():
# 创建基础索引
data_index = [(dataset_name, idx) for idx in range(len(dataset))]
# 根据 size 和 weight 调整索引
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight:
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
```
### 3.4 `_convert_data_sample`:数据格式标准化
将原始数据转换为标准格式,`DataConverterPlugin`插件在此处被调用,具体调用的插件由 `get_dataset_info` 方法获取的 `converter` 信息指定,若 `converter` 为空则假定数据集为标准格式,此方法由`DataEngine``__getitem__` 方法调用。
```python
def _convert_data_sample(self, raw_sample: dict, dataset_name: str) -> Sample:
converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None:
# 使用指定的转换器
from ..plugins.data_plugins.converter import get_converter
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
else:
# 已经是标准格式
return {"_dataset_name": dataset_name, **raw_sample}
```
---
## 4. 初始化
`DataEngine` 初始化过程只需传入一个构建好的 `DataArguments` 即可,后续可通过该 `DataEngine` 访问数据集中的数据。
```python
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
# 1. 创建数据参数
data_args = DataArguments(
dataset="~/data/v1_sft_demo.jsonl",
cutoff_len=2048
)
# 2. 初始化 Data Engine
data_engine = DataEngine(data_args=data_args)
# 3. 访问数据
sample = data_engine[0] # 获取第一个样本
```
## 5. 数据访问方式
实例化后的`DataEngine`支持整数索引、列表索引、以及切片等访问方式,其数据读取用法可等价于 Python 列表。
```python
sample = data_engine[0] # 获取第一个样本
sample = data_engine[0:10] # 获取前 10 个样本
sample = data_engine[[0, 5, 10]] # 获取指定索引的样本
```

View File

@@ -0,0 +1 @@
# ModelEngine

View File

@@ -0,0 +1 @@
# Trainer

View File

@@ -0,0 +1,467 @@
# Data Plugins
## 1. Data Plugins 简介
## DataConverterPlugin
### 1. DataConverterPlugin 简介
DataConverter 负责将非标准格式的数据集转换为 v1 的标准 Messages 格式。这使得用户可以继续使用现有的数据集(如 Alpaca 格式),而无需手动转换。针对自定义格式的数据集,用户也可以通过构建对应的自定义 DataConverter 插件,来负责其数据格式标准化。
当前LLaMA-Factory 已内置了 `Alpaca Converter``Pair Converter`,这两类数据集可以直接使用对应的 converter 进行标准化,无需自定义转换器。
### 2. Alpaca Converter 详解
#### 2.1 Alpaca 格式
Alpaca 格式是一种常见的指令微调数据格式:
```json
{
"system": "You are a helpful assistant.",
"instruction": "Describe a process of making crepes.",
"input": "",
"output": "Making crepes is an easy and delicious process..."
}
```
#### 2.2 Alpaca Converter 接口定义
```python
class AlpacaSample(TypedDict, total=False):
"""Alpaca 格式数据样本结构
attr:
system (str, 可选): 系统提示信息system prompt用于设定对话背景或模型行为。
instruction (str, 可选): 用户指令user instruction通常为任务描述。
input (str, 可选): 额外的输入内容input text可与 instruction 拼接。
output (str, 可选): 模型生成的目标输出expected response
"""
...
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""将 Alpaca 样本转换为 SFTSupervised Fine-Tuning标准样本格式
`alpaca_converter` 将 Alpaca 数据集中一条样本转换为通用的 `SFTSample` 格式
该格式用于监督微调SFT或多轮对话建模
转换逻辑:
- 若存在 `system` 字段则生成一条系统消息loss_weight = 0.0
- 若存在 `instruction` 或 `input` 字段则合并为一条用户消息loss_weight = 0.0
- 若存在 `output` 字段则生成一条助手机器人回复消息loss_weight = 1.0
args:
raw_sample (AlpacaSample): 原始 Alpaca 数据样本
return:
SFTSample: 转换后的标准化样本,格式如下:
{
"messages": [
{"role": "system", "content": [{"type": "text", "value": "..."}], "loss_weight": 0.0},
{"role": "user", "content": [{"type": "text", "value": "..."}], "loss_weight": 0.0},
{"role": "assistant", "content": [{"type": "text", "value": "..."}], "loss_weight": 1.0},
]
}
example:
>>> raw = {"instruction": "请将以下句子翻译成英文:", "input": "你好", "output": "Hello"}
>>> alpaca_converter(raw)
{
"messages": [
{"role": "user", "content": [{"type": "text", "value": "请将以下句子翻译成英文:你好"}], "loss_weight": 0.0},
{"role": "assistant", "content": [{"type": "text", "value": "Hello"}], "loss_weight": 1.0}
]
}
"""
```
#### 2.3 转换过程
`alpaca_converter` 函数将 Alpaca 格式转换为标准格式,转换逻辑如下:
```python
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
messages = []
# 1. 添加系统提示词(如果存在)
if "system" in raw_sample:
messages.append({
"role": "system",
"content": [{"type": "text", "value": raw_sample["system"]}],
"loss_weight": 0.0
})
# 2. 添加用户输入instruction + input
if "instruction" in raw_sample or "input" in raw_sample:
user_content = raw_sample.get("instruction", "") + raw_sample.get("input", "")
messages.append({
"role": "user",
"content": [{"type": "text", "value": user_content}],
"loss_weight": 0.0
})
# 3. 添加模型回复
if "output" in raw_sample:
messages.append({
"role": "assistant",
"content": [{"type": "text", "value": raw_sample["output"]}],
"loss_weight": 1.0
})
return {"messages": messages}
```
#### 2.4 转换示例
**输入Alpaca 格式):**
```json
{
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris."
}
```
**输出(标准格式):**
```json
{
"messages": [
{
"role": "user",
"content": [{"type": "text", "value": "What is the capital of France?"}],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [{"type": "text", "value": "The capital of France is Paris."}],
"loss_weight": 1.0
}
]
}
```
### 3. 自定义转换器
#### 3.1 创建自定义转换器
如果用户有自己的数据格式,可以轻松添加自定义转换器将其标准化,实现过程可参考如下示例:
```python
# src/llamafactory/v1/plugins/data_plugins/converter.py
from typing import TypedDict, NotRequired
from ...extras.types import SFTSample
# 1. 定义输入格式的类型
class MyCustomSample(TypedDict, total=False):
question: str
answer: str
context: NotRequired[str]
# 2. 实现转换逻辑
def custom_converter(raw_sample: MyCustomSample) -> SFTSample:
messages = []
# 构建用户消息
user_text = raw_sample["question"]
if "context" in raw_sample:
user_text = f"Context: {raw_sample['context']}\n\nQuestion: {user_text}"
messages.append({
"role": "user",
"content": [{"type": "text", "value": user_text}],
"loss_weight": 0.0
})
# 构建助手消息
messages.append({
"role": "assistant",
"content": [{"type": "text", "value": raw_sample["answer"]}],
"loss_weight": 1.0
})
return {"messages": messages}
# 3. 注册 custom_converter
#src/llamafactory/v1/plugins/data_plugins/converter.py: CONVERTERS
CONVERTERS = {
"alpaca": alpaca_converter,
"custom": custom_converter, # 添加自定义转换器
}
```
#### 3.2 使用自定义转换器
在 YAML 配置中指定转换器名称:
```yaml
my_dataset:
file_name: custom_data.json
converter: custom
```
---
## DataLoaderPlugin
### 1. DataLoaderPlugin 简介
`DataLoaderPlugin` 负责从本地文件加载数据集,当前支持如下文件格式:
- **JSON**: `.json`
- **JSONL**: `.jsonl`
- **CSV**: `.csv`
- **Parquet**: `.parquet`
- **Arrow**: `.arrow`
- **Text**: `.txt`
### 2. DataLoaderPlugin 接口定义
```python
@dataclass
class DataLoaderPlugin:
"""数据加载插件DataLoaderPlugin
负责根据数据集信息(`DatasetInfo`)自动加载本地或远程数据集。
支持多种文件格式(如 CSV、JSON、Parquet、Text、Arrow并可选择是否以流式方式加载。
通常由 `DataEngine` 调用,用于统一封装数据加载逻辑。
"""
args: DataArguments
"""数据参数对象,包含数据目录、缓存路径、分片等配置信息。"""
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""获取数据集文件格式
根据输入文件路径自动判断应使用的 HuggingFace `load_dataset` 构建器类型。
通过文件扩展名推断数据类型,例如 `.csv`、`.jsonl`、`.parquet`、`.txt` 等。
args:
path (str): 数据集文件路径,用于识别文件类型。
return:
Literal["arrow", "csv", "json", "parquet", "text"]:
数据构建器名称,用于 `datasets.load_dataset()`。
example:
>>> _get_builder_name("data/train.jsonl")
"json"
"""
...
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
"""根据传入的 `dataset_info` 自动选择合适的加载方式
args:
dataset_info (DatasetInfo): 数据集元信息,通常包含:
- `file_name`: 数据文件路径
- `split`: 数据划分(如 "train"、"test"
- `streaming`: 是否启用流式加载
return:
HFDataset: 加载完成的 Hugging Face 数据集对象。
example:
>>> plugin = DataLoaderPlugin(args)
>>> ds = plugin.auto_load_data({"file_name": "~/data.json", "split": "train"})
"""
...
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
"""从文件或目录加载数据集
根据输入路径自动识别文件类型CSV、JSON、Parquet、Text 等),
并通过 `datasets.load_dataset()` 加载数据集。
若 `streaming=True`,则将结果转换为迭代式数据集。
args:
filepath (str): 文件路径或目录路径。
split (str): 数据划分名称(如 "train"、"validation")。
streaming (bool): 是否启用流式加载模式。
return:
HFDataset: 加载后的数据集对象。
example:
>>> plugin.load_data_from_file("data/train.json", "train", False)
"""
...
```
---
## DataIndexPlugin
### 1. DataIndexPlugin 简介
`DataIndexPlugin` 负责调整数据索引,支持通过配置 `size`, `weight` 等参数控制数据集样本数量和采样频率。
- 使用 `size` 参数 限制使用的样本数量:
```yaml
my_dataset:
file_name: large_dataset.json
size: 1000 # 只使用前 1000 个样本
```
- 使用 `weight` 参数调整数据集在混合数据中的采样频率:
```yaml
dataset_a:
file_name: data_a.json
weight: 1.0
dataset_b:
file_name: data_b.json
weight: 2.0 # dataset_b 的样本出现频率是 dataset_a 的 2 倍
```
**说明**`weight` 参数适用于在多个数据集混合训练时,调整不同数据集的的采样频率
-`weight=1.0` 时,数据集按原始比例采样
-`weight=2.0` 时,该数据集的索引会复制 2 倍,使其样本出现频率翻倍
### 2. DataIndexPlugin 接口定义
```python
@dataclass
class DataIndexPlugin:
"""数据索引插件DataIndexPlugin
根据 `size` 和 `weight` 调整数据索引列表,控制数据集的样本数量和采样频率
通常在多数据集混合训练时使用,以控制不同数据集在总体样本中的占比。
在 `DataEngine.build_data_index` 中被自动调用,用于实现样本重采样或加权分布。
"""
def adjust_data_index(
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
) -> list[tuple[str, int]]:
"""调整数据索引列表
根据 `size` 或 `weight` 参数对输入的数据索引进行采样、扩展或缩减。
若两个参数同时存在,将依次执行基于大小和基于权重的调整。
args:
data_index (list[tuple[str, int]]):
数据索引列表,每个元素为 `(dataset_name, sample_index)`。
size (Optional[int]):
目标样本数量,若指定则根据该数量裁剪或重复样本。
weight (Optional[float]):
数据集权重,用于控制数据集在混合训练中的采样比例。
return:
list[tuple[str, int]]:
调整后的数据索引列表。
example:
>>> plugin = DataIndexPlugin()
>>> adjusted = plugin.adjust_data_index([("ds1", i) for i in range(100)], size=50, weight=None)
>>> len(adjusted)
50
"""
...
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
"""根据目标大小调整数据索引
通过裁剪或重复样本,使索引总数等于 `size`。
常用于统一不同数据集的样本数量。
args:
data_index (list[tuple[str, int]]):
原始数据索引列表。
size (int):
目标样本数量。
return:
list[tuple[str, int]]:
调整后长度等于 `size` 的数据索引列表。
example:
>>> plugin.adjust_by_size([("ds1", i) for i in range(10)], 20)
"""
...
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
"""根据权重调整数据索引
通过加权采样或重复样本,使数据集样本出现频率符合指定权重。
常用于多数据源训练中按比例平衡样本。
args:
data_index (list[tuple[str, int]]):
原始数据索引列表。
weight (float):
数据集权重(相对比例,可与其他数据集共同归一化)。
return:
list[tuple[str, int]]:
调整后的加权数据索引列表。
example:
>>> plugin.adjust_by_weight([("ds1", i) for i in range(10)], 0.5)
"""
...
```
---
## DataSelectorPlugin
### 1. DataSelectorPlugin 简介
`DataSelectorPlugin``DataEngine`提供基于索引访问数据的功能,由 `DataEngine``__getitem__` 方法自动调用。
### 2. DataSelectorPlugin 接口定义
```python
@dataclass
class DataSelectorPlugin:
"""根据索引选择数据集样本。
配合 `DataEngine` 使用,通过统一的 `data_index` 结构(包含数据集名与样本索引)来实现灵活的数据选择
"""
data_index: list[tuple[str, int]]
"""数据索引列表,每个元素为 (dataset_name, sample_index)。"""
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
"""选择数据集样本
根据输入类型从 `data_index` 中选择对应的样本索引
支持三种索引方式:
- 切片slice返回对应范围内的样本
- 索引列表list[int]):返回指定索引处的多个样本
- 其他类型输入将触发异常。
args:
index (Union[slice, list[int], Any]): 数据样本索引
可以是切片(`slice`)或索引列表
return:
Union[tuple[str, int], list[tuple[str, int]]]:
- 若为单个索引:返回一个 `(dataset_name, sample_index)`
- 若为多个索引或切片:返回多个样本的列表
except:
Raises:
ValueError: 当输入索引类型不受支持时抛出。
...
```

View File

@@ -0,0 +1,197 @@
# Kernels plugins
## 概览
LLaMA-Factory 通过 Kernels plugins 系统依据不同硬件设备提供高性能计算内核kernel实现。该系统通过注册表机制管理所有 kernel通过 `@register_kernel` 装饰器实现 kernel 定义后自动注册,由 `apply_kernel` 方法来使能指定的 kernel`apply_default_kernels` 可使能注册表中当前环境所有可用的默认 kernels。
## 架构设计
### 核心组件
#### 1. Registry注册表
`Registry` 是一个用于管理所有 kernel 实现的静态类。它维护一个字典结构:`{kernel_id: KernelClass}`
```python
# 注册表结构示例
{
"npu_fused_rmsnorm": NpuRMSNormKernel,
"npu_fused_swiglu": NpuSwiGluKernel,
...
}
```
#### 2. register_kernel (装饰器)
`@register_kernel``Registry.register` 的别名。所有 kernel 类均应使用该装饰器进行注册。
**注册机制**
- 装饰器检查类是否继承自 `BaseKernel`
- 检查类是否定义了 `_kernel_id``_device` 属性。
- 检查 `_device` 是否与当前运行环境的加速器类型匹配。如果不匹配,则跳过注册。
- 如果一切符合要求,将 kernel 类注册到全局注册表中。
#### 3. BaseKernel基类
所有 kernel 的实现都必须继承自 `BaseKernel` 抽象基类。`BaseKernel` 定义了 kernel 的基本属性和接口。
#### 4. 标识系统
**Kernel ID** (`_kernel_id`)
每个 kernel 必须拥有一个唯一的字符串标识符,例如 `"npu_fused_rmsnorm"`
**Device Type** (`_device`)
kernel 必须声明其支持的设备类型,例如 `DeviceType.NPU``DeviceType.CUDA`
## Kernel 系统 API 设计
### **Registry**:全局 kernel 注册表
`Registry` 类提供了注册和获取 kernel 的接口:
```python
class Registry:
@classmethod
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
"""注册一个 kernel 类"""
...
@classmethod
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
"""根据 ID 获取 kernel 类"""
...
```
### **BaseKernel**
`BaseKernel` 定义了所有 kernel 必须实现的协议:
- `_kernel_id`: 类属性kernel 的唯一标识符。
- `_device`: 类属性kernel 支持的设备类型。
- `check_deps()`: 类方法,检查 kernel 的依赖项是否满足(如 `torch_npu` 是否安装)。
- `apply(**kwargs)`: 抽象类方法,实现 kernel 的具体应用逻辑。
```python
class BaseKernel(ABC):
_kernel_id: Any = ""
_device: DeviceType = DeviceType.CPU
@classmethod
def check_deps(cls) -> bool:
"""检查依赖项"""
...
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
"""应用 kernel 到模型"""
...
```
### **scan_all_kernels**
`scan_all_kernels` 函数会自动扫描 `ops` 目录下的所有 `.py` 文件并导入它们,从而触发 `@register_kernel` 装饰器完成自动注册。
### **apply_kernel**
对模型使能指定的 kernel。
```python
def apply_kernel(kernel_id: str, **kwargs) -> HFModel:
"""应用指定的 kernel 到模型
Args:
kernel_id: 目标 kernel 的 ID
**kwargs: 传递给 kernel.apply 的参数,通常包含 model
"""
```
**用法示例**
```python
from llamafactory.v1.plugins.model_plugins.kernels import apply_kernel
model = apply_kernel("npu_fused_rmsnorm", model=model)
```
### **apply_default_kernels**
对模型使能所有默认注册的 kernel。这是一个高级 API通常在模型加载流程中自动调用。
```python
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
"""应用所有默认 kernel
Args:
model: HFModel 实例
include_kernels: 包含的 kernel ID 列表(逗号分隔字符串),或者 "auto"/True 表示全部
"""
```
## 扩展 Kernels
如果用户有针对特定模型或者设备的 kernel可以按照下述步骤去实现并接入 LLaMA-Factory。
### 创建新 Kernel 的步骤
#### 1. 创建 Kernel 实现文件
`src/llamafactory/v1/plugins/model_plugins/kernels/ops` 下的相应子目录中创建新的 kernel 实现文件,例如 `mlp/cuda_swiglu.py`
```python
import torch
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
# 实现具体的 kernel 函数
def _cuda_swiglu_forward(self, hidden_state):
# ... CUDA 优化实现 ...
pass
@register_kernel
class CudaSwiGluKernel(BaseKernel):
_kernel_id = "cuda_fused_swiglu"
_device = DeviceType.CUDA
@classmethod
def apply(cls, **kwargs) -> HFModel:
model = kwargs.get("model")
if model is None:
raise ValueError("model is required")
if not cls.check_deps():
raise RuntimeError("Dependencies not met")
# 遍历模型并替换 forward 方法
for name, module in model.named_modules():
# ... 匹配和替换逻辑 ...
pass
return model
```
#### 2. 自动发现
由于 `scan_all_kernels` 会自动扫描 `ops` 目录,只要文件位于该目录下且没有语法错误,系统启动时会自动导入并注册,无需手动修改注册表代码。
#### 3. 测试 Kernel
创建测试用例验证 kernel 的正确性:
```python
from llamafactory.v1.plugins.model_plugins.kernels import apply_kernel
# ... 加载模型 ...
model = apply_kernel("cuda_fused_swiglu", model=model)
# ... 验证 forward 是否被替换 ...
```
## 异常处理
### 依赖不可用
`BaseKernel.check_deps()` 默认会检查当前设备类型是否匹配。子类可以重写此方法以添加额外的依赖检查(如检查特定的库是否安装)。如果 `check_deps()` 返回 `False``apply()` 方法应当抛出异常或进行相应处理。
### Kernel ID 未找到
如果调用 `apply_kernel` 时传入了不存在的 `kernel_id`,会抛出 `ValueError`

View File

@@ -0,0 +1,71 @@
# Getting Started
## 训练方法
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|:---------------------:| ------------------ | ------------------ | ------------------ | ------------------ |
| 指令监督微调 | :white_check_mark: | | | |
| 奖励模型训练 | | | | |
| DPO 训练 | | | | |
## 软件依赖
| 必需项 | 至少 | 推荐 |
|:---------------------:|--------|--------|
| python | 3.11 | 3.12 |
| torch | 2.7.1 | 2.7.1 |
| torch-npu(Ascend NPU) | 2.7.1 | 2.7.1 |
| torchvision | 0.22.1 | 0.22.1 |
| transformers | 5.0.0 | 5.0.0 |
| datasets | 3.2.0 | 4.0.0 |
| peft | 0.18.1 | 0.18.1 |
| 可选项 | 至少 | 推荐 |
|:----------------:|--------|--------|
| CUDA(NVIDIA GPU) | 11.6 | 12.2 |
| deepspeed | 0.18.4 | 0.18.4 |
| flash-attn(NVIDIA GPU) | 2.5.6 | 2.7.2 |
## 如何使用
### 安装 LLaMA Factory
> [!IMPORTANT]
> 此步骤为必需。
#### 从源码安装
```bash
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
cd LlamaFactory
pip install -e .
```
### 数据准备
关于数据集文件的格式,请参考 [data-preparation/README.md](data-preparation/README.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
> [!NOTE]
> 使用自定义数据集或自定义数据集格式时,请参照 [data-preparation/README.md](data-preparation/README.md) 进行配置,如有必要,请重新实现自定义数据集的数据处理逻辑,包括对应的`converter`。
您也可以使用 **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**、**[DataFlow](https://github.com/OpenDCAI/DataFlow)** 和 **[GraphGen](https://github.com/open-sciencelab/GraphGen)** 构建用于微调的合成数据。
### 快速开始
下面的命令展示了对 Qwen3-0.6B 模型使用 FSDP2 进行 全参**微调**,两行命令等价。
```bash
export USE_V1=1
llamafactory-cli sft examples/v1/train_full/train_full_fsdp2.yaml
llamafactory-cli train examples/v1/train_full/train_full_fsdp2.yaml
```
高级用法请参考 [advanced](./advanced/README.md)包括多卡多机微调、分布式、Lora、量化、以及各种加速特性等

View File

@@ -0,0 +1 @@
# Data Argument

62
docs/zh/index.rst Normal file
View File

@@ -0,0 +1,62 @@
LlamaFactory 文档
=================
.. toctree::
:maxdepth: 1
:caption: Getting Started
getting-started
installation
llamaboard-web-ui
.. toctree::
:maxdepth: 1
:caption: Data Preparation
data-preparation/data-processing
.. toctree::
:maxdepth: 1
:caption: Training
training/sft
training/dpo
.. toctree::
:maxdepth: 1
:caption: Inference
inference/deploy
.. toctree::
:maxdepth: 1
:caption: Advanced
advanced/lora-and-quantization/lora
advanced/lora-and-quantization/quantization
advanced/distributed/fsdp
advanced/distributed/deepspeed
advanced/distributed/parallel-dp-tp-ep-sp-cp
advanced/custom-kernels/triton
advanced/custom-kernels/fused-operators
.. toctree::
:maxdepth: 1
:caption: Hyperparameters
hyperparameters/data-argument
hyperparameters/model-argument
hyperparameters/sample-argument
hyperparameters/training-argument
.. toctree::
:maxdepth: 1
:caption: Dev Guide
dev-guide/core/data-engine
dev-guide/core/model-engine
dev-guide/core/trainer
dev-guide/plugins/data-plugins
dev-guide/plugins/model-plugins/initialization
dev-guide/plugins/model-plugins/kernels
dev-guide/plugins/model-plugins/rendering

View File

@@ -0,0 +1 @@
# Deploy

1
docs/zh/installation.md Normal file
View File

@@ -0,0 +1 @@
# Installation

View File

@@ -0,0 +1 @@
# LlamaBoard Web UI

1
docs/zh/training/dpo.md Normal file
View File

@@ -0,0 +1 @@
# DPO

1
docs/zh/training/sft.md Normal file
View File

@@ -0,0 +1 @@
# SFT

View File

@@ -0,0 +1,45 @@
### model
model_name_or_path: models/Llama-2-7b
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z0_config.json
use_asft_loss: true
asft_alpha: 0.1
### dataset
dataset: med
template: llama2
cutoff_len: 2048
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama2-7b/full/asft2
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 8
learning_rate: 2.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,45 @@
### model
model_name_or_path: models/Qwen2.5-7B
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z0_config.json
use_asft_loss: true
asft_alpha: 0.05
### dataset
dataset: math
template: qwen
cutoff_len: 2048
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2-7b/full/asft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 8
learning_rate: 5.0e-5
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,38 @@
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# Freeze Configuration
peft_config:
name: freeze
freeze_trainable_layers: 2 # Train the last 2 layers
freeze_trainable_modules: all # In these layers, train specific modules
freeze_extra_modules: null # Extra modules to train (e.g. embed_tokens, lm_head)
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_freeze
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 2.0e-5
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,25 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto
dist_config:
name: deepspeed
config_file: examples/deepspeed/ds_z3_config.json
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/Qwen3-0.6B-deepspeed
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10

View File

@@ -0,0 +1,7 @@
model: Qwen/Qwen3-4B
peft_config:
name: lora
adapter_name_or_path: ./outputs/test_lora
export_dir: ./merge_lora_model
export_size: 5
infer_dtype: auto

View File

@@ -0,0 +1,39 @@
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,43 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
# Quantization Config
quant_config:
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
quantization_bit: 4 # choice: 8/4(bnb)
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_quantization
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -1 +1 @@
liger-kernel>=0.5.5
liger-kernel>=0.6.3

View File

@@ -213,6 +213,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
and getattr(self.model.config, "model_type", None)
in [
"glm4v",
"glm_ocr",
"Keye",
"qwen2_vl",
"qwen2_5_vl",

View File

@@ -459,6 +459,18 @@ class ReasoningTemplate(Template):
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
@dataclass
class Glm47ReasoningTemplate(ReasoningTemplate):
r"""GLM-4.7 uses only the closing </think> tag for empty thinking blocks."""
@override
def add_thought(self, content: str = "") -> str:
if not content:
return self.thought_words[1]
return self.thought_words[0] + content + self.thought_words[1]
TEMPLATES: dict[str, "Template"] = {}
@@ -1049,6 +1061,39 @@ register_template(
)
# copied from glm4 template
register_template(
name="glm_ocr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
)
# copied from glm4_moe template
register_template(
name="glm4_7",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4_moe"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
thought_words=("<think>", "</think>"),
efficient_eos=True,
template_class=Glm47ReasoningTemplate,
)
# copied from glm4 template
register_template(
name="glmz1",

View File

@@ -65,6 +65,7 @@ MCA_SUPPORTED_MODELS = {
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3",
"qwen3_moe",
"qwen3_next",
@@ -939,6 +940,29 @@ register_model_group(
)
register_model_group(
models={
"GLM-4.7-Flash": {
DownloadSource.DEFAULT: "zai-org/GLM-4.7-Flash",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.7-Flash",
},
},
template="glm4_7",
)
register_model_group(
models={
"GLM-OCR": {
DownloadSource.DEFAULT: "zai-org/GLM-OCR",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-OCR",
},
},
template="glm_ocr",
multimodal=True,
)
register_model_group(
models={
"GLM-Z1-0414-9B-Chat": {

View File

@@ -490,6 +490,14 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether to use the DFT loss."},
)
use_asft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the ASFT loss."},
)
asft_alpha: float = field(
default=0.1,
metadata={"help": "The alpha parameter for ASFT loss to control the power of adaptive weight."},
)
use_eaft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the EAFT loss."},

View File

@@ -77,6 +77,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
elif model_type == "qwen3_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
elif model_type == "qwen3_next":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel

View File

@@ -77,6 +77,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Glm4MoeMoE])
if model_type == "glm4_moe_lite":
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import Glm4MoeLiteMoE
_set_z3_leaf_modules(model, [Glm4MoeLiteMoE])
if model_type == "glm4v_moe":
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextMoE
@@ -137,6 +142,10 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
if model_type == "qwen3_next":
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:

View File

@@ -239,6 +239,15 @@ _register_composite_model(
)
_register_composite_model(
model_type="glm_ocr",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="internvl",
)

View File

@@ -82,9 +82,34 @@ def _check_model_support(model_args: "ModelArguments"):
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if config.model_type not in MCA_SUPPORTED_MODELS:
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
raise ValueError(
f"Model {config.model_type} is not supported by mcore_adapter."
"You can try to upgrade mcore_adapter to the latest version for more supported models."
)
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
return
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
params_to_freeze.extend(["vision_model.pos_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -161,22 +186,8 @@ def run_sft(
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
# optional freezing for qwen_vl series
_freeze_model_parameters(model, finetuning_args)
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
data_collator = SFTDataCollatorWith4DAttentionMask(
@@ -229,6 +240,8 @@ def run_dpo(
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
_freeze_model_parameters(model, finetuning_args)
if finetuning_args.use_ref_model:
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
ref_model = AutoModel.from_config(ref_config)

View File

@@ -17,6 +17,7 @@
import json
import os
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Optional, Union
@@ -52,6 +53,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
processor: Optional["ProcessorMixin"],
model_args: Optional["ModelArguments"] = None,
gen_kwargs: Optional[dict[str, Any]] = None,
ref_model: Optional["torch.nn.Module"] = None,
**kwargs,
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
@@ -82,6 +84,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
self.ref_model = ref_model
if ref_model is not None:
from trl.models.utils import prepare_deepspeed, prepare_fsdp
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
if self.accelerator.is_fsdp2:
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
else:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_dft_loss:
from ..trainer_utils import dft_loss_func
@@ -93,6 +116,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
elif finetuning_args.use_asft_loss:
from ..trainer_utils import asft_loss_func
self.compute_loss_func = partial(
asft_loss_func,
asft_alpha=finetuning_args.asft_alpha,
)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@@ -119,7 +149,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
@override
def compute_loss(self, model, inputs, *args, **kwargs):
return super().compute_loss(model, inputs, *args, **kwargs)
if self.finetuning_args.use_asft_loss:
with torch.no_grad():
ref_outputs = self.ref_model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
ref_logits = ref_outputs.logits
outputs = model(**inputs)
return self.compute_loss_func(outputs, inputs["labels"], ref_logits)
else:
return super().compute_loss(model, inputs, *args, **kwargs)
@override
def prediction_step(

View File

@@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps
from ...extras.packages import is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
@@ -52,6 +52,10 @@ def run_sft(
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
ref_model = None
if finetuning_args.use_asft_loss:
ref_model = create_ref_model(model_args, finetuning_args)
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
@@ -124,6 +128,7 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
**dataset_module,
**tokenizer_module,
**metric_module,

View File

@@ -23,6 +23,7 @@ from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
import torch.nn.functional as F
from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
@@ -681,6 +682,88 @@ def _dft_cross_entropy(
return loss
def asft_loss_func(
outputs,
labels: torch.Tensor,
ref_logits: torch.Tensor,
asft_alpha: float = 0.1,
ignore_index: int = -100,
) -> torch.Tensor:
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
logits = logits.float()
# shift for causal LM
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_ref_logits = ref_logits[..., :-1, :].contiguous()
vocab_size = shift_logits.size(-1)
# flatten
shift_logits = shift_logits.view(-1, vocab_size)
shift_ref_logits = shift_ref_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1).to(shift_logits.device)
return _asft_cross_entropy(
policy_logits=shift_logits,
policy_labels=shift_labels,
ref_logits=shift_ref_logits,
asft_alpha=asft_alpha,
ignore_index=ignore_index,
)
def _asft_cross_entropy(
policy_logits: torch.Tensor,
policy_labels: torch.Tensor,
ref_logits: torch.Tensor,
asft_alpha: float = 0.1,
ignore_index: int = -100,
) -> torch.Tensor:
dft_loss = _dft_cross_entropy(
policy_logits,
policy_labels,
ignore_index=ignore_index,
)
kl_loss = _kl_divergence(
policy_logits,
ref_logits,
policy_labels,
ignore_index=ignore_index,
)
return dft_loss + asft_alpha * kl_loss
def _kl_divergence(
policy_logits: torch.Tensor,
ref_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
# log p(y|x)
log_p = F.log_softmax(policy_logits, dim=-1)
# q(y|x)
q = F.softmax(ref_logits, dim=-1)
# token-wise KL
kl = F.kl_div(
log_p,
q,
reduction="none",
).sum(dim=-1) # [N]
# mask padding tokens
mask = (labels != ignore_index).float()
return (kl * mask).sum() / mask.sum()
def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",

View File

@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
model = model.to(output_dtype)
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained(
save_directory=model_args.export_dir,
max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format),
)
# Prepare save arguments (safe_serialization removed in transformers v5.0.0)
save_kwargs = {
"save_directory": model_args.export_dir,
"max_shard_size": f"{model_args.export_size}GB",
}
if not is_transformers_version_greater_than("5.0.0"):
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.save_pretrained(**save_kwargs)
if model_args.export_hub_model_id is not None:
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
push_kwargs = {
"max_shard_size": f"{model_args.export_size}GB",
}
if not is_transformers_version_greater_than("5.0.0"):
push_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.push_to_hub(
model_args.export_hub_model_id,
token=model_args.hf_hub_token,
max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format),
**push_kwargs,
)
if finetuning_args.stage == "rm":

View File

@@ -76,19 +76,28 @@ class BaseTrainer:
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
if self.args.dist_config is not None:
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
else:
shard_need_optimizer = False
self._accelerate_engine = None
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
if shard_need_optimizer:
if dist_name == "deepspeed":
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self._deepspeed_engine = DistributedPlugin("deepspeed")(
self.model,
self.args.dist_config,
num_micro_batch=self.train_batch_generator.num_micro_batch,
micro_batch_size=self.args.micro_batch_size,
)
self._init_optimizer()
self._shard_model()
self._init_lr_scheduler()
self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare(
self.model, self.optimizer, self.lr_scheduler
)
else:
# fsdp2 / DDP / no dist
self._shard_model()
self._init_optimizer()
self._init_lr_scheduler()
self._init_lr_scheduler()
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
@@ -171,25 +180,35 @@ class BaseTrainer:
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
for micro_batch in micro_batches:
num_micro = len(micro_batches)
for i, micro_batch in enumerate(micro_batches):
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
loss.backward()
if self._deepspeed_engine is not None:
# deepspeed: set sync_gradients so engine.step() only fires on last micro-batch
self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1
self._deepspeed_engine.backward(loss)
else:
loss.backward()
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
if self._deepspeed_engine is not None:
# deepspeed: engine.step() already ran inside backward at the sync boundary
grad_norm = self._deepspeed_engine.get_grad_norm()
else:
self.optimizer.step()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
@@ -203,7 +222,14 @@ class BaseTrainer:
def save_model(self) -> None:
"""Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
DistributedPlugin(self.args.dist_config.name).save_model(
self.model, self.args.output_dir, self.renderer.processor
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -90,6 +90,26 @@ class ModelEngine:
Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
init_kwargs = {"device_map": init_device}
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
init_kwargs = QuantizationPlugin(self.args.quant_config.name)(
init_kwargs=init_kwargs,
config=self.model_config,
tokenizer=self.processor,
model_args=self.args,
is_trainable=self.is_train,
)
if self.args.model_class == ModelClass.LLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
@@ -107,14 +127,8 @@ class ModelEngine:
AutoClass = AutoModel
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META:
assert self.args.quant_config is None, "Quantization is not supported with meta device."
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
@@ -122,8 +136,8 @@ class ModelEngine:
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
)
if self.args.peft_config is None:

View File

@@ -125,6 +125,11 @@ def launch():
run_chat()
elif command == "merge":
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
merge_and_export_model()
elif command == "env":
raise NotImplementedError("Environment information is not implemented yet.")

View File

@@ -12,14 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
import re
from typing import Literal, TypedDict, Union
from peft import LoraConfig, PeftModel, get_peft_model
import torch
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from ...config import InputArgument, get_args
from ...core.model_engine import ModelEngine
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import HFModel
logger = logging.get_logger(__name__)
class LoraConfigDict(TypedDict, total=False):
name: Literal["lora"]
"""Plugin name."""
@@ -27,8 +35,28 @@ class LoraConfigDict(TypedDict, total=False):
"""Lora rank."""
lora_alpha: int
"""Lora alpha."""
target_modules: list[str]
lora_dropout: float
"""Lora dropout."""
target_modules: Union[list[str], str]
"""Target modules."""
use_rslora: bool
"""Use RS-LoRA."""
use_dora: bool
"""Use DoRA."""
modules_to_save: list[str]
"""Modules to save."""
adapter_name_or_path: Union[list[str], str]
"""Path to the adapter(s)."""
export_dir: str
"""Path to the export directory."""
export_size: int
"""Shard size for the export model."""
export_hub_model_id: str
"""Hub model ID for the export model."""
infer_dtype: Literal["auto", "float16", "float32", "bfloat16"]
"""Inference data type for the export model."""
export_legacy_format: bool
"""Use legacy format for the export model."""
class FreezeConfigDict(TypedDict, total=False):
@@ -36,22 +64,280 @@ class FreezeConfigDict(TypedDict, total=False):
"""Plugin name."""
freeze_trainable_layers: int
"""Freeze trainable layers."""
freeze_trainable_modules: list[str] | None
freeze_trainable_modules: Union[list[str], str]
"""Freeze trainable modules."""
freeze_extra_modules: list[str]
"""Freeze extra modules."""
cast_trainable_params_to_fp32: bool
"""Cast trainable params to fp32."""
class PeftPlugin(BasePlugin):
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
return super().__call__(model, config)
return super().__call__(model, config, is_train)
def _find_all_linear_modules(model: HFModel) -> list[str]:
r"""Find all available modules to apply LoRA."""
forbidden_modules = {"lm_head", "output_layer", "output"}
module_names = set()
for name, module in model.named_modules():
if any(forbidden_module in name for forbidden_module in forbidden_modules):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
return list(module_names)
def merge_adapters(model: HFModel, adapter_name_or_path: Union[list[str], str]) -> HFModel:
if not isinstance(adapter_name_or_path, list):
adapter_name_or_path = [adapter_name_or_path]
for adapter_path in adapter_name_or_path:
model = PeftModel.from_pretrained(model, adapter_path)
model = model.merge_and_unload()
logger.info_rank0(f"Merged adapter from {adapter_path}")
return model
def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is_train: bool) -> HFModel:
r"""Loads adapter(s) into the model.
Determine adapter usage based on mode:
- Training: Load the single adapter for continued training.
- Inference: Merge all adapters to clean up the model.
- Unmergeable: Keep the single adapter active without merging.
"""
if not isinstance(adapter_name_or_path, list):
adapter_name_or_path = [adapter_name_or_path]
# TODO
# Adapters fix for deepspeed and quant
# Adapters fix for vision
if is_train and len(adapter_name_or_path) > 1:
raise ValueError(
"When `adapter_name_or_path` is provided for training, only a single LoRA adapter is supported. "
"Training will continue on the specified adapter. "
"Please merge multiple adapters before starting a new LoRA adapter."
)
if is_train:
adapter_to_merge = []
adapter_to_resume = adapter_name_or_path[0]
else:
adapter_to_merge = adapter_name_or_path
adapter_to_resume = None
if adapter_to_merge:
model = merge_adapters(model, adapter_to_merge)
if adapter_to_resume is not None:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_train)
if is_train:
logger.info_rank0(
f"Resuming training from existing LoRA adapter at {adapter_to_resume}. "
"LoRA hyperparameters will be loaded from the adapter itself; "
"the current LoRA configuration will be ignored. "
"Merge the adapter into the base model before training if you want to start a new adapter."
)
return model
@PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
adapter_name_or_path = config.get("adapter_name_or_path")
if adapter_name_or_path:
return load_adapter(model, adapter_name_or_path, is_train)
logger.info_rank0("Fine-tuning method: LoRA")
target_modules = config.get("target_modules", "all")
# Handle target modules
if target_modules == "all":
target_modules = _find_all_linear_modules(model)
elif isinstance(target_modules, str):
target_modules = [target_modules]
logger.info_rank0(f"LoRA target modules: {target_modules}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=not is_train,
r=config.get("r", 8),
lora_alpha=config.get("lora_alpha", 16),
lora_dropout=config.get("lora_dropout", 0.05),
use_rslora=config.get("use_rslora", False),
use_dora=config.get("use_dora", False),
target_modules=target_modules,
modules_to_save=config.get("modules_to_save", None),
)
model = get_peft_model(model, peft_config)
if is_train:
model.print_trainable_parameters()
return model
@PeftPlugin("freeze").register()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool = False) -> HFModel:
logger.info_rank0("Fine-tuning method: Freeze")
if not is_train:
return model
freeze_trainable_layers = config.get("freeze_trainable_layers", 2)
freeze_trainable_modules = config.get("freeze_trainable_modules", ["all"])
freeze_extra_modules = config.get("freeze_extra_modules", [])
cast_trainable_params_to_fp32 = config.get("cast_trainable_params_to_fp32", True)
if isinstance(freeze_trainable_modules, str):
freeze_trainable_modules = [module.strip() for module in freeze_trainable_modules.split(",")]
if isinstance(freeze_extra_modules, str):
freeze_extra_modules = [module.strip() for module in freeze_extra_modules.split(",")]
# Get number of layers
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if freeze_trainable_layers > 0:
# last n layers
trainable_layer_ids = range(max(0, num_layers - freeze_trainable_layers), num_layers)
else:
# first n layers
trainable_layer_ids = range(min(-freeze_trainable_layers, num_layers))
# Identify hidden and non-hidden modules
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name:
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
# Build list of trainable layer patterns
trainable_layers = []
for module_name in freeze_trainable_modules:
if module_name == "all":
for idx in trainable_layer_ids:
trainable_layers.append(f".{idx:d}.")
elif module_name in hidden_modules:
for idx in trainable_layer_ids:
trainable_layers.append(f".{idx:d}.{module_name}")
else:
raise ValueError(f"Module {module_name} not found in hidden modules: {hidden_modules}")
# Add extra modules
if freeze_extra_modules:
for module_name in freeze_extra_modules:
if module_name in non_hidden_modules:
trainable_layers.append(module_name)
else:
raise ValueError(f"Module {module_name} not found in non-hidden modules: {non_hidden_modules}")
# TODO
# Multi-modal special handling
# Set requires_grad
forbidden_modules = {"quant_state", "quantization_weight", "qweight", "qzeros", "scales"}
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
param.requires_grad_(True)
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32) # Cast to fp32 for stability
else:
param.requires_grad_(False)
logger.info_rank0(f"Set trainable layers: {trainable_layers}")
# Count trainable params for verification
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
logger.info_rank0(
f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params:.4f}"
)
return model
def merge_and_export_model(args: InputArgument = None):
model_args, _, _, _ = get_args(args)
export_config = model_args.peft_config
if export_config is None:
raise ValueError("Please specify peft_config to merge and export model.")
export_dir = export_config.get("export_dir")
if export_dir is None:
raise ValueError("Please specify export_dir.")
export_size = export_config.get("export_size", 5)
export_hub_model_id = export_config.get("export_hub_model_id")
infer_dtype = export_config.get("infer_dtype", "auto")
export_legacy_format = export_config.get("export_legacy_format", False)
adapters = None
if export_config.get("name") == "lora":
adapters = export_config.get("adapter_name_or_path")
else:
raise ValueError("Currently merge and export model function is only supported for lora.")
if adapters is None:
raise ValueError("Please set adapter_name_or_path to merge adapters into base model.")
logger.info_rank0("Loading model for export...")
model_engine = ModelEngine(model_args, is_train=False)
model = model_engine.model
tokenizer = model_engine.processor
if infer_dtype == "auto":
if model.config.torch_dtype == torch.float32 and torch.cuda.is_bf16_supported():
model = model.to(torch.bfloat16)
logger.info_rank0("Converted model to bfloat16.")
else:
target_dtype = getattr(torch, infer_dtype)
model = model.to(target_dtype)
logger.info_rank0(f"Converted model to {infer_dtype}.")
logger.info_rank0(f"Exporting model to {export_dir}...")
model.save_pretrained(
export_dir,
max_shard_size=f"{export_size}GB",
safe_serialization=not export_legacy_format,
)
if tokenizer is not None:
try:
if hasattr(tokenizer, "padding_side"):
tokenizer.padding_side = "left"
tokenizer.save_pretrained(export_dir)
except Exception as e:
logger.warning(f"Failed to save tokenizer: {e}")
if export_hub_model_id:
logger.info_rank0(f"Pushing to hub: {export_hub_model_id}...")
model.push_to_hub(export_hub_model_id)
if tokenizer is not None:
tokenizer.push_to_hub(export_hub_model_id)
logger.info_rank0("Model exported successfully.")

View File

@@ -0,0 +1,122 @@
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any
import torch
from transformers import BitsAndBytesConfig
from ...accelerator.helper import get_current_device
from ...config.model_args import ModelArguments
from ...utils import logging
from ...utils.packages import check_version
from ...utils.plugin import BasePlugin
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
logger = logging.get_logger(__name__)
class QuantizationPlugin(BasePlugin):
r"""Plugin for model quantization."""
def __call__(
self,
init_kwargs: dict[str, Any] = None,
config: "PretrainedConfig" = None,
tokenizer: "PreTrainedTokenizer" = None,
model_args: "ModelArguments" = None,
is_trainable: bool = False,
) -> dict[str, Any]:
return super().__call__(
init_kwargs, config=config, tokenizer=tokenizer, model_args=model_args, is_trainable=is_trainable
)
@QuantizationPlugin("auto").register()
def quantization_auto(
init_kwargs: dict[str, Any],
**kwargs,
) -> dict[str, Any]:
"""Automatic quantization selection, only support bnb currently.
Args:
init_kwargs (dict[str, Any]): The kwargs for model initialization.
**kwargs: Keyword arguments containing the model.
Returns:
dict[str, Any]: The updated kwargs for model initialization.
"""
model_args: ModelArguments = kwargs.get("model_args", None)
quant_config = model_args.quant_config
quantization_bit = quant_config.get("quantization_bit", None)
if quantization_bit is not None:
logger.info_rank0(f"Loading {quantization_bit}-bit quantized model.")
if quantization_bit in [8, 4]:
return quantization_with_bnb(init_kwargs, **kwargs)
else:
raise ValueError(f"Unsupported quantization bit: {quantization_bit} for auto quantization.")
logger.warning_rank0("No quantization method applied.")
return init_kwargs
@QuantizationPlugin("bnb").register()
def quantization_with_bnb(
init_kwargs: dict[str, Any],
model_args: "ModelArguments" = None,
**kwargs,
) -> dict[str, Any]:
r"""Quantization with BNB."""
logger.info_rank0("Using Bitsandbytes quantization.")
quantization_bit = model_args.quant_config.get("quantization_bit", None)
if quantization_bit is None:
logger.warning_rank0("quantization_bit is not specified, default to 8-bit quantization.")
quantization_bit = 4
assert quantization_bit in [8, 4], "Bitsandbytes only accepts 4-bit or 8-bit quantization."
if quantization_bit == 8:
check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif quantization_bit == 4:
check_version("bitsandbytes>=0.39.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.quant_config.get("compute_dtype", torch.float16),
bnb_4bit_use_double_quant=model_args.quant_config.get("double_quantization", True),
bnb_4bit_quant_type=model_args.quant_config.get("quantization_type", "nf4"),
bnb_4bit_quant_storage=model_args.quant_config.get(
"compute_dtype", torch.float16
), # crucial for fsdp+qlora
)
else:
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
# TODO: improve deepspeed zero3 and fsdp detection.
if kwargs.get("is_trainable", False):
logger.info_rank0("Detected inference mode, setting device_map for bitsandbytes quantization.")
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
else:
logger.info_rank0("Detected training mode, skip setting device_map for bitsandbytes quantization.")
if model_args.quant_config.get("quantization_bit") != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
check_version("bitsandbytes>=0.43.0", mandatory=True)
logger.info_rank0(f"Quantizing model to {model_args.quant_config.get('quantization_bit')} bit with bitsandbytes.")
return init_kwargs

View File

@@ -0,0 +1,129 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DeepSpeed integration via accelerate's built-in capabilities.
Instead of manually calling deepspeed.initialize() and syncing config,
this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
initialization, backward, gradient accumulation, and model saving.
"""
from typing import Any, Optional
import torch
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
class DeepSpeedEngine:
"""DeepSpeed integration using accelerate's built-in capabilities.
This replaces the manual DeepSpeedConfigHelper / DeepSpeedEngine approach
with accelerate's Accelerator + DeepSpeedPlugin, which handles:
- Config syncing (auto values, batch size, lr, etc.)
- deepspeed.initialize() call
- Optimizer / LR scheduler wrapping
- Backward + gradient accumulation boundary
- ZeRO-3 parameter gathering for saving
"""
def __init__(self, dist_config: dict[str, Any], num_micro_batch: int = 1, micro_batch_size: int = 1):
config_file = dist_config.get("config_file")
if not config_file:
raise ValueError("DeepSpeed config_file is required in dist_config")
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
self.accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
gradient_accumulation_steps=num_micro_batch,
)
# Resolve "auto" for train_micro_batch_size_per_gpu so that
# accelerate.prepare() does not require a DataLoader to infer it.
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
if ds_config.get("train_micro_batch_size_per_gpu") in (None, "auto"):
ds_config["train_micro_batch_size_per_gpu"] = micro_batch_size
logger.info_rank0(f"DeepSpeedEngine initialized with config: {config_file}")
def shard_model(self, model: HFModel) -> "DeepSpeedEngine":
"""No-op shard — actual model wrapping happens in prepare().
Returns self so the caller gets the engine instance via the hub interface.
"""
return self
def prepare(
self,
model: HFModel,
optimizer: torch.optim.Optimizer,
lr_scheduler: Optional[Any] = None,
) -> tuple[HFModel, torch.optim.Optimizer, Any]:
"""Prepare model, optimizer, and lr_scheduler using accelerate.
Internally calls deepspeed.initialize() and wraps the returned objects.
"""
if lr_scheduler is not None:
model, optimizer, lr_scheduler = self.accelerator.prepare(model, optimizer, lr_scheduler)
else:
model, optimizer = self.accelerator.prepare(model, optimizer)
model._accelerator = self.accelerator # type: ignore[assignment]
logger.info_rank0("Model, optimizer, and lr_scheduler prepared via accelerate")
return model, optimizer, lr_scheduler
def backward(self, loss: torch.Tensor) -> None:
"""Backward pass using accelerate.
Delegates to DeepSpeedEngineWrapper.backward() which respects
sync_gradients to control gradient accumulation boundaries.
When sync_gradients=True: engine.backward(loss) + engine.step()
When sync_gradients=False: engine.backward(loss) only
"""
self.accelerator.backward(loss)
def get_grad_norm(self) -> float:
"""Get the global gradient norm from the DeepSpeed engine."""
engine_wrapper = getattr(self.accelerator, "deepspeed_engine_wrapped", None)
if engine_wrapper is not None:
return engine_wrapper.engine.get_global_grad_norm() or 0.0
return 0.0
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
"""Save model using accelerate's built-in ZeRO-aware utilities.
Expects model._accelerator to be set during prepare().
Handles ZeRO-3 parameter gathering automatically via
accelerator.get_state_dict().
"""
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
unwrapped_model = accelerator.unwrap_model(model)
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
accelerator.wait_for_everyone()
logger.info_rank0(f"Model saved to {output_dir}")

View File

@@ -17,23 +17,24 @@ import os
import torch
import torch.nn as nn
from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from transformers import PreTrainedModel
from ....accelerator.helper import get_current_accelerator
from ....accelerator.interface import DistributedInterface
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules:
if isinstance(no_split_modules, (list, tuple)):
@@ -49,6 +50,20 @@ def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
return None
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
if DistributedInterface().get_rank() == 0:
logger.info("Gathering state dict for saving...")
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(model, options=options)
if DistributedInterface().get_rank() == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
logger.info(f"Model saved to {output_dir}")
class FSDP2Engine:
def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface()
@@ -94,7 +109,10 @@ class FSDP2Engine:
cast_forward_inputs=True,
)
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
def is_lora_module_wrap(self, model) -> bool:
return any(isinstance(module, LoraLayer) for module in model.modules())
def prepare_model(self, model: HFModel) -> HFModel:
if self.fsdp_mesh is None:
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
return model
@@ -111,6 +129,25 @@ class FSDP2Engine:
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
transformer_layer_cls_to_wrap = {layer_cls}
if self.is_lora_module_wrap(model):
lora_modules = []
for module in model.modules():
if len(list(module.children())) != 0:
continue
if any(param.requires_grad for param in module.parameters(recurse=False)):
lora_modules.append(module)
for module in lora_modules:
fully_shard(
module,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
logger.info("Applying FSDP wrap for LoRA layer separately.")
for name, module in model.named_modules():
should_wrap = False
@@ -156,7 +193,7 @@ class FSDP2Engine:
return model
@torch.no_grad()
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None):
if self.rank == 0:
logger.info("Materializing sharded model params...")
@@ -176,7 +213,7 @@ class FSDP2Engine:
return model
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta":
model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
@@ -184,7 +221,7 @@ class FSDP2Engine:
model = self.prepare_model(model)
return model
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
def _load_from_dcp(self, model: HFModel, dcp_path: str):
import torch.distributed.checkpoint as dcp
try:
@@ -203,7 +240,7 @@ class FSDP2Engine:
logger.error(f"Failed to load from DCP: {e}")
raise e
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob
import json

View File

@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
if TYPE_CHECKING:
from ....utils.types import HFModel, Processor
class DistributedPlugin(BasePlugin):
@@ -23,12 +30,32 @@ class DistributedPlugin(BasePlugin):
@DistributedPlugin("fsdp2").register()
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .fsdp2 import FSDP2Engine
return FSDP2Engine(dist_config).shard_model(model)
@DistributedPlugin("fsdp2").register("save_model")
def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None:
from .fsdp2 import save_model
return save_model(model, output_dir, processor)
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
return model
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(
dist_config,
num_micro_batch=kwargs.get("num_micro_batch"),
micro_batch_size=kwargs.get("micro_batch_size"),
).shard_model(model)
@DistributedPlugin("deepspeed").register("save_model")
def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None:
from .deepspeed import save_model
return save_model(model, output_dir, processor)

View File

@@ -33,7 +33,7 @@ def run_sft(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
DistributedInterface(training_args.dist_config)
train_dataset = DataEngine(data_args.train_dataset)
model_engine = ModelEngine(model_args)
model_engine = ModelEngine(model_args, is_train=True)
trainer = SFTTrainer(
args=training_args,
model=model_engine.model,

View File

@@ -21,6 +21,13 @@ from functools import lru_cache
from typing import TYPE_CHECKING
from packaging import version
from transformers.utils.versions import require_version
from . import logging
from .env import is_env_enabled
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
@@ -41,3 +48,22 @@ def _get_package_version(name: str) -> "Version":
@lru_cache
def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""Optionally check the package version."""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if "gptqmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory:
hint = f"To fix: run `{pip_command}`."
else:
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)

View File

@@ -166,3 +166,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
def fix_valuehead_cpu_loading():
"""Fix valuehead model loading."""
patch_valuehead_model()
@pytest.fixture(scope="session", autouse=True)
def bypass_mistral_regex_check():
"""Disable Mistral regex network check.
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
"""
try:
from transformers.tokenization_utils_fast import TokenizersBackend
except ImportError:
# Very old transformers, nothing to patch
yield
return
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
# Method does not exist in this version
yield
return
# Backup original method
original = TokenizersBackend._patch_mistral_regex
# Replace with no-op
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
yield
# Restore original method
TokenizersBackend._patch_mistral_regex = original

View File

@@ -172,3 +172,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
@pytest.fixture(scope="session", autouse=True)
def bypass_mistral_regex_check():
"""Disable Mistral regex network check.
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
"""
try:
from transformers.tokenization_utils_fast import TokenizersBackend
except ImportError:
# Very old transformers, nothing to patch
yield
return
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
# Method does not exist in this version
yield
return
# Backup original method
original = TokenizersBackend._patch_mistral_regex
# Replace with no-op
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
yield
# Restore original method
TokenizersBackend._patch_mistral_regex = original

View File

@@ -0,0 +1,156 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from llamafactory.v1.plugins.model_plugins import peft as peft_module
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
TINY_MODEL = "llamafactory/tiny-random-qwen3"
@pytest.fixture(scope="module")
def model_path():
return TINY_MODEL
@pytest.fixture(scope="function")
def model(model_path):
return AutoModelForCausalLM.from_pretrained(model_path)
@pytest.fixture(scope="function")
def tokenizer(model_path):
return AutoTokenizer.from_pretrained(model_path)
@pytest.fixture(scope="function")
def adapter_path(tmp_path):
# Create a dummy adapter
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
base_model = AutoModelForCausalLM.from_pretrained(TINY_MODEL)
peft_model = get_peft_model(base_model, lora_config)
save_path = tmp_path / "test_adapter"
peft_model.save_pretrained(save_path)
return str(save_path)
def test_find_all_linear_modules(model):
"""Verify linear modules are discoverable and include q_proj / v_proj for tiny-random-qwen3."""
modules = peft_module._find_all_linear_modules(model)
expected_subset = {"q_proj", "v_proj"}
assert expected_subset.issubset(set(modules))
def test_get_lora_model(model):
"""Verify a PeftModel is returned and LoRA config takes effect."""
config = {"name": "lora", "r": 8, "target_modules": "all", "lora_alpha": 16}
model = peft_module.get_lora_model(model, config, is_train=True)
assert isinstance(model, PeftModel)
assert model.peft_config["default"].r == 8
assert "q_proj" in model.peft_config["default"].target_modules
def test_get_freeze_model_layers(model):
"""Verify layer-wise freezing: only the last layer stays trainable."""
# Freeze all but last layer
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "all"}
# Ensure we start with something known
model = peft_module.get_freeze_model(model, config, is_train=True)
num_layers = model.config.num_hidden_layers
assert num_layers > 0
for name, param in model.named_parameters():
if f"layers.{num_layers - 1}" in name:
assert param.requires_grad, f"{name} should be trainable"
elif "layers.0" in name and num_layers > 1:
assert not param.requires_grad, f"{name} should be frozen"
def test_get_freeze_model_modules(model):
"""Verify module-wise freezing: only last-layer self_attn is trainable."""
# Freeze specific modules (e.g. only self_attn)
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "self_attn"}
model = peft_module.get_freeze_model(model, config, is_train=True)
num_layers = model.config.num_hidden_layers
for name, param in model.named_parameters():
if f"layers.{num_layers - 1}" in name and "self_attn" in name:
assert param.requires_grad, f"{name} should be trainable"
else:
assert not param.requires_grad, f"{name} should be frozen"
def test_load_adapter_single_for_inference(model, adapter_path):
"""Verify single adapter is merged+unloaded in inference mode."""
# Test loading single adapter for inference (merge and unload)
model_result = peft_module.load_adapter(model, adapter_path, is_train=False)
assert not isinstance(model_result, PeftModel)
def test_load_adapter_resume_train(model, adapter_path):
"""Verify training mode returns a trainable PeftModel."""
# Test loading for training
model_result = peft_module.load_adapter(model, adapter_path, is_train=True)
assert isinstance(model_result, PeftModel)
def test_load_adapter_train_multiple_disallowed(model, adapter_path):
"""Verify multiple adapters are rejected in training mode."""
with pytest.raises(ValueError, match="only a single LoRA adapter"):
peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=True)
def test_load_adapter_infer_multiple_merges(model, adapter_path):
"""Verify multiple adapters are merged in inference mode."""
# Test merging multiple adapters
model_result = peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=False)
assert not isinstance(model_result, PeftModel)
def test_merge_and_export_model(tmp_path, adapter_path):
"""Verify merge_and_export_model produces export artifacts."""
export_dir = tmp_path / "export"
args_dict = {
"model": TINY_MODEL,
"peft_config": {
"name": "lora",
"adapter_name_or_path": adapter_path,
"export_dir": str(export_dir),
"export_size": 1,
"infer_dtype": "float16",
},
}
merge_and_export_model(args_dict)
assert export_dir.exists()
assert (export_dir / "config.json").exists()
assert (export_dir / "model.safetensors").exists()
assert (export_dir / "tokenizer_config.json").exists()

View File

@@ -0,0 +1,51 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.core.model_engine import ModelEngine
bitsandbytes = pytest.importorskip("bitsandbytes")
def check_quantization_status(model):
quantized_info = {"bnb": []}
for name, module in model.named_modules():
# check BitsAndBytes quantization
if isinstance(module, bitsandbytes.nn.modules.Linear8bitLt) or isinstance(
module, bitsandbytes.nn.modules.Linear4bit
):
quantized_info["bnb"].append(name)
return quantized_info
@pytest.mark.runs_on(["cuda"])
@pytest.mark.parametrize("name, quantization_bit", [("bnb", 4), ("auto", 4)])
def test_quantization_plugin(name, quantization_bit):
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen3",
quant_config={
"name": name,
"quantization_bit": quantization_bit,
},
)
model_engine = ModelEngine(model_args=model_args)
quantized_info = check_quantization_status(model_engine.model)
print(f"Quantized weights for method {name} with {quantization_bit} bit: {quantized_info}")
assert any(v for v in quantized_info.values()), "model is not quantized properly."