mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-08 23:20:36 +08:00
[v1] Refactor kernel plugin (#9669)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@@ -173,7 +173,7 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||
)
|
||||
use_v1_kernels: bool = field(
|
||||
use_v1_kernels: bool | None = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use high-performance kernels in training."},
|
||||
)
|
||||
|
||||
@@ -216,9 +216,9 @@ def load_model(
|
||||
"You are try to using future feature about kernels, please note that this feature "
|
||||
"is not supported for all models. If get any error, please disable this feature, or report the issue."
|
||||
)
|
||||
from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
||||
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = apply_available_kernels(model)
|
||||
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
|
||||
@@ -112,6 +112,13 @@ class ModelLoader:
|
||||
|
||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||
|
||||
if self.args.kernel_config is not None:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
87
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
Normal file
87
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of base kernel class.
|
||||
|
||||
Init Phase:
|
||||
1. Define base kernel class.
|
||||
2. Define abstract methods.
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
|
||||
|
||||
class BaseKernel(ABC):
|
||||
r"""Base class for all kernel implementations.
|
||||
|
||||
Subclasses must implement the abstract methods and define the required class attributes.
|
||||
"""
|
||||
|
||||
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
|
||||
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
|
||||
|
||||
@classmethod
|
||||
def get_kernel_id(cls) -> str:
|
||||
r"""Returns the unique identifier for the kernel."""
|
||||
return cls._kernel_id
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
r"""Checks if the required dependencies for the kernel are available.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if dependencies are met, ``False`` otherwise.
|
||||
|
||||
.. note::
|
||||
In explicit mode, if a user specifies an implementation but this check fails,
|
||||
it should raise an error instead of silently switching.
|
||||
Kernels can override this method to implement custom dependency checks.
|
||||
"""
|
||||
if cls._device != get_current_accelerator().type:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the kernel optimization to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with the kernel applied.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the kernel dependencies are not met.
|
||||
NotImplementedError: If the method is not implemented by the subclass.
|
||||
|
||||
Example:
|
||||
>>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
|
||||
>>> model = HFModel(config=config)
|
||||
>>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
|
||||
"""
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.")
|
||||
raise NotImplementedError
|
||||
@@ -1,23 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class KernelType(str, Enum):
|
||||
RMSNORM = "rmsnorm"
|
||||
SWIGLU = "swiglu"
|
||||
FLASH_ATTENTION = "flash_attention"
|
||||
ROPE = "rope"
|
||||
MOE = "moe"
|
||||
132
src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
Normal file
132
src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of kernel interface.
|
||||
|
||||
Init Phase:
|
||||
1. Scan all kernels.
|
||||
2. Register default kernels.
|
||||
3. Define kernel plugin.
|
||||
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.plugin import BasePlugin
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def scan_all_kernels():
|
||||
r"""Scan all kernels in the ``ops`` directory.
|
||||
|
||||
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
|
||||
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
|
||||
|
||||
Returns:
|
||||
dict[str, type[BaseKernel]]: A dictionary of registered kernels.
|
||||
|
||||
.. note::
|
||||
This function assumes that the ``ops`` directory is located in the same directory as this file.
|
||||
It recursively searches for ``.py`` files and constructs the module path for import.
|
||||
"""
|
||||
ops_path = Path(__file__).parent / "ops"
|
||||
|
||||
if not ops_path.exists():
|
||||
return
|
||||
|
||||
base_package = __package__
|
||||
|
||||
for file_path in ops_path.rglob("*.py"):
|
||||
if file_path.name == "__init__.py":
|
||||
continue
|
||||
|
||||
# calculate the relative path:
|
||||
# file_path = .../kernels_v2/ops/mlp/npu_swiglu.py
|
||||
# rel_path = ops/mlp/npu_swiglu.py
|
||||
rel_path = file_path.relative_to(Path(__file__).parent)
|
||||
|
||||
# build module path:
|
||||
module_name = ".".join(rel_path.parts)[:-3]
|
||||
full_module_name = f"{base_package}.{module_name}"
|
||||
|
||||
try:
|
||||
importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}")
|
||||
|
||||
return Registry.get_registered_kernels()
|
||||
|
||||
|
||||
default_kernels = scan_all_kernels()
|
||||
|
||||
|
||||
def get_default_kernels():
|
||||
r"""Get a list of default registered kernel IDs.
|
||||
|
||||
Returns:
|
||||
list[str]: List of kernel IDs.
|
||||
"""
|
||||
return list(default_kernels.keys())
|
||||
|
||||
|
||||
def apply_kernel(kernel_id: str, **kwargs):
|
||||
r"""Applies a specific kernel to the model.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to apply.
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
Typically includes the model instance.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with applied kernel.
|
||||
"""
|
||||
kernel = default_kernels.get(kernel_id)
|
||||
if kernel is None:
|
||||
raise ValueError(f"Kernel {kernel_id} not found")
|
||||
kernel.apply(**kwargs)
|
||||
|
||||
|
||||
class KernelPlugin(BasePlugin):
|
||||
r"""Plugin for managing kernel optimizations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@KernelPlugin("auto").register
|
||||
def apply_default_kernels(**kwargs):
|
||||
r"""Applies all default registered kernels to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
Typically includes the model instance and the include_kernels configuration.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with applied kernels.
|
||||
"""
|
||||
if not kwargs.get("include_kernels"): # None/False/empty string
|
||||
return kwargs.get("model")
|
||||
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
|
||||
use_kernels = default_kernels.keys()
|
||||
else:
|
||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
for kernel in use_kernels:
|
||||
if kernel not in default_kernels:
|
||||
raise ValueError(f"Kernel {kernel} not found")
|
||||
apply_kernel(kernel, **kwargs)
|
||||
return kwargs.get("model")
|
||||
@@ -12,22 +12,49 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of NPU fused MoE kernels.
|
||||
|
||||
Init Phase:
|
||||
1. Define GMM functions.
|
||||
2. Define NPU fused MoE functions.
|
||||
3. Register NPU fused MoE kernel.
|
||||
|
||||
"""
|
||||
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.packages import is_transformers_version_greater_than
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.packages import is_transformers_version_greater_than
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
class GmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, group_list):
|
||||
r"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors for backward pass.
|
||||
x (Tensor): Input tensor.
|
||||
weight (Tensor): Weight tensor.
|
||||
group_list (list): List of group sizes.
|
||||
|
||||
Returns:
|
||||
Tensor: The result of the grouped matrix multiplication.
|
||||
"""
|
||||
ctx.save_for_backward(x, weight)
|
||||
ctx.group_list = group_list
|
||||
|
||||
@@ -38,6 +65,15 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
r"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
grad_output (Tensor): Gradient with respect to the output.
|
||||
|
||||
Returns:
|
||||
tuple: Gradients with respect to input, weight, and None for group_list.
|
||||
"""
|
||||
input_tensor, weight = ctx.saved_tensors
|
||||
group_list = ctx.group_list
|
||||
|
||||
@@ -58,8 +94,20 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class HybridGmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, num_experts, *args):
|
||||
r"""Performs the forward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors.
|
||||
num_experts (int): Number of experts.
|
||||
*args: Variable length argument list containing inputs and weights.
|
||||
|
||||
Returns:
|
||||
tuple: The outputs of the grouped matrix multiplication.
|
||||
"""
|
||||
x_list = list(args[:num_experts])
|
||||
weight_list = list(args[num_experts:])
|
||||
|
||||
@@ -76,6 +124,15 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
r"""Performs the backward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
*grad_outputs: Gradients with respect to the outputs.
|
||||
|
||||
Returns:
|
||||
tuple: Gradients with respect to inputs and weights.
|
||||
"""
|
||||
saved_tensors = ctx.saved_tensors
|
||||
num_experts = ctx.num_experts
|
||||
split_sizes = ctx.split_sizes
|
||||
@@ -119,10 +176,23 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class NpuMoeFused:
|
||||
r"""Container for NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_forward(
|
||||
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
r"""Forward pass for MoE experts using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The MoE layer instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
routing_weights (Tensor): Routing weights.
|
||||
router_indices (Tensor): Router indices.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor after expert computation.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||
@@ -138,6 +208,15 @@ class NpuMoeFused:
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""Forward pass for sparse MoE block using NPU optimization.
|
||||
|
||||
Args:
|
||||
self: The MoE sparse block instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor: The routed output.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||
router_logits = self.gate(hidden_states)
|
||||
@@ -151,8 +230,19 @@ class NpuMoeFused:
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
r"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
||||
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The Qwen3 MoE block instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the next states and router logits.
|
||||
"""
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
@@ -206,14 +296,33 @@ if not is_transformers_version_greater_than("5.0.0"):
|
||||
}
|
||||
|
||||
|
||||
class NpuMoEFusedMoEKernel(MetaMoEKernel):
|
||||
type = KernelType.MOE
|
||||
device = DeviceType.NPU
|
||||
@register_kernel
|
||||
class NpuFusedMoEKernel(BaseKernel):
|
||||
r"""NPU Fused MoE Kernel implementation."""
|
||||
|
||||
_kernel_id = "npu_fused_moe"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> HFModel:
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the NPU fused MoE kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with patched MoE forward functions.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not provided.
|
||||
RuntimeError: If dependencies are not met.
|
||||
"""
|
||||
model = kwargs.get("model", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
|
||||
|
||||
archs = getattr(model.config, "architectures", [])
|
||||
target_moe_mapping = None
|
||||
@@ -12,36 +12,71 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of NPU fused SwiGLU kernels.
|
||||
|
||||
Init Phase:
|
||||
1. Define SwiGLU forward functions.
|
||||
2. Register NPU fused SwiGLU kernel.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaSwiGluKernel
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
def _npu_swiglu_forward(self, hidden_state):
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def npu_swiglu_forward(self, hidden_state):
|
||||
r"""SwiGLU forward pass for NPU.
|
||||
|
||||
Args:
|
||||
self: The MLP layer instance.
|
||||
hidden_state (Tensor): Input hidden state.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of SwiGLU.
|
||||
"""
|
||||
return self.down_proj(
|
||||
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||
)
|
||||
|
||||
|
||||
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
import torch_npu
|
||||
r"""SwiGLU forward pass for GLM4 on NPU.
|
||||
|
||||
Args:
|
||||
self: The GLM4 MLP layer instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of SwiGLU.
|
||||
"""
|
||||
up_states = self.gate_up_proj(hidden_states)
|
||||
gate, up_states = up_states.chunk(2, dim=-1)
|
||||
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
|
||||
|
||||
|
||||
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
import torch_npu
|
||||
r"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||
|
||||
Args:
|
||||
self: The Gemma3nText MLP layer instance.
|
||||
hidden_states (Tensor): Input hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of SwiGLU.
|
||||
"""
|
||||
gate_proj = self.gate_proj(hidden_states)
|
||||
if self.activation_sparsity > 0.0:
|
||||
gate_proj = self._gaussian_topk(gate_proj)
|
||||
@@ -51,12 +86,11 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
return down_proj
|
||||
|
||||
|
||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
type = KernelType.SWIGLU
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_swiglu_forward
|
||||
@register_kernel
|
||||
class NpuSwiGluKernel(BaseKernel):
|
||||
r"""NPU Kernel for fused SwiGLU activation."""
|
||||
|
||||
# Don't apply the kernel to the following modules
|
||||
# just support apply to the following module layers
|
||||
expect_modules = frozenset(
|
||||
{
|
||||
"Qwen3VLMoeTextMLP",
|
||||
@@ -87,10 +121,29 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
}
|
||||
)
|
||||
|
||||
_kernel_id = "npu_fused_swiglu"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Applies the NPU fused SwiGLU kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with patched SwiGLU forward functions.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not provided.
|
||||
RuntimeError: If dependencies are not met.
|
||||
"""
|
||||
model = kwargs.get("model", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError("torch_npu is not available but NpuSwiGluKernel was called.")
|
||||
|
||||
# Mapping of specific mlp modules to their corresponding kernel implementations
|
||||
kernel_mapping = {
|
||||
@@ -109,7 +162,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
):
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
|
||||
kernel_func = kernel_mapping.get(module.__class__.__name__, npu_swiglu_forward)
|
||||
module.forward = types.MethodType(kernel_func, module)
|
||||
|
||||
return model
|
||||
@@ -11,40 +11,49 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of NPU fused RMSNorm kernels.
|
||||
|
||||
Init Phase:
|
||||
1. Define RMSNorm forward function.
|
||||
2. Register NPU fused RMSNorm kernel.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRMSNormKernel
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
def _npu_rms_forward(self, hidden_states):
|
||||
"""NPU forward implementation for RMSNorm.
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
r"""NPU forward implementation for RMSNorm.
|
||||
|
||||
Args:
|
||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||
hidden_states: Input hidden states tensor, same shape as the baseline.
|
||||
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
|
||||
|
||||
Returns:
|
||||
Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||
"""
|
||||
import torch_npu
|
||||
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
|
||||
class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
@register_kernel
|
||||
class NpuRMSNormKernel(BaseKernel):
|
||||
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
|
||||
type = KernelType.RMSNORM
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_rms_forward
|
||||
_kernel_id = "npu_fused_rmsnorm"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> HFModel:
|
||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
|
||||
Key points:
|
||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||
@@ -52,10 +61,23 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
replace the original `forward`.
|
||||
- Do not modify weights, hyperparameters, or module structure to ensure
|
||||
numerical behavior and interface consistency.
|
||||
"""
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with NPU fused RMSNorm.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If torch_npu is not available.
|
||||
ValueError: If the model is not provided.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
@@ -63,6 +85,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
module.forward = types.MethodType(cls.kernel, module)
|
||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,146 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of NPU fused RoPE kernels.
|
||||
|
||||
Init Phase:
|
||||
1. Define RoPE forward functions.
|
||||
2. Register NPU fused RoPE kernel.
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.logging import get_logger
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
k (Tensor): Key tensor.
|
||||
cos (Tensor): Cosine part of embedding.
|
||||
sin (Tensor): Sine part of embedding.
|
||||
position_ids (Tensor, optional): Position IDs. Default: ``None``.
|
||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
||||
|
||||
Returns:
|
||||
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
k (Tensor): Key tensor.
|
||||
cos (Tensor): Cosine part of embedding.
|
||||
sin (Tensor): Sine part of embedding.
|
||||
mrope_section (Tensor): Multimodal RoPE section.
|
||||
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
||||
|
||||
Returns:
|
||||
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
||||
"""
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@register_kernel
|
||||
class NpuRoPEKernel(BaseKernel):
|
||||
r"""NPU Kernel for Rotary Position Embedding."""
|
||||
|
||||
_kernel_id = "npu_fused_rope"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with patched RoPE functions.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If dependencies are not met.
|
||||
ValueError: If the model is not provided.
|
||||
"""
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||
model = kwargs.get("model", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
module_name = module.__class__.__module__
|
||||
if module_name in _modules:
|
||||
continue
|
||||
try:
|
||||
target_module = sys.modules[module_name]
|
||||
if hasattr(target_module, "apply_rotary_pos_emb"):
|
||||
if getattr(target_module, "apply_rotary_pos_emb") is not _apply_rotary_pos_emb:
|
||||
setattr(target_module, "apply_rotary_pos_emb", _apply_rotary_pos_emb)
|
||||
_modules.add(module_name)
|
||||
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
|
||||
if (
|
||||
getattr(target_module, "apply_multimodal_rotary_pos_emb")
|
||||
is not _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||
):
|
||||
setattr(
|
||||
target_module,
|
||||
"apply_multimodal_rotary_pos_emb",
|
||||
_apply_multimodal_rotary_pos_emb_qwen25_vl,
|
||||
)
|
||||
_modules.add(module_name)
|
||||
except Exception as e:
|
||||
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
|
||||
return model
|
||||
@@ -12,247 +12,86 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
"""The definition of kernel registry.
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
from .constants import KernelType
|
||||
Init Phase:
|
||||
1. Define kernel registry.
|
||||
2. Register kernels.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from .base import BaseKernel
|
||||
|
||||
|
||||
class KernelRegistry:
|
||||
_instance: Optional["KernelRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
||||
self._initialized = True
|
||||
|
||||
def register(
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
|
||||
) -> None:
|
||||
"""Register a kernel implementation.
|
||||
|
||||
Args:
|
||||
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
|
||||
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
|
||||
kernel_impl: the actual kernel function or class.
|
||||
"""
|
||||
if kernel_type not in self._registry:
|
||||
self._registry[kernel_type] = {}
|
||||
|
||||
if device_type in self._registry[kernel_type]:
|
||||
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
|
||||
|
||||
self._registry[kernel_type][device_type] = kernel_impl
|
||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
|
||||
return self._registry.get(kernel_type, {}).get(device_type)
|
||||
__all__ = ["Registry", "register_kernel"]
|
||||
|
||||
|
||||
KERNEL_REGISTRY = KernelRegistry()
|
||||
class Registry:
|
||||
r"""Registry for managing kernel implementations.
|
||||
|
||||
|
||||
class AutoRegisterKernelMeta(ABCMeta):
|
||||
"""Metaclass that automatically registers kernel classes upon creation.
|
||||
|
||||
This metaclass checks if a newly created class has both `type` and `device`
|
||||
attributes defined. If so, it automatically registers the kernel in the
|
||||
global KERNEL_REGISTRY, eliminating the need for manual registration.
|
||||
|
||||
To disable auto-registration for a specific class, set `auto_register = False`.
|
||||
Storage structure: ``{ "kernel_id": Class }``
|
||||
"""
|
||||
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
|
||||
|
||||
# Check if auto-registration is disabled
|
||||
auto_register = namespace.get("auto_register", True)
|
||||
|
||||
# Only auto-register if the class has both type and device attributes defined
|
||||
# and they are not None (skip base classes like MetaKernel itself)
|
||||
# and auto_register is True
|
||||
kernel_type = namespace.get("type")
|
||||
device_type = namespace.get("device")
|
||||
|
||||
if auto_register and kernel_type is not None and device_type is not None:
|
||||
# Auto-register this kernel
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||
"""Base class for all kernel implementations.
|
||||
|
||||
Subclasses are automatically registered when they define both `type` and `device`
|
||||
attributes. To disable auto-registration, set `auto_register = False`.
|
||||
|
||||
Attributes:
|
||||
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
|
||||
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
|
||||
kernel: The actual kernel function or implementation.
|
||||
auto_register: Set to False to disable automatic registration (default: True).
|
||||
"""
|
||||
|
||||
type: KernelType | None = None
|
||||
device: DeviceType | None = None
|
||||
kernel: Callable | None = None
|
||||
_kernels: dict[str, type[BaseKernel]] = {}
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
"""Apply the kernel to the model.
|
||||
def register(cls, kernel_cls: type[BaseKernel]):
|
||||
r"""Decorator to register a kernel class.
|
||||
|
||||
This method should check if the kernel can be applied (e.g., dependencies
|
||||
are installed, target modules exist) and perform the kernel replacement.
|
||||
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
|
||||
|
||||
Args:
|
||||
model: The HuggingFace model to optimize.
|
||||
**kwargs: Additional arguments for kernel application.
|
||||
kernel_cls (type[BaseKernel]): The kernel class to register.
|
||||
|
||||
Returns:
|
||||
The optimized model (may be the same object with modifications).
|
||||
type[BaseKernel]: The registered kernel class.
|
||||
|
||||
Raises:
|
||||
TypeError: If the class does not inherit from :class:`BaseKernel`.
|
||||
ValueError: If the kernel ID is missing or already registered.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if not issubclass(kernel_cls, BaseKernel):
|
||||
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
|
||||
kernel_id = kernel_cls.get_kernel_id()
|
||||
device = kernel_cls.get_device()
|
||||
|
||||
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
|
||||
if device != get_current_accelerator().type:
|
||||
return
|
||||
|
||||
if not kernel_id:
|
||||
raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")
|
||||
|
||||
if kernel_id in cls._kernels:
|
||||
raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")
|
||||
|
||||
cls._kernels[kernel_id] = kernel_cls
|
||||
return kernel_cls
|
||||
|
||||
class MetaFlashAttentionKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
||||
r"""Retrieves a registered kernel implementation by its ID.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
|
||||
"""
|
||||
return cls._kernels.get(kernel_id)
|
||||
|
||||
class MetaRMSNormKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
|
||||
r"""Returns a dictionary of all registered kernels.
|
||||
|
||||
Returns:
|
||||
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
|
||||
"""
|
||||
return cls._kernels
|
||||
|
||||
|
||||
class MetaSwiGluKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaRoPEKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaMoEKernel(MetaKernel):
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _ensure_kernels_loaded() -> None:
|
||||
"""Ensure all kernel implementations are imported and registered.
|
||||
|
||||
This function dynamically imports all kernel implementation modules to trigger
|
||||
their auto-registration. Python's module system ensures each module is only
|
||||
executed once (cached in sys.modules), so repeated calls are safe and fast.
|
||||
"""
|
||||
# List of kernel module paths to import
|
||||
kernel_modules = [
|
||||
"rms_norm.npu_rms_norm",
|
||||
"rope.npu_rope",
|
||||
"mlp.npu_swiglu",
|
||||
"mlp.npu_fused_moe",
|
||||
# Add new kernel modules here as they are created
|
||||
]
|
||||
|
||||
# Import each module to trigger kernel registration
|
||||
# Python's import system caches modules, so this is fast on subsequent calls
|
||||
for module_name in kernel_modules:
|
||||
try:
|
||||
__import__(f"{__package__}.{module_name}", fromlist=["*"])
|
||||
except ImportError:
|
||||
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
|
||||
pass
|
||||
|
||||
|
||||
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
"""Discover and return all kernel classes registered for the current device.
|
||||
|
||||
This function inspects the runtime environment (device type) and returns
|
||||
all MetaKernel classes registered for that device. Each kernel's `apply()`
|
||||
method is responsible for checking if it can actually be applied (e.g.,
|
||||
required dependencies are installed, target modules exist in the model).
|
||||
|
||||
The function automatically discovers all kernels registered in KERNEL_REGISTRY
|
||||
without requiring manual enumeration. On first call, it dynamically imports
|
||||
all kernel implementation modules to trigger their auto-registration.
|
||||
|
||||
Args:
|
||||
model: The HuggingFace model to apply kernels to.
|
||||
TODO: implement the kernel route detection logic by model structure.
|
||||
|
||||
Returns:
|
||||
A list of MetaKernel classes available for the current device.
|
||||
"""
|
||||
# Ensure all kernel modules are imported to trigger registration
|
||||
_ensure_kernels_loaded()
|
||||
|
||||
discovered_kernels: list[type[MetaKernel]] = []
|
||||
|
||||
# Detect current device type
|
||||
accelerator = get_current_accelerator()
|
||||
try:
|
||||
device_type = DeviceType(accelerator.type)
|
||||
except ValueError:
|
||||
# Unknown device type, return empty list
|
||||
return discovered_kernels
|
||||
|
||||
# Skip CPU as it typically doesn't have optimized kernels
|
||||
if device_type == DeviceType.CPU:
|
||||
return discovered_kernels
|
||||
|
||||
# Iterate through registry and collect all kernels for current device
|
||||
for devices in KERNEL_REGISTRY._registry.values():
|
||||
kernel_cls = devices.get(device_type)
|
||||
if kernel_cls is not None:
|
||||
discovered_kernels.append(kernel_cls)
|
||||
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
requirement is that `apply` returns the replaced model.
|
||||
|
||||
Example:
|
||||
from transformers import AutoModelForCausalLM
|
||||
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if not issubclass(kernel, MetaKernel):
|
||||
raise ValueError(f"{kernel} must be a MetaKernel instance.")
|
||||
|
||||
if kernel.device != get_current_accelerator().type:
|
||||
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
|
||||
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
|
||||
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||
"""Apply all available kernels to the model."""
|
||||
for kernel in discover_kernels(model):
|
||||
model = apply_kernel(model, kernel, **kwargs)
|
||||
|
||||
return model
|
||||
# export decorator alias
|
||||
register_kernel = Registry.register
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors."""
|
||||
import torch_npu
|
||||
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
|
||||
import torch_npu
|
||||
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
)
|
||||
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class NpuRoPEKernel(MetaRoPEKernel):
|
||||
type = KernelType.ROPE
|
||||
device = DeviceType.NPU
|
||||
kernel = _apply_rotary_pos_emb
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
"""
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
module_name = module.__class__.__module__
|
||||
if module_name in _modules:
|
||||
continue
|
||||
try:
|
||||
target_module = sys.modules[module_name]
|
||||
if hasattr(target_module, "apply_rotary_pos_emb"):
|
||||
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
|
||||
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
|
||||
_modules.add(module_name)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
||||
"""Qwen2-VL specific RoPE kernel - not auto-registered.
|
||||
|
||||
This kernel is for specific models (Qwen2-VL) and should be manually
|
||||
applied when needed rather than auto-discovered.
|
||||
"""
|
||||
|
||||
type = KernelType.ROPE
|
||||
device = DeviceType.NPU
|
||||
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||
auto_register = False # Disable auto-registration for model-specific kernel
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||
NPU-accelerated version from this file.
|
||||
"""
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
module_name = module.__class__.__module__
|
||||
if module_name in _modules:
|
||||
continue
|
||||
try:
|
||||
target_module = sys.modules[module_name]
|
||||
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
|
||||
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
|
||||
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
|
||||
_modules.add(module_name)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
71
tests_v1/config/test_args_parser.py
Normal file
71
tests_v1/config/test_args_parser.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import pathlib
|
||||
from unittest.mock import patch
|
||||
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
|
||||
|
||||
def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
config_yaml = """
|
||||
### model
|
||||
model: "llamafactory/tiny-random-qwen2.5"
|
||||
trust_remote_code: true
|
||||
use_fast_processor: true
|
||||
model_class: "llm"
|
||||
kernel_config:
|
||||
name: "auto"
|
||||
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
peft_config:
|
||||
name: "lora"
|
||||
lora_rank: 0.8
|
||||
quant_config: null
|
||||
|
||||
### data
|
||||
dataset: "llamafactory/tiny-supervised-dataset"
|
||||
cutoff_len: 2048
|
||||
|
||||
### training
|
||||
output_dir: "outputs/test_run"
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
dist_config: null
|
||||
|
||||
### sample
|
||||
sample_backend: "hf"
|
||||
max_new_tokens: 128
|
||||
"""
|
||||
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml, encoding="utf-8")
|
||||
|
||||
test_argv = ["test_args_parser.py", str(config_file)]
|
||||
|
||||
with patch.object(sys, "argv", test_argv):
|
||||
data_args, model_args, training_args, sample_args = get_args()
|
||||
assert training_args.output_dir == "outputs/test_run"
|
||||
assert training_args.micro_batch_size == 1
|
||||
assert training_args.global_batch_size == 1
|
||||
assert training_args.learning_rate == 1.0e-4
|
||||
assert training_args.bf16 is False
|
||||
assert training_args.dist_config is None
|
||||
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
|
||||
assert model_args.kernel_config.name == "auto"
|
||||
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||
assert model_args.peft_config.name == "lora"
|
||||
assert model_args.peft_config.get("lora_rank") == 0.8
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.v1.config.model_args import ModelArguments
|
||||
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
|
||||
|
||||
@@ -29,5 +29,23 @@ def test_tiny_qwen():
|
||||
assert model_loader.model.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_tiny_qwen_with_kernel_plugin():
|
||||
from transformers import Qwen2ForCausalLM
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
|
||||
|
||||
model_args = ModelArguments(
|
||||
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
||||
)
|
||||
model_loader = ModelLoader(model_args)
|
||||
# test enable apply kernel plugin
|
||||
if hasattr(torch, "npu"):
|
||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
|
||||
else:
|
||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tiny_qwen()
|
||||
test_tiny_qwen_with_kernel_plugin()
|
||||
|
||||
@@ -12,16 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from llamafactory.v1.accelerator.helper import get_current_accelerator
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -29,24 +26,29 @@ def clear_accelerator_cache():
|
||||
get_current_accelerator.cache_clear()
|
||||
|
||||
|
||||
def reload_kernels():
|
||||
"""Helper to reload kernel modules to respect mocked accelerator."""
|
||||
# Unload kernel interface and registry
|
||||
keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")]
|
||||
for k in keys_to_remove:
|
||||
del sys.modules[k]
|
||||
|
||||
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_apply_kernel(mock_get_accelerator: MagicMock):
|
||||
mock_device = MagicMock()
|
||||
setattr(mock_device, "type", "npu")
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
# Force reload of kernels with mocked accelerator
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
apply_kernel(model, npu_rope.NpuRoPEKernel)
|
||||
|
||||
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
|
||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
||||
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
|
||||
assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__
|
||||
|
||||
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
@@ -56,12 +58,15 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
||||
setattr(mock_device, "type", "npu")
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
# Force reload of kernels with mocked accelerator
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
model = apply_available_kernels(model)
|
||||
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
model = apply_default_kernels(model=model, include_kernels=True)
|
||||
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
|
||||
assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__
|
||||
|
||||
Reference in New Issue
Block a user