[v1] Refactor kernel plugin (#9669)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
浮梦
2025-12-31 18:26:48 +08:00
committed by GitHub
parent 4e1d69579a
commit 16735b9e35
19 changed files with 777 additions and 433 deletions

View File

@@ -173,7 +173,7 @@ class BaseModelArguments:
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
use_v1_kernels: bool = field(
use_v1_kernels: bool | None = field(
default=False,
metadata={"help": "Whether or not to use high-performance kernels in training."},
)

View File

@@ -216,9 +216,9 @@ def load_model(
"You are try to using future feature about kernels, please note that this feature "
"is not supported for all models. If get any error, please disable this feature, or report the issue."
)
from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_available_kernels(model)
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model)
if is_trainable:

View File

@@ -112,6 +112,13 @@ class ModelLoader:
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
)
return model

View File

@@ -0,0 +1,87 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of base kernel class.
Init Phase:
1. Define base kernel class.
2. Define abstract methods.
"""
from abc import ABC, abstractmethod
from typing import Any
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
class BaseKernel(ABC):
r"""Base class for all kernel implementations.
Subclasses must implement the abstract methods and define the required class attributes.
"""
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
@classmethod
def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
.. note::
In explicit mode, if a user specifies an implementation but this check fails,
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if cls._device != get_current_accelerator().type:
return False
return True
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
Returns:
HFModel: The model with the kernel applied.
Raises:
RuntimeError: If the kernel dependencies are not met.
NotImplementedError: If the method is not implemented by the subclass.
Example:
>>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
>>> model = HFModel(config=config)
>>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
"""
if not cls.check_deps():
raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.")
raise NotImplementedError

View File

@@ -1,23 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
class KernelType(str, Enum):
RMSNORM = "rmsnorm"
SWIGLU = "swiglu"
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"

View File

@@ -0,0 +1,132 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of kernel interface.
Init Phase:
1. Scan all kernels.
2. Register default kernels.
3. Define kernel plugin.
"""
import importlib
from pathlib import Path
from ....utils.logging import get_logger
from ....utils.plugin import BasePlugin
from .registry import Registry
logger = get_logger(__name__)
def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
Returns:
dict[str, type[BaseKernel]]: A dictionary of registered kernels.
.. note::
This function assumes that the ``ops`` directory is located in the same directory as this file.
It recursively searches for ``.py`` files and constructs the module path for import.
"""
ops_path = Path(__file__).parent / "ops"
if not ops_path.exists():
return
base_package = __package__
for file_path in ops_path.rglob("*.py"):
if file_path.name == "__init__.py":
continue
# calculate the relative path:
# file_path = .../kernels_v2/ops/mlp/npu_swiglu.py
# rel_path = ops/mlp/npu_swiglu.py
rel_path = file_path.relative_to(Path(__file__).parent)
# build module path:
module_name = ".".join(rel_path.parts)[:-3]
full_module_name = f"{base_package}.{module_name}"
try:
importlib.import_module(full_module_name)
except Exception as e:
logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}")
return Registry.get_registered_kernels()
default_kernels = scan_all_kernels()
def get_default_kernels():
r"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
"""
return list(default_kernels.keys())
def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance.
Returns:
HFModel: The model with applied kernel.
"""
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations."""
pass
@KernelPlugin("auto").register
def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
Returns:
HFModel: The model with applied kernels.
"""
if not kwargs.get("include_kernels"): # None/False/empty string
return kwargs.get("model")
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model")

View File

@@ -12,22 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused MoE kernels.
Init Phase:
1. Define GMM functions.
2. Define NPU fused MoE functions.
3. Register NPU fused MoE kernel.
"""
import types
import torch
import torch.nn.functional as F
import torch_npu
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.packages import is_transformers_version_greater_than
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaMoEKernel
try:
import torch_npu
except ImportError:
pass
from ......accelerator.helper import DeviceType
from ......utils.packages import is_transformers_version_greater_than
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod
def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
x (Tensor): Input tensor.
weight (Tensor): Weight tensor.
group_list (list): List of group sizes.
Returns:
Tensor: The result of the grouped matrix multiplication.
"""
ctx.save_for_backward(x, weight)
ctx.group_list = group_list
@@ -38,6 +65,15 @@ class GmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
grad_output (Tensor): Gradient with respect to the output.
Returns:
tuple: Gradients with respect to input, weight, and None for group_list.
"""
input_tensor, weight = ctx.saved_tensors
group_list = ctx.group_list
@@ -58,8 +94,20 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod
def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
num_experts (int): Number of experts.
*args: Variable length argument list containing inputs and weights.
Returns:
tuple: The outputs of the grouped matrix multiplication.
"""
x_list = list(args[:num_experts])
weight_list = list(args[num_experts:])
@@ -76,6 +124,15 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
*grad_outputs: Gradients with respect to the outputs.
Returns:
tuple: Gradients with respect to inputs and weights.
"""
saved_tensors = ctx.saved_tensors
num_experts = ctx.num_experts
split_sizes = ctx.split_sizes
@@ -119,10 +176,23 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused:
r"""Container for NPU fused MoE forward functions."""
@staticmethod
def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
hidden_states (Tensor): Input hidden states.
routing_weights (Tensor): Routing weights.
router_indices (Tensor): Router indices.
Returns:
Tensor: Output tensor after expert computation.
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
@@ -138,6 +208,15 @@ class NpuMoeFused:
@staticmethod
def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Forward pass for sparse MoE block using NPU optimization.
Args:
self: The MoE sparse block instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: The routed output.
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states)
@@ -151,8 +230,19 @@ class NpuMoeFused:
class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
hidden_states (Tensor): Input hidden states.
Returns:
tuple: A tuple containing the next states and router logits.
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
@@ -206,14 +296,33 @@ if not is_transformers_version_greater_than("5.0.0"):
}
class NpuMoEFusedMoEKernel(MetaMoEKernel):
type = KernelType.MOE
device = DeviceType.NPU
@register_kernel
class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod
def apply(cls, model, **kwargs) -> HFModel:
if not is_torch_npu_available():
return model
def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched MoE forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
archs = getattr(model.config, "architectures", [])
target_moe_mapping = None

View File

@@ -12,36 +12,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused SwiGLU kernels.
Init Phase:
1. Define SwiGLU forward functions.
2. Register NPU fused SwiGLU kernel.
"""
import re
import types
import torch
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaSwiGluKernel
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
def _npu_swiglu_forward(self, hidden_state):
try:
import torch_npu
except ImportError:
pass
def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
hidden_state (Tensor): Input hidden state.
Returns:
Tensor: Output of SwiGLU.
"""
return self.down_proj(
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
)
def _npu_swiglu_glm4_forward(self, hidden_states):
import torch_npu
r"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
import torch_npu
r"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
gate_proj = self.gate_proj(hidden_states)
if self.activation_sparsity > 0.0:
gate_proj = self._gaussian_topk(gate_proj)
@@ -51,12 +86,11 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
return down_proj
class NpuSwiGluKernel(MetaSwiGluKernel):
type = KernelType.SWIGLU
device = DeviceType.NPU
kernel = _npu_swiglu_forward
@register_kernel
class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation."""
# Don't apply the kernel to the following modules
# just support apply to the following module layers
expect_modules = frozenset(
{
"Qwen3VLMoeTextMLP",
@@ -87,10 +121,29 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
}
)
_kernel_id = "npu_fused_swiglu"
_device = DeviceType.NPU
@classmethod
def apply(cls, model, **kwargs) -> "HFModel":
if not is_torch_npu_available():
return model
def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched SwiGLU forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuSwiGluKernel was called.")
# Mapping of specific mlp modules to their corresponding kernel implementations
kernel_mapping = {
@@ -109,7 +162,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
kernel_func = kernel_mapping.get(module.__class__.__name__, npu_swiglu_forward)
module.forward = types.MethodType(kernel_func, module)
return model

View File

@@ -11,40 +11,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RMSNorm kernels.
Init Phase:
1. Define RMSNorm forward function.
2. Register NPU fused RMSNorm kernel.
"""
import re
import types
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRMSNormKernel
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
def _npu_rms_forward(self, hidden_states):
"""NPU forward implementation for RMSNorm.
def npu_rms_norm_forward(self, hidden_states):
r"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
hidden_states: Input hidden states tensor, same shape as the baseline.
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
Returns:
Normalized tensor consistent with the baseline RMSNorm behavior.
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
"""
import torch_npu
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
class NpuRMSNormKernel(MetaRMSNormKernel):
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
@register_kernel
class NpuRMSNormKernel(BaseKernel):
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
type = KernelType.RMSNORM
device = DeviceType.NPU
kernel = _npu_rms_forward
_kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU
@classmethod
def apply(cls, model, **kwargs) -> HFModel:
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
def apply(cls, **kwargs) -> "HFModel":
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
@@ -52,10 +61,23 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
replace the original `forward`.
- Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency.
"""
if not is_torch_npu_available():
return model
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with NPU fused RMSNorm.
Raises:
RuntimeError: If torch_npu is not available.
ValueError: If the model is not provided.
"""
model = kwargs.get("model")
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():
@@ -63,6 +85,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
if re.search(rms_norm_pattern, module.__class__.__name__):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
module.forward = types.MethodType(cls.kernel, module)
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model

View File

@@ -0,0 +1,146 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RoPE kernels.
Init Phase:
1. Define RoPE forward functions.
2. Register NPU fused RoPE kernel.
"""
import sys
import torch
from ......accelerator.helper import DeviceType
from ......utils.logging import get_logger
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
logger = get_logger(__name__)
try:
import torch_npu
except ImportError:
pass
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
position_ids (Tensor, optional): Position IDs. Default: ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
mrope_section (Tensor): Multimodal RoPE section.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
@register_kernel
class NpuRoPEKernel(BaseKernel):
r"""NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched RoPE functions.
Raises:
RuntimeError: If dependencies are not met.
ValueError: If the model is not provided.
"""
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_rotary_pos_emb"):
if getattr(target_module, "apply_rotary_pos_emb") is not _apply_rotary_pos_emb:
setattr(target_module, "apply_rotary_pos_emb", _apply_rotary_pos_emb)
_modules.add(module_name)
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
if (
getattr(target_module, "apply_multimodal_rotary_pos_emb")
is not _apply_multimodal_rotary_pos_emb_qwen25_vl
):
setattr(
target_module,
"apply_multimodal_rotary_pos_emb",
_apply_multimodal_rotary_pos_emb_qwen25_vl,
)
_modules.add(module_name)
except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
return model

View File

@@ -12,247 +12,86 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Callable
from typing import Any, Optional
"""The definition of kernel registry.
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
from .constants import KernelType
Init Phase:
1. Define kernel registry.
2. Register kernels.
"""
from typing import Optional
from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel
class KernelRegistry:
_instance: Optional["KernelRegistry"] = None
_initialized: bool = False
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True
def register(
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
) -> None:
"""Register a kernel implementation.
Args:
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
kernel_impl: the actual kernel function or class.
"""
if kernel_type not in self._registry:
self._registry[kernel_type] = {}
if device_type in self._registry[kernel_type]:
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
self._registry[kernel_type][device_type] = kernel_impl
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
return self._registry.get(kernel_type, {}).get(device_type)
__all__ = ["Registry", "register_kernel"]
KERNEL_REGISTRY = KernelRegistry()
class Registry:
r"""Registry for managing kernel implementations.
class AutoRegisterKernelMeta(ABCMeta):
"""Metaclass that automatically registers kernel classes upon creation.
This metaclass checks if a newly created class has both `type` and `device`
attributes defined. If so, it automatically registers the kernel in the
global KERNEL_REGISTRY, eliminating the need for manual registration.
To disable auto-registration for a specific class, set `auto_register = False`.
Storage structure: ``{ "kernel_id": Class }``
"""
def __new__(mcs, name, bases, namespace, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
# Check if auto-registration is disabled
auto_register = namespace.get("auto_register", True)
# Only auto-register if the class has both type and device attributes defined
# and they are not None (skip base classes like MetaKernel itself)
# and auto_register is True
kernel_type = namespace.get("type")
device_type = namespace.get("device")
if auto_register and kernel_type is not None and device_type is not None:
# Auto-register this kernel
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
return cls
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
"""Base class for all kernel implementations.
Subclasses are automatically registered when they define both `type` and `device`
attributes. To disable auto-registration, set `auto_register = False`.
Attributes:
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
kernel: The actual kernel function or implementation.
auto_register: Set to False to disable automatic registration (default: True).
"""
type: KernelType | None = None
device: DeviceType | None = None
kernel: Callable | None = None
_kernels: dict[str, type[BaseKernel]] = {}
@classmethod
@abstractmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
"""Apply the kernel to the model.
def register(cls, kernel_cls: type[BaseKernel]):
r"""Decorator to register a kernel class.
This method should check if the kernel can be applied (e.g., dependencies
are installed, target modules exist) and perform the kernel replacement.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
Args:
model: The HuggingFace model to optimize.
**kwargs: Additional arguments for kernel application.
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
The optimized model (may be the same object with modifications).
type[BaseKernel]: The registered kernel class.
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
ValueError: If the kernel ID is missing or already registered.
"""
raise NotImplementedError
if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
return
if not kernel_id:
raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")
if kernel_id in cls._kernels:
raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")
cls._kernels[kernel_id] = kernel_cls
return kernel_cls
class MetaFlashAttentionKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
r"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
Returns:
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
"""
return cls._kernels.get(kernel_id)
class MetaRMSNormKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
r"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
"""
return cls._kernels
class MetaSwiGluKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaRoPEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaMoEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def _ensure_kernels_loaded() -> None:
"""Ensure all kernel implementations are imported and registered.
This function dynamically imports all kernel implementation modules to trigger
their auto-registration. Python's module system ensures each module is only
executed once (cached in sys.modules), so repeated calls are safe and fast.
"""
# List of kernel module paths to import
kernel_modules = [
"rms_norm.npu_rms_norm",
"rope.npu_rope",
"mlp.npu_swiglu",
"mlp.npu_fused_moe",
# Add new kernel modules here as they are created
]
# Import each module to trigger kernel registration
# Python's import system caches modules, so this is fast on subsequent calls
for module_name in kernel_modules:
try:
__import__(f"{__package__}.{module_name}", fromlist=["*"])
except ImportError:
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
pass
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
"""Discover and return all kernel classes registered for the current device.
This function inspects the runtime environment (device type) and returns
all MetaKernel classes registered for that device. Each kernel's `apply()`
method is responsible for checking if it can actually be applied (e.g.,
required dependencies are installed, target modules exist in the model).
The function automatically discovers all kernels registered in KERNEL_REGISTRY
without requiring manual enumeration. On first call, it dynamically imports
all kernel implementation modules to trigger their auto-registration.
Args:
model: The HuggingFace model to apply kernels to.
TODO: implement the kernel route detection logic by model structure.
Returns:
A list of MetaKernel classes available for the current device.
"""
# Ensure all kernel modules are imported to trigger registration
_ensure_kernels_loaded()
discovered_kernels: list[type[MetaKernel]] = []
# Detect current device type
accelerator = get_current_accelerator()
try:
device_type = DeviceType(accelerator.type)
except ValueError:
# Unknown device type, return empty list
return discovered_kernels
# Skip CPU as it typically doesn't have optimized kernels
if device_type == DeviceType.CPU:
return discovered_kernels
# Iterate through registry and collect all kernels for current device
for devices in KERNEL_REGISTRY._registry.values():
kernel_cls = devices.get(device_type)
if kernel_cls is not None:
discovered_kernels.append(kernel_cls)
return discovered_kernels
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only
requirement is that `apply` returns the replaced model.
Example:
from transformers import AutoModelForCausalLM
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if not issubclass(kernel, MetaKernel):
raise ValueError(f"{kernel} must be a MetaKernel instance.")
if kernel.device != get_current_accelerator().type:
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
return kernel.apply(model, **kwargs)
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
"""Apply all available kernels to the model."""
for kernel in discover_kernels(model):
model = apply_kernel(model, kernel, **kwargs)
return model
# export decorator alias
register_kernel = Registry.register

View File

@@ -1,122 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import torch
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRoPEKernel
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
import torch_npu
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
import torch_npu
mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim
)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
class NpuRoPEKernel(MetaRoPEKernel):
type = KernelType.ROPE
device = DeviceType.NPU
kernel = _apply_rotary_pos_emb
@classmethod
def apply(cls, model, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
if not is_torch_npu_available():
return model
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_rotary_pos_emb"):
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
_modules.add(module_name)
except Exception:
pass
return model
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
"""Qwen2-VL specific RoPE kernel - not auto-registered.
This kernel is for specific models (Qwen2-VL) and should be manually
applied when needed rather than auto-discovered.
"""
type = KernelType.ROPE
device = DeviceType.NPU
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
auto_register = False # Disable auto-registration for model-specific kernel
@classmethod
def apply(cls, model, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
"""
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
module_name = module.__class__.__module__
if module_name in _modules:
continue
try:
target_module = sys.modules[module_name]
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
_modules.add(module_name)
except Exception:
pass
return model

View File

@@ -0,0 +1,71 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import pathlib
from unittest.mock import patch
from llamafactory.v1.config.arg_parser import get_args
def test_get_args_from_yaml(tmp_path: pathlib.Path):
config_yaml = """
### model
model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true
use_fast_processor: true
model_class: "llm"
kernel_config:
name: "auto"
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
peft_config:
name: "lora"
lora_rank: 0.8
quant_config: null
### data
dataset: "llamafactory/tiny-supervised-dataset"
cutoff_len: 2048
### training
output_dir: "outputs/test_run"
micro_batch_size: 1
global_batch_size: 1
learning_rate: 1.0e-4
bf16: false
dist_config: null
### sample
sample_backend: "hf"
max_new_tokens: 128
"""
config_file = tmp_path / "config.yaml"
config_file.write_text(config_yaml, encoding="utf-8")
test_argv = ["test_args_parser.py", str(config_file)]
with patch.object(sys, "argv", test_argv):
data_args, model_args, training_args, sample_args = get_args()
assert training_args.output_dir == "outputs/test_run"
assert training_args.micro_batch_size == 1
assert training_args.global_batch_size == 1
assert training_args.learning_rate == 1.0e-4
assert training_args.bf16 is False
assert training_args.dist_config is None
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
assert model_args.kernel_config.name == "auto"
assert model_args.kernel_config.get("include_kernels") == "auto"
assert model_args.peft_config.name == "lora"
assert model_args.peft_config.get("lora_rank") == 0.8

View File

@@ -14,7 +14,7 @@
import torch
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.core.model_loader import ModelLoader
@@ -29,5 +29,23 @@ def test_tiny_qwen():
assert model_loader.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
from transformers import Qwen2ForCausalLM
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
)
model_loader = ModelLoader(model_args)
# test enable apply kernel plugin
if hasattr(torch, "npu"):
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else:
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_loader.model, Qwen2ForCausalLM)
if __name__ == "__main__":
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -12,16 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest.mock import MagicMock, patch
import pytest
from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
@pytest.fixture(autouse=True)
@@ -29,24 +26,29 @@ def clear_accelerator_cache():
get_current_accelerator.cache_clear()
def reload_kernels():
"""Helper to reload kernel modules to respect mocked accelerator."""
# Unload kernel interface and registry
keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")]
for k in keys_to_remove:
del sys.modules[k]
@patch("torch.accelerator.current_accelerator")
def test_apply_kernel(mock_get_accelerator: MagicMock):
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
# Force reload of kernels with mocked accelerator
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__
@patch("torch.accelerator.current_accelerator")
@@ -56,12 +58,15 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
# Force reload of kernels with mocked accelerator
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_available_kernels(model)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
model = apply_default_kernels(model=model, include_kernels=True)
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__