[v1] init commit for v1 docs (#10145)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
Co-authored-by: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
浮梦
2026-02-09 19:43:55 +08:00
committed by GitHub
parent ea644d04ec
commit 1d5e8ebcd0
63 changed files with 2237 additions and 0 deletions

View File

@@ -0,0 +1,253 @@
# DataEngine
## 1. DataEngine 简介
`DataEngine` 是 LLaMA-Factory v1 数据处理的核心类,继承自 PyTorch 的 `Dataset`,负责各种插件的接入,其他功能(如数据格式转换、数据加载等)均通过插件的形式实现并接入 `DataEngine`
`DataEngine`接受一个唯一入参:`DataArguments` 实例,所有的元数据集信息均通过该参数配置传入。
## 2. DataEngine 与 DataArguments 接口定义
```python
@dataclass
class DataArguments:
""" `DataEngine`初始化入参
args:
dataset (str): 数据集路径,远程数据集 repo id / dataset_info.yaml 路径,或本地数据集路径/dataset_info.yaml路径
cutoff_len (int): 数据集截止长度,即数据集最大样本采样数量
"""
...
class DataEngine(Dataset):
"""数据引擎DataEngine
`DataEngine` 负责数据集的加载与统一管理,支持:
- 从本地路径或 Hugging Face Hub 加载数据
- 通过插件机制加载自定义数据
- 构建统一的数据索引
- 支持流式streaming与非流式数据访问
attr:
args (DataArguments): 数据参数配置
datasets (dict[str, HFDataset]): 数据集名称到数据对象的映射
dataset_infos (dict[str, DatasetInfo]): 数据集名称到元信息的映射
data_index (list[tuple[str, int]]): 数据索引列表,每项为 (dataset_name, sample_index)
streaming (bool): 是否为流式数据集
"""
def __init__(self, data_args: DataArguments) -> None:
"""初始化 `DataEngine`
初始化时自动执行以下步骤:
1. 调用 `get_dataset_info` 从 `data_args` 读取并解析数据集元信息
2. 调用 `load_dataset`,根据配置加载数据集
3. 调用 `build_data_index`,构建统一的索引列表
args:
data_args (DataArguments): 数据参数配置对象
"""
...
def get_dataset_info(self) -> None:
"""从配置文件或远程仓库加载数据集元信息
根据 `self.args.dataset` 确定数据源,数据源支持如下选项:
- 本地 YAML 配置文件路径
- Hugging Face Hub 上的 YAML 配置文件路径
- 本地数据集文件路径
- Hugging Face Hub 数据集 repo id
"""
...
def load_dataset(self) -> None:
"""根据数据集元信息加载所有数据集
每个数据集条目可以包含以下字段:
- `hf_hub_url`: 使用 `datasets.load_dataset` 加载
- 本地数据文件:通过 `DataLoaderPlugin` 插件加载
- `streaming`: 是否启用流式模式
更新:
self.datasets (dict): 数据集名称到已加载数据对象的映射
self.streaming (bool): 如果任一数据集为流式模式,则设置为 True
"""
...
def build_data_index(self) -> None:
"""构建统一的数据索引
为所有数据集创建全局索引列表 `(dataset_name, sample_index)`
当启用流式模式时,生成固定长度(例如 1000的占位索引
否则,为每条样本建立索引。
插件 `DataIndexPlugin` 可根据数据集大小或权重调整索引分布
"""
...
def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
"""将原始样本转换为统一格式
根据 `dataset_info` 中的 `converter` 字段,调用对应的转换插件,
将原始样本标准化为统一的数据结构。
args:
raw_sample (dict[str, Any]): 原始数据样本
dataset_name (str): 样本所属的数据集名称
return:
Sample: 转换后的标准化格式样本
"""
...
def __len__(self) -> int:
"""返回数据集的总样本数
return:
int: 数据集长度
如果为流式数据集,返回 `-1`
"""
...
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
"""根据索引或选择器获取样本
args:
index (Union[int, Any]): 数据索引int 或 list[int]
return:
Union[Sample, list[Sample]]: 单个样本或样本列表
"""
...
def __iter__(self) -> Iterable:
"""返回数据集迭代器
用于非流式数据集的顺序或随机访问
流式模式下需要实现异步加载逻辑
return:
Iterable: 数据集迭代器。
"""
...
async def __aiter__(self) -> AsyncIterable:
"""返回异步数据集迭代器
用于流式数据集或异步数据加载场景
允许在异步环境中以流的方式读取样本
return:
AsyncIterable: 异步迭代器,按顺序产出样本
"""
...
```
`DataArguments` 参数说明:
`dataset`: 数据集路径,支持本地或远程,当传入本地数据集文件路径时,需要满足该数据集为标准格式;否则需要传入 `dataset_info.yaml` 来配置数据集的 `converter` 等元信息,以告知 `DataEngine` 应当如何处理该数据。
`cutoff_len`: 数据集的截止长度,即该数据集的最大样本数量。
---
## 3. DataEngine 核心方法
### 3.1 `get_dataset_info`:加载数据元信息
根据 `dataset` 参数加载数据集配置,获取数据位置、数据格式、插件配置等所有数据元信息,在实例化 `DataEngine` 时会自动调用此方法。
### 3.2 加载数据集:`load_dataset`
遍历所有数据源,根据不同的数据源加载数据,在实例化 `DataEngine` 时会自动调用此方法。
```python
for key, value in self.dataset_infos.items():
split = value.get("split", "train")
streaming = value.get("streaming", False)
if "hf_hub_url" in value:
# 从 HF Hub 加载
dataset = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
else:
# 使用 DataLoaderPlugin 加载本地文件
dataset = DataLoaderPlugin(args=self.args).auto_load_data(value)
self.datasets[key] = dataset
```
### 3.3 `build_data_index`:构建数据索引
为每个数据集创建索引列表 `[(dataset_name, sample_index), ...]`, `DataIndexPlugin`插件在此处被调用,可控制各数据集的采样频率、采样方式等,在实例化`DataEngine`时会自动调用此方法。
```python
for dataset_name, dataset in self.datasets.items():
# 创建基础索引
data_index = [(dataset_name, idx) for idx in range(len(dataset))]
# 根据 size 和 weight 调整索引
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight:
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
```
### 3.4 `_convert_data_sample`:数据格式标准化
将原始数据转换为标准格式,`DataConverterPlugin`插件在此处被调用,具体调用的插件由 `get_dataset_info` 方法获取的 `converter` 信息指定,若 `converter` 为空则假定数据集为标准格式,此方法由`DataEngine``__getitem__` 方法调用。
```python
def _convert_data_sample(self, raw_sample: dict, dataset_name: str) -> Sample:
converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None:
# 使用指定的转换器
from ..plugins.data_plugins.converter import get_converter
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
else:
# 已经是标准格式
return {"_dataset_name": dataset_name, **raw_sample}
```
---
## 4. 初始化
`DataEngine` 初始化过程只需传入一个构建好的 `DataArguments` 即可,后续可通过该 `DataEngine` 访问数据集中的数据。
```python
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
# 1. 创建数据参数
data_args = DataArguments(
dataset="~/data/v1_sft_demo.jsonl",
cutoff_len=2048
)
# 2. 初始化 Data Engine
data_engine = DataEngine(data_args=data_args)
# 3. 访问数据
sample = data_engine[0] # 获取第一个样本
```
## 5. 数据访问方式
实例化后的`DataEngine`支持整数索引、列表索引、以及切片等访问方式,其数据读取用法可等价于 Python 列表。
```python
sample = data_engine[0] # 获取第一个样本
sample = data_engine[0:10] # 获取前 10 个样本
sample = data_engine[[0, 5, 10]] # 获取指定索引的样本
```

View File

@@ -0,0 +1 @@
# ModelEngine

View File

@@ -0,0 +1 @@
# Trainer

View File

@@ -0,0 +1,467 @@
# Data Plugins
## 1. Data Plugins 简介
## DataConverterPlugin
### 1. DataConverterPlugin 简介
DataConverter 负责将非标准格式的数据集转换为 v1 的标准 Messages 格式。这使得用户可以继续使用现有的数据集(如 Alpaca 格式),而无需手动转换。针对自定义格式的数据集,用户也可以通过构建对应的自定义 DataConverter 插件,来负责其数据格式标准化。
当前LLaMA-Factory 已内置了 `Alpaca Converter``Pair Converter`,这两类数据集可以直接使用对应的 converter 进行标准化,无需自定义转换器。
### 2. Alpaca Converter 详解
#### 2.1 Alpaca 格式
Alpaca 格式是一种常见的指令微调数据格式:
```json
{
"system": "You are a helpful assistant.",
"instruction": "Describe a process of making crepes.",
"input": "",
"output": "Making crepes is an easy and delicious process..."
}
```
#### 2.2 Alpaca Converter 接口定义
```python
class AlpacaSample(TypedDict, total=False):
"""Alpaca 格式数据样本结构
attr:
system (str, 可选): 系统提示信息system prompt用于设定对话背景或模型行为。
instruction (str, 可选): 用户指令user instruction通常为任务描述。
input (str, 可选): 额外的输入内容input text可与 instruction 拼接。
output (str, 可选): 模型生成的目标输出expected response
"""
...
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""将 Alpaca 样本转换为 SFTSupervised Fine-Tuning标准样本格式
`alpaca_converter` 将 Alpaca 数据集中一条样本转换为通用的 `SFTSample` 格式
该格式用于监督微调SFT或多轮对话建模
转换逻辑:
- 若存在 `system` 字段则生成一条系统消息loss_weight = 0.0
- 若存在 `instruction` 或 `input` 字段则合并为一条用户消息loss_weight = 0.0
- 若存在 `output` 字段则生成一条助手机器人回复消息loss_weight = 1.0
args:
raw_sample (AlpacaSample): 原始 Alpaca 数据样本
return:
SFTSample: 转换后的标准化样本,格式如下:
{
"messages": [
{"role": "system", "content": [{"type": "text", "value": "..."}], "loss_weight": 0.0},
{"role": "user", "content": [{"type": "text", "value": "..."}], "loss_weight": 0.0},
{"role": "assistant", "content": [{"type": "text", "value": "..."}], "loss_weight": 1.0},
]
}
example:
>>> raw = {"instruction": "请将以下句子翻译成英文:", "input": "你好", "output": "Hello"}
>>> alpaca_converter(raw)
{
"messages": [
{"role": "user", "content": [{"type": "text", "value": "请将以下句子翻译成英文:你好"}], "loss_weight": 0.0},
{"role": "assistant", "content": [{"type": "text", "value": "Hello"}], "loss_weight": 1.0}
]
}
"""
```
#### 2.3 转换过程
`alpaca_converter` 函数将 Alpaca 格式转换为标准格式,转换逻辑如下:
```python
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
messages = []
# 1. 添加系统提示词(如果存在)
if "system" in raw_sample:
messages.append({
"role": "system",
"content": [{"type": "text", "value": raw_sample["system"]}],
"loss_weight": 0.0
})
# 2. 添加用户输入instruction + input
if "instruction" in raw_sample or "input" in raw_sample:
user_content = raw_sample.get("instruction", "") + raw_sample.get("input", "")
messages.append({
"role": "user",
"content": [{"type": "text", "value": user_content}],
"loss_weight": 0.0
})
# 3. 添加模型回复
if "output" in raw_sample:
messages.append({
"role": "assistant",
"content": [{"type": "text", "value": raw_sample["output"]}],
"loss_weight": 1.0
})
return {"messages": messages}
```
#### 2.4 转换示例
**输入Alpaca 格式):**
```json
{
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris."
}
```
**输出(标准格式):**
```json
{
"messages": [
{
"role": "user",
"content": [{"type": "text", "value": "What is the capital of France?"}],
"loss_weight": 0.0
},
{
"role": "assistant",
"content": [{"type": "text", "value": "The capital of France is Paris."}],
"loss_weight": 1.0
}
]
}
```
### 3. 自定义转换器
#### 3.1 创建自定义转换器
如果用户有自己的数据格式,可以轻松添加自定义转换器将其标准化,实现过程可参考如下示例:
```python
# src/llamafactory/v1/plugins/data_plugins/converter.py
from typing import TypedDict, NotRequired
from ...extras.types import SFTSample
# 1. 定义输入格式的类型
class MyCustomSample(TypedDict, total=False):
question: str
answer: str
context: NotRequired[str]
# 2. 实现转换逻辑
def custom_converter(raw_sample: MyCustomSample) -> SFTSample:
messages = []
# 构建用户消息
user_text = raw_sample["question"]
if "context" in raw_sample:
user_text = f"Context: {raw_sample['context']}\n\nQuestion: {user_text}"
messages.append({
"role": "user",
"content": [{"type": "text", "value": user_text}],
"loss_weight": 0.0
})
# 构建助手消息
messages.append({
"role": "assistant",
"content": [{"type": "text", "value": raw_sample["answer"]}],
"loss_weight": 1.0
})
return {"messages": messages}
# 3. 注册 custom_converter
#src/llamafactory/v1/plugins/data_plugins/converter.py: CONVERTERS
CONVERTERS = {
"alpaca": alpaca_converter,
"custom": custom_converter, # 添加自定义转换器
}
```
#### 3.2 使用自定义转换器
在 YAML 配置中指定转换器名称:
```yaml
my_dataset:
file_name: custom_data.json
converter: custom
```
---
## DataLoaderPlugin
### 1. DataLoaderPlugin 简介
`DataLoaderPlugin` 负责从本地文件加载数据集,当前支持如下文件格式:
- **JSON**: `.json`
- **JSONL**: `.jsonl`
- **CSV**: `.csv`
- **Parquet**: `.parquet`
- **Arrow**: `.arrow`
- **Text**: `.txt`
### 2. DataLoaderPlugin 接口定义
```python
@dataclass
class DataLoaderPlugin:
"""数据加载插件DataLoaderPlugin
负责根据数据集信息(`DatasetInfo`)自动加载本地或远程数据集。
支持多种文件格式(如 CSV、JSON、Parquet、Text、Arrow并可选择是否以流式方式加载。
通常由 `DataEngine` 调用,用于统一封装数据加载逻辑。
"""
args: DataArguments
"""数据参数对象,包含数据目录、缓存路径、分片等配置信息。"""
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""获取数据集文件格式
根据输入文件路径自动判断应使用的 HuggingFace `load_dataset` 构建器类型。
通过文件扩展名推断数据类型,例如 `.csv`、`.jsonl`、`.parquet`、`.txt` 等。
args:
path (str): 数据集文件路径,用于识别文件类型。
return:
Literal["arrow", "csv", "json", "parquet", "text"]:
数据构建器名称,用于 `datasets.load_dataset()`。
example:
>>> _get_builder_name("data/train.jsonl")
"json"
"""
...
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
"""根据传入的 `dataset_info` 自动选择合适的加载方式
args:
dataset_info (DatasetInfo): 数据集元信息,通常包含:
- `file_name`: 数据文件路径
- `split`: 数据划分(如 "train"、"test"
- `streaming`: 是否启用流式加载
return:
HFDataset: 加载完成的 Hugging Face 数据集对象。
example:
>>> plugin = DataLoaderPlugin(args)
>>> ds = plugin.auto_load_data({"file_name": "~/data.json", "split": "train"})
"""
...
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
"""从文件或目录加载数据集
根据输入路径自动识别文件类型CSV、JSON、Parquet、Text 等),
并通过 `datasets.load_dataset()` 加载数据集。
若 `streaming=True`,则将结果转换为迭代式数据集。
args:
filepath (str): 文件路径或目录路径。
split (str): 数据划分名称(如 "train"、"validation")。
streaming (bool): 是否启用流式加载模式。
return:
HFDataset: 加载后的数据集对象。
example:
>>> plugin.load_data_from_file("data/train.json", "train", False)
"""
...
```
---
## DataIndexPlugin
### 1. DataIndexPlugin 简介
`DataIndexPlugin` 负责调整数据索引,支持通过配置 `size`, `weight` 等参数控制数据集样本数量和采样频率。
- 使用 `size` 参数 限制使用的样本数量:
```yaml
my_dataset:
file_name: large_dataset.json
size: 1000 # 只使用前 1000 个样本
```
- 使用 `weight` 参数调整数据集在混合数据中的采样频率:
```yaml
dataset_a:
file_name: data_a.json
weight: 1.0
dataset_b:
file_name: data_b.json
weight: 2.0 # dataset_b 的样本出现频率是 dataset_a 的 2 倍
```
**说明**`weight` 参数适用于在多个数据集混合训练时,调整不同数据集的的采样频率
-`weight=1.0` 时,数据集按原始比例采样
-`weight=2.0` 时,该数据集的索引会复制 2 倍,使其样本出现频率翻倍
### 2. DataIndexPlugin 接口定义
```python
@dataclass
class DataIndexPlugin:
"""数据索引插件DataIndexPlugin
根据 `size` 和 `weight` 调整数据索引列表,控制数据集的样本数量和采样频率
通常在多数据集混合训练时使用,以控制不同数据集在总体样本中的占比。
在 `DataEngine.build_data_index` 中被自动调用,用于实现样本重采样或加权分布。
"""
def adjust_data_index(
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
) -> list[tuple[str, int]]:
"""调整数据索引列表
根据 `size` 或 `weight` 参数对输入的数据索引进行采样、扩展或缩减。
若两个参数同时存在,将依次执行基于大小和基于权重的调整。
args:
data_index (list[tuple[str, int]]):
数据索引列表,每个元素为 `(dataset_name, sample_index)`。
size (Optional[int]):
目标样本数量,若指定则根据该数量裁剪或重复样本。
weight (Optional[float]):
数据集权重,用于控制数据集在混合训练中的采样比例。
return:
list[tuple[str, int]]:
调整后的数据索引列表。
example:
>>> plugin = DataIndexPlugin()
>>> adjusted = plugin.adjust_data_index([("ds1", i) for i in range(100)], size=50, weight=None)
>>> len(adjusted)
50
"""
...
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
"""根据目标大小调整数据索引
通过裁剪或重复样本,使索引总数等于 `size`。
常用于统一不同数据集的样本数量。
args:
data_index (list[tuple[str, int]]):
原始数据索引列表。
size (int):
目标样本数量。
return:
list[tuple[str, int]]:
调整后长度等于 `size` 的数据索引列表。
example:
>>> plugin.adjust_by_size([("ds1", i) for i in range(10)], 20)
"""
...
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
"""根据权重调整数据索引
通过加权采样或重复样本,使数据集样本出现频率符合指定权重。
常用于多数据源训练中按比例平衡样本。
args:
data_index (list[tuple[str, int]]):
原始数据索引列表。
weight (float):
数据集权重(相对比例,可与其他数据集共同归一化)。
return:
list[tuple[str, int]]:
调整后的加权数据索引列表。
example:
>>> plugin.adjust_by_weight([("ds1", i) for i in range(10)], 0.5)
"""
...
```
---
## DataSelectorPlugin
### 1. DataSelectorPlugin 简介
`DataSelectorPlugin``DataEngine`提供基于索引访问数据的功能,由 `DataEngine``__getitem__` 方法自动调用。
### 2. DataSelectorPlugin 接口定义
```python
@dataclass
class DataSelectorPlugin:
"""根据索引选择数据集样本。
配合 `DataEngine` 使用,通过统一的 `data_index` 结构(包含数据集名与样本索引)来实现灵活的数据选择
"""
data_index: list[tuple[str, int]]
"""数据索引列表,每个元素为 (dataset_name, sample_index)。"""
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
"""选择数据集样本
根据输入类型从 `data_index` 中选择对应的样本索引
支持三种索引方式:
- 切片slice返回对应范围内的样本
- 索引列表list[int]):返回指定索引处的多个样本
- 其他类型输入将触发异常。
args:
index (Union[slice, list[int], Any]): 数据样本索引
可以是切片(`slice`)或索引列表
return:
Union[tuple[str, int], list[tuple[str, int]]]:
- 若为单个索引:返回一个 `(dataset_name, sample_index)`
- 若为多个索引或切片:返回多个样本的列表
except:
Raises:
ValueError: 当输入索引类型不受支持时抛出。
...
```

View File

@@ -0,0 +1,197 @@
# Kernels plugins
## 概览
LLaMA-Factory 通过 Kernels plugins 系统依据不同硬件设备提供高性能计算内核kernel实现。该系统通过注册表机制管理所有 kernel通过 `@register_kernel` 装饰器实现 kernel 定义后自动注册,由 `apply_kernel` 方法来使能指定的 kernel`apply_default_kernels` 可使能注册表中当前环境所有可用的默认 kernels。
## 架构设计
### 核心组件
#### 1. Registry注册表
`Registry` 是一个用于管理所有 kernel 实现的静态类。它维护一个字典结构:`{kernel_id: KernelClass}`
```python
# 注册表结构示例
{
"npu_fused_rmsnorm": NpuRMSNormKernel,
"npu_fused_swiglu": NpuSwiGluKernel,
...
}
```
#### 2. register_kernel (装饰器)
`@register_kernel``Registry.register` 的别名。所有 kernel 类均应使用该装饰器进行注册。
**注册机制**
- 装饰器检查类是否继承自 `BaseKernel`
- 检查类是否定义了 `_kernel_id``_device` 属性。
- 检查 `_device` 是否与当前运行环境的加速器类型匹配。如果不匹配,则跳过注册。
- 如果一切符合要求,将 kernel 类注册到全局注册表中。
#### 3. BaseKernel基类
所有 kernel 的实现都必须继承自 `BaseKernel` 抽象基类。`BaseKernel` 定义了 kernel 的基本属性和接口。
#### 4. 标识系统
**Kernel ID** (`_kernel_id`)
每个 kernel 必须拥有一个唯一的字符串标识符,例如 `"npu_fused_rmsnorm"`
**Device Type** (`_device`)
kernel 必须声明其支持的设备类型,例如 `DeviceType.NPU``DeviceType.CUDA`
## Kernel 系统 API 设计
### **Registry**:全局 kernel 注册表
`Registry` 类提供了注册和获取 kernel 的接口:
```python
class Registry:
@classmethod
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
"""注册一个 kernel 类"""
...
@classmethod
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
"""根据 ID 获取 kernel 类"""
...
```
### **BaseKernel**
`BaseKernel` 定义了所有 kernel 必须实现的协议:
- `_kernel_id`: 类属性kernel 的唯一标识符。
- `_device`: 类属性kernel 支持的设备类型。
- `check_deps()`: 类方法,检查 kernel 的依赖项是否满足(如 `torch_npu` 是否安装)。
- `apply(**kwargs)`: 抽象类方法,实现 kernel 的具体应用逻辑。
```python
class BaseKernel(ABC):
_kernel_id: Any = ""
_device: DeviceType = DeviceType.CPU
@classmethod
def check_deps(cls) -> bool:
"""检查依赖项"""
...
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
"""应用 kernel 到模型"""
...
```
### **scan_all_kernels**
`scan_all_kernels` 函数会自动扫描 `ops` 目录下的所有 `.py` 文件并导入它们,从而触发 `@register_kernel` 装饰器完成自动注册。
### **apply_kernel**
对模型使能指定的 kernel。
```python
def apply_kernel(kernel_id: str, **kwargs) -> HFModel:
"""应用指定的 kernel 到模型
Args:
kernel_id: 目标 kernel 的 ID
**kwargs: 传递给 kernel.apply 的参数,通常包含 model
"""
```
**用法示例**
```python
from llamafactory.v1.plugins.model_plugins.kernels import apply_kernel
model = apply_kernel("npu_fused_rmsnorm", model=model)
```
### **apply_default_kernels**
对模型使能所有默认注册的 kernel。这是一个高级 API通常在模型加载流程中自动调用。
```python
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
"""应用所有默认 kernel
Args:
model: HFModel 实例
include_kernels: 包含的 kernel ID 列表(逗号分隔字符串),或者 "auto"/True 表示全部
"""
```
## 扩展 Kernels
如果用户有针对特定模型或者设备的 kernel可以按照下述步骤去实现并接入 LLaMA-Factory。
### 创建新 Kernel 的步骤
#### 1. 创建 Kernel 实现文件
`src/llamafactory/v1/plugins/model_plugins/kernels/ops` 下的相应子目录中创建新的 kernel 实现文件,例如 `mlp/cuda_swiglu.py`
```python
import torch
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
# 实现具体的 kernel 函数
def _cuda_swiglu_forward(self, hidden_state):
# ... CUDA 优化实现 ...
pass
@register_kernel
class CudaSwiGluKernel(BaseKernel):
_kernel_id = "cuda_fused_swiglu"
_device = DeviceType.CUDA
@classmethod
def apply(cls, **kwargs) -> HFModel:
model = kwargs.get("model")
if model is None:
raise ValueError("model is required")
if not cls.check_deps():
raise RuntimeError("Dependencies not met")
# 遍历模型并替换 forward 方法
for name, module in model.named_modules():
# ... 匹配和替换逻辑 ...
pass
return model
```
#### 2. 自动发现
由于 `scan_all_kernels` 会自动扫描 `ops` 目录,只要文件位于该目录下且没有语法错误,系统启动时会自动导入并注册,无需手动修改注册表代码。
#### 3. 测试 Kernel
创建测试用例验证 kernel 的正确性:
```python
from llamafactory.v1.plugins.model_plugins.kernels import apply_kernel
# ... 加载模型 ...
model = apply_kernel("cuda_fused_swiglu", model=model)
# ... 验证 forward 是否被替换 ...
```
## 异常处理
### 依赖不可用
`BaseKernel.check_deps()` 默认会检查当前设备类型是否匹配。子类可以重写此方法以添加额外的依赖检查(如检查特定的库是否安装)。如果 `check_deps()` 返回 `False``apply()` 方法应当抛出异常或进行相应处理。
### Kernel ID 未找到
如果调用 `apply_kernel` 时传入了不存在的 `kernel_id`,会抛出 `ValueError`