From 16735b9e35c629cfbf152b1ed645d96afb4018d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=AE=E6=A2=A6?= <46097299+frozenleaves@users.noreply.github.com> Date: Wed, 31 Dec 2025 18:26:48 +0800 Subject: [PATCH] [v1] Refactor kernel plugin (#9669) Co-authored-by: frozenleaves --- src/llamafactory/hparams/model_args.py | 2 +- src/llamafactory/model/loader.py | 4 +- src/llamafactory/v1/core/model_loader.py | 7 + .../v1/plugins/model_plugins/kernels/base.py | 87 ++++++ .../model_plugins/kernels/constants.py | 23 -- .../model_plugins/kernels/interface.py | 132 +++++++++ .../kernels/{attn => ops}/__init__.py | 0 .../kernels/{ => ops}/mlp/__init__.py | 0 .../kernels/{ => ops}/mlp/npu_fused_moe.py | 133 ++++++++- .../kernels/{ => ops}/mlp/npu_swiglu.py | 85 +++++- .../kernels/{ => ops}/rms_norm/__init__.py | 0 .../{ => ops}/rms_norm/npu_rms_norm.py | 60 ++-- .../kernels/{ => ops}/rope/__init__.py | 0 .../kernels/ops/rope/npu_rope.py | 146 +++++++++ .../plugins/model_plugins/kernels/registry.py | 279 ++++-------------- .../model_plugins/kernels/rope/npu_rope.py | 122 -------- tests_v1/config/test_args_parser.py | 71 +++++ tests_v1/core/test_model_loader.py | 20 +- .../model_plugins/test_kernel_plugin.py | 39 +-- 19 files changed, 777 insertions(+), 433 deletions(-) create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/base.py delete mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/constants.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/interface.py rename src/llamafactory/v1/plugins/model_plugins/kernels/{attn => ops}/__init__.py (100%) rename src/llamafactory/v1/plugins/model_plugins/kernels/{ => ops}/mlp/__init__.py (100%) rename src/llamafactory/v1/plugins/model_plugins/kernels/{ => ops}/mlp/npu_fused_moe.py (68%) rename src/llamafactory/v1/plugins/model_plugins/kernels/{ => ops}/mlp/npu_swiglu.py (62%) rename src/llamafactory/v1/plugins/model_plugins/kernels/{ => ops}/rms_norm/__init__.py (100%) rename src/llamafactory/v1/plugins/model_plugins/kernels/{ => ops}/rms_norm/npu_rms_norm.py (51%) rename src/llamafactory/v1/plugins/model_plugins/kernels/{ => ops}/rope/__init__.py (100%) create mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py delete mode 100644 src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py create mode 100644 tests_v1/config/test_args_parser.py diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 0d0be63e4..aaa83057a 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -173,7 +173,7 @@ class BaseModelArguments: default=True, metadata={"help": "Whether or not to use KV cache in generation."}, ) - use_v1_kernels: bool = field( + use_v1_kernels: bool | None = field( default=False, metadata={"help": "Whether or not to use high-performance kernels in training."}, ) diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index ef4f3f134..a6dbf3b07 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -216,9 +216,9 @@ def load_model( "You are try to using future feature about kernels, please note that this feature " "is not supported for all models. If get any error, please disable this feature, or report the issue." ) - from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels + from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels - model = apply_available_kernels(model) + model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels) trainable_params, all_param = count_parameters(model) if is_trainable: diff --git a/src/llamafactory/v1/core/model_loader.py b/src/llamafactory/v1/core/model_loader.py index e77d0ae27..ef6ca9324 100644 --- a/src/llamafactory/v1/core/model_loader.py +++ b/src/llamafactory/v1/core/model_loader.py @@ -112,6 +112,13 @@ class ModelLoader: model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train) + if self.args.kernel_config is not None: + from ..plugins.model_plugins.kernels.interface import KernelPlugin + + model = KernelPlugin(self.args.kernel_config.name)( + model=model, include_kernels=self.args.kernel_config.get("include_kernels") + ) + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/base.py b/src/llamafactory/v1/plugins/model_plugins/kernels/base.py new file mode 100644 index 000000000..d5cd83be6 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/base.py @@ -0,0 +1,87 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of base kernel class. + +Init Phase: +1. Define base kernel class. +2. Define abstract methods. + +""" + +from abc import ABC, abstractmethod +from typing import Any + +from ....accelerator.helper import DeviceType, get_current_accelerator +from ....utils.types import HFModel + + +class BaseKernel(ABC): + r"""Base class for all kernel implementations. + + Subclasses must implement the abstract methods and define the required class attributes. + """ + + _kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation + _device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc. + + @classmethod + def get_kernel_id(cls) -> str: + r"""Returns the unique identifier for the kernel.""" + return cls._kernel_id + + @classmethod + def get_device(cls) -> str: + r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu").""" + return cls._device + + @classmethod + def check_deps(cls) -> bool: + r"""Checks if the required dependencies for the kernel are available. + + Returns: + bool: ``True`` if dependencies are met, ``False`` otherwise. + + .. note:: + In explicit mode, if a user specifies an implementation but this check fails, + it should raise an error instead of silently switching. + Kernels can override this method to implement custom dependency checks. + """ + if cls._device != get_current_accelerator().type: + return False + return True + + @classmethod + @abstractmethod + def apply(cls, **kwargs) -> HFModel: + r"""Applies the kernel optimization to the model. + + Args: + **kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration. + + Returns: + HFModel: The model with the kernel applied. + + Raises: + RuntimeError: If the kernel dependencies are not met. + NotImplementedError: If the method is not implemented by the subclass. + + Example: + >>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel + >>> model = HFModel(config=config) + >>> model = apply_kernel(model=model, kernel_id="npu_fused_moe") + """ + if not cls.check_deps(): + raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.") + raise NotImplementedError diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py deleted file mode 100644 index 55a05e942..000000000 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - - -class KernelType(str, Enum): - RMSNORM = "rmsnorm" - SWIGLU = "swiglu" - FLASH_ATTENTION = "flash_attention" - ROPE = "rope" - MOE = "moe" diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py new file mode 100644 index 000000000..19a2def19 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py @@ -0,0 +1,132 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of kernel interface. + +Init Phase: +1. Scan all kernels. +2. Register default kernels. +3. Define kernel plugin. + +""" + +import importlib +from pathlib import Path + +from ....utils.logging import get_logger +from ....utils.plugin import BasePlugin +from .registry import Registry + + +logger = get_logger(__name__) + + +def scan_all_kernels(): + r"""Scan all kernels in the ``ops`` directory. + + Scans the ``ops`` directory for all ``.py`` files and attempts to import them. + Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels. + + Returns: + dict[str, type[BaseKernel]]: A dictionary of registered kernels. + + .. note:: + This function assumes that the ``ops`` directory is located in the same directory as this file. + It recursively searches for ``.py`` files and constructs the module path for import. + """ + ops_path = Path(__file__).parent / "ops" + + if not ops_path.exists(): + return + + base_package = __package__ + + for file_path in ops_path.rglob("*.py"): + if file_path.name == "__init__.py": + continue + + # calculate the relative path: + # file_path = .../kernels_v2/ops/mlp/npu_swiglu.py + # rel_path = ops/mlp/npu_swiglu.py + rel_path = file_path.relative_to(Path(__file__).parent) + + # build module path: + module_name = ".".join(rel_path.parts)[:-3] + full_module_name = f"{base_package}.{module_name}" + + try: + importlib.import_module(full_module_name) + except Exception as e: + logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}") + + return Registry.get_registered_kernels() + + +default_kernels = scan_all_kernels() + + +def get_default_kernels(): + r"""Get a list of default registered kernel IDs. + + Returns: + list[str]: List of kernel IDs. + """ + return list(default_kernels.keys()) + + +def apply_kernel(kernel_id: str, **kwargs): + r"""Applies a specific kernel to the model. + + Args: + kernel_id (str): The ID of the kernel to apply. + **kwargs: Keyword arguments passed to the kernel application function. + Typically includes the model instance. + + Returns: + HFModel: The model with applied kernel. + """ + kernel = default_kernels.get(kernel_id) + if kernel is None: + raise ValueError(f"Kernel {kernel_id} not found") + kernel.apply(**kwargs) + + +class KernelPlugin(BasePlugin): + r"""Plugin for managing kernel optimizations.""" + + pass + + +@KernelPlugin("auto").register +def apply_default_kernels(**kwargs): + r"""Applies all default registered kernels to the model. + + Args: + **kwargs: Keyword arguments passed to the kernel application function. + Typically includes the model instance and the include_kernels configuration. + + Returns: + HFModel: The model with applied kernels. + """ + if not kwargs.get("include_kernels"): # None/False/empty string + return kwargs.get("model") + elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto + use_kernels = default_kernels.keys() + else: + use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3" + for kernel in use_kernels: + if kernel not in default_kernels: + raise ValueError(f"Kernel {kernel} not found") + apply_kernel(kernel, **kwargs) + return kwargs.get("model") diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/attn/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/__init__.py similarity index 100% rename from src/llamafactory/v1/plugins/model_plugins/kernels/attn/__init__.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/ops/__init__.py diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/__init__.py similarity index 100% rename from src/llamafactory/v1/plugins/model_plugins/kernels/mlp/__init__.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/__init__.py diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py similarity index 68% rename from src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py index 1fa9ef470..0d84dbec8 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py @@ -12,22 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""The definition of NPU fused MoE kernels. + +Init Phase: +1. Define GMM functions. +2. Define NPU fused MoE functions. +3. Register NPU fused MoE kernel. + +""" + import types import torch import torch.nn.functional as F -import torch_npu -from .....accelerator.helper import DeviceType, is_torch_npu_available -from .....utils.packages import is_transformers_version_greater_than -from .....utils.types import HFModel -from ..constants import KernelType -from ..registry import MetaMoEKernel + +try: + import torch_npu +except ImportError: + pass + +from ......accelerator.helper import DeviceType +from ......utils.packages import is_transformers_version_greater_than +from ......utils.types import HFModel +from ...base import BaseKernel +from ...registry import register_kernel class GmmFunction(torch.autograd.Function): + r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM).""" + @staticmethod def forward(ctx, x, weight, group_list): + r"""Performs the forward pass of Grouped Matrix Multiplication. + + Args: + ctx: Context object to save tensors for backward pass. + x (Tensor): Input tensor. + weight (Tensor): Weight tensor. + group_list (list): List of group sizes. + + Returns: + Tensor: The result of the grouped matrix multiplication. + """ ctx.save_for_backward(x, weight) ctx.group_list = group_list @@ -38,6 +65,15 @@ class GmmFunction(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): + r"""Performs the backward pass of Grouped Matrix Multiplication. + + Args: + ctx: Context object containing saved tensors. + grad_output (Tensor): Gradient with respect to the output. + + Returns: + tuple: Gradients with respect to input, weight, and None for group_list. + """ input_tensor, weight = ctx.saved_tensors group_list = ctx.group_list @@ -58,8 +94,20 @@ class GmmFunction(torch.autograd.Function): class HybridGmmFunction(torch.autograd.Function): + r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU.""" + @staticmethod def forward(ctx, num_experts, *args): + r"""Performs the forward pass of Hybrid GMM. + + Args: + ctx: Context object to save tensors. + num_experts (int): Number of experts. + *args: Variable length argument list containing inputs and weights. + + Returns: + tuple: The outputs of the grouped matrix multiplication. + """ x_list = list(args[:num_experts]) weight_list = list(args[num_experts:]) @@ -76,6 +124,15 @@ class HybridGmmFunction(torch.autograd.Function): @staticmethod def backward(ctx, *grad_outputs): + r"""Performs the backward pass of Hybrid GMM. + + Args: + ctx: Context object containing saved tensors. + *grad_outputs: Gradients with respect to the outputs. + + Returns: + tuple: Gradients with respect to inputs and weights. + """ saved_tensors = ctx.saved_tensors num_experts = ctx.num_experts split_sizes = ctx.split_sizes @@ -119,10 +176,23 @@ class HybridGmmFunction(torch.autograd.Function): class NpuMoeFused: + r"""Container for NPU fused MoE forward functions.""" + @staticmethod def npu_moe_experts_forward( self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor ) -> torch.Tensor: + r"""Forward pass for MoE experts using NPU fused operations. + + Args: + self: The MoE layer instance. + hidden_states (Tensor): Input hidden states. + routing_weights (Tensor): Routing weights. + router_indices (Tensor): Router indices. + + Returns: + Tensor: Output tensor after expert computation. + """ batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute( @@ -138,6 +208,15 @@ class NpuMoeFused: @staticmethod def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""Forward pass for sparse MoE block using NPU optimization. + + Args: + self: The MoE sparse block instance. + hidden_states (Tensor): Input hidden states. + + Returns: + Tensor: The routed output. + """ batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) router_logits = self.gate(hidden_states) @@ -151,8 +230,19 @@ class NpuMoeFused: class Qwen3NpuMoeFused: + r"""Container for Qwen3 NPU fused MoE forward functions.""" + @staticmethod def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor): + r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations. + + Args: + self: The Qwen3 MoE block instance. + hidden_states (Tensor): Input hidden states. + + Returns: + tuple: A tuple containing the next states and router logits. + """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -206,14 +296,33 @@ if not is_transformers_version_greater_than("5.0.0"): } -class NpuMoEFusedMoEKernel(MetaMoEKernel): - type = KernelType.MOE - device = DeviceType.NPU +@register_kernel +class NpuFusedMoEKernel(BaseKernel): + r"""NPU Fused MoE Kernel implementation.""" + + _kernel_id = "npu_fused_moe" + _device = DeviceType.NPU @classmethod - def apply(cls, model, **kwargs) -> HFModel: - if not is_torch_npu_available(): - return model + def apply(cls, **kwargs) -> HFModel: + r"""Applies the NPU fused MoE kernel to the model. + + Args: + **kwargs: Keyword arguments containing the model. + + Returns: + HFModel: The model with patched MoE forward functions. + + Raises: + ValueError: If the model is not provided. + RuntimeError: If dependencies are not met. + """ + model = kwargs.get("model", None) + if model is None: + raise ValueError(f"HFModel instance is required for {cls.__name__}.") + + if not cls.check_deps(): + raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.") archs = getattr(model.config, "architectures", []) target_moe_mapping = None diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py similarity index 62% rename from src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py index ff115a43d..e6f82051c 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py @@ -12,36 +12,71 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""The definition of NPU fused SwiGLU kernels. + +Init Phase: +1. Define SwiGLU forward functions. +2. Register NPU fused SwiGLU kernel. + +""" + import re import types import torch -from .....accelerator.helper import DeviceType, is_torch_npu_available -from .....utils.types import HFModel -from ..constants import KernelType -from ..registry import MetaSwiGluKernel +from ......accelerator.helper import DeviceType +from ......utils.types import HFModel +from ...base import BaseKernel +from ...registry import register_kernel -def _npu_swiglu_forward(self, hidden_state): +try: import torch_npu +except ImportError: + pass + +def npu_swiglu_forward(self, hidden_state): + r"""SwiGLU forward pass for NPU. + + Args: + self: The MLP layer instance. + hidden_state (Tensor): Input hidden state. + + Returns: + Tensor: Output of SwiGLU. + """ return self.down_proj( torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1) ) def _npu_swiglu_glm4_forward(self, hidden_states): - import torch_npu + r"""SwiGLU forward pass for GLM4 on NPU. + Args: + self: The GLM4 MLP layer instance. + hidden_states (Tensor): Input hidden states. + + Returns: + Tensor: Output of SwiGLU. + """ up_states = self.gate_up_proj(hidden_states) gate, up_states = up_states.chunk(2, dim=-1) return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1)) def _npu_swiglu_gemma3ntext_forward(self, hidden_states): - import torch_npu + r"""SwiGLU forward pass for Gemma3nText on NPU. + Args: + self: The Gemma3nText MLP layer instance. + hidden_states (Tensor): Input hidden states. + + Returns: + Tensor: Output of SwiGLU. + """ gate_proj = self.gate_proj(hidden_states) if self.activation_sparsity > 0.0: gate_proj = self._gaussian_topk(gate_proj) @@ -51,12 +86,11 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states): return down_proj -class NpuSwiGluKernel(MetaSwiGluKernel): - type = KernelType.SWIGLU - device = DeviceType.NPU - kernel = _npu_swiglu_forward +@register_kernel +class NpuSwiGluKernel(BaseKernel): + r"""NPU Kernel for fused SwiGLU activation.""" - # Don't apply the kernel to the following modules + # just support apply to the following module layers expect_modules = frozenset( { "Qwen3VLMoeTextMLP", @@ -87,10 +121,29 @@ class NpuSwiGluKernel(MetaSwiGluKernel): } ) + _kernel_id = "npu_fused_swiglu" + _device = DeviceType.NPU + @classmethod - def apply(cls, model, **kwargs) -> "HFModel": - if not is_torch_npu_available(): - return model + def apply(cls, **kwargs) -> "HFModel": + r"""Applies the NPU fused SwiGLU kernel to the model. + + Args: + **kwargs: Keyword arguments containing the model. + + Returns: + HFModel: The model with patched SwiGLU forward functions. + + Raises: + ValueError: If the model is not provided. + RuntimeError: If dependencies are not met. + """ + model = kwargs.get("model", None) + if model is None: + raise ValueError(f"HFModel instance is required for {cls.__name__}.") + + if not cls.check_deps(): + raise RuntimeError("torch_npu is not available but NpuSwiGluKernel was called.") # Mapping of specific mlp modules to their corresponding kernel implementations kernel_mapping = { @@ -109,7 +162,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel): ): # Bind function as an instance method to preserve `self` semantics # and replace the original forward - kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward) + kernel_func = kernel_mapping.get(module.__class__.__name__, npu_swiglu_forward) module.forward = types.MethodType(kernel_func, module) return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/__init__.py similarity index 100% rename from src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/__init__.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/__init__.py diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py similarity index 51% rename from src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py index 7ff00898e..6ce36bb67 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py @@ -11,40 +11,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""The definition of NPU fused RMSNorm kernels. + +Init Phase: +1. Define RMSNorm forward function. +2. Register NPU fused RMSNorm kernel. + +""" + import re import types -from .....accelerator.helper import DeviceType, is_torch_npu_available -from .....utils.types import HFModel -from ..constants import KernelType -from ..registry import MetaRMSNormKernel +from ......accelerator.helper import DeviceType +from ......utils.types import HFModel +from ...base import BaseKernel +from ...registry import register_kernel -def _npu_rms_forward(self, hidden_states): - """NPU forward implementation for RMSNorm. +def npu_rms_norm_forward(self, hidden_states): + r"""NPU forward implementation for RMSNorm. Args: self: RMSNorm module instance with `weight` and `variance_epsilon`. - hidden_states: Input hidden states tensor, same shape as the baseline. + hidden_states (Tensor): Input hidden states tensor, same shape as the baseline. Returns: - Normalized tensor consistent with the baseline RMSNorm behavior. + Tensor: Normalized tensor consistent with the baseline RMSNorm behavior. """ import torch_npu return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] -class NpuRMSNormKernel(MetaRMSNormKernel): - """NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" +@register_kernel +class NpuRMSNormKernel(BaseKernel): + r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" - type = KernelType.RMSNORM - device = DeviceType.NPU - kernel = _npu_rms_forward + _kernel_id = "npu_fused_rmsnorm" + _device = DeviceType.NPU @classmethod - def apply(cls, model, **kwargs) -> HFModel: - """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. + def apply(cls, **kwargs) -> "HFModel": + r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. Key points: - Match modules whose class name contains "RMSNorm" (case-insensitive). @@ -52,10 +61,23 @@ class NpuRMSNormKernel(MetaRMSNormKernel): replace the original `forward`. - Do not modify weights, hyperparameters, or module structure to ensure numerical behavior and interface consistency. - """ - if not is_torch_npu_available(): - return model + Args: + **kwargs: Keyword arguments containing the model. + + Returns: + HFModel: The model with NPU fused RMSNorm. + + Raises: + RuntimeError: If torch_npu is not available. + ValueError: If the model is not provided. + """ + model = kwargs.get("model") + if model is None: + raise ValueError(f"HFModel instance is required for {cls.__name__}.") + + if not cls.check_deps(): + raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.") rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE) for name, module in model.named_modules(): @@ -63,6 +85,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel): if re.search(rms_norm_pattern, module.__class__.__name__): # Bind function as an instance method to preserve `self` semantics # and replace the original forward - module.forward = types.MethodType(cls.kernel, module) + module.forward = types.MethodType(npu_rms_norm_forward, module) return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/__init__.py similarity index 100% rename from src/llamafactory/v1/plugins/model_plugins/kernels/rope/__init__.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/__init__.py diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py new file mode 100644 index 000000000..b431b5063 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py @@ -0,0 +1,146 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of NPU fused RoPE kernels. + +Init Phase: +1. Define RoPE forward functions. +2. Register NPU fused RoPE kernel. + +""" + +import sys + +import torch + +from ......accelerator.helper import DeviceType +from ......utils.logging import get_logger +from ......utils.types import HFModel +from ...base import BaseKernel +from ...registry import register_kernel + + +logger = get_logger(__name__) + +try: + import torch_npu +except ImportError: + pass + + +def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization. + + Args: + q (Tensor): Query tensor. + k (Tensor): Key tensor. + cos (Tensor): Cosine part of embedding. + sin (Tensor): Sine part of embedding. + position_ids (Tensor, optional): Position IDs. Default: ``None``. + unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1. + + Returns: + tuple: (q_embed, k_embed) The embedded query and key tensors. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed + + +def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU. + + Args: + q (Tensor): Query tensor. + k (Tensor): Key tensor. + cos (Tensor): Cosine part of embedding. + sin (Tensor): Sine part of embedding. + mrope_section (Tensor): Multimodal RoPE section. + unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1. + + Returns: + tuple: (q_embed, k_embed) The embedded query and key tensors. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed + + +@register_kernel +class NpuRoPEKernel(BaseKernel): + r"""NPU Kernel for Rotary Position Embedding.""" + + _kernel_id = "npu_fused_rope" + _device = DeviceType.NPU + + @classmethod + def apply(cls, **kwargs) -> "HFModel": + r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. + + This function iterates through the model's modules to find attention layers, + identifies the module where they are defined, and replaces the original + `apply_rotary_pos_emb` function in that module's namespace with the + NPU-accelerated version from this file. + + Args: + **kwargs: Keyword arguments containing the model. + + Returns: + HFModel: The model with patched RoPE functions. + + Raises: + RuntimeError: If dependencies are not met. + ValueError: If the model is not provided. + """ + if not cls.check_deps(): + raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.") + model = kwargs.get("model", None) + if model is None: + raise ValueError(f"HFModel instance is required for {cls.__name__}.") + _modules = set() + for module in model.modules(): + if "Attention" in module.__class__.__name__: + module_name = module.__class__.__module__ + if module_name in _modules: + continue + try: + target_module = sys.modules[module_name] + if hasattr(target_module, "apply_rotary_pos_emb"): + if getattr(target_module, "apply_rotary_pos_emb") is not _apply_rotary_pos_emb: + setattr(target_module, "apply_rotary_pos_emb", _apply_rotary_pos_emb) + _modules.add(module_name) + if hasattr(target_module, "apply_multimodal_rotary_pos_emb"): + if ( + getattr(target_module, "apply_multimodal_rotary_pos_emb") + is not _apply_multimodal_rotary_pos_emb_qwen25_vl + ): + setattr( + target_module, + "apply_multimodal_rotary_pos_emb", + _apply_multimodal_rotary_pos_emb_qwen25_vl, + ) + _modules.add(module_name) + except Exception as e: + logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}") + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index 78a235074..f6c1984ae 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -12,247 +12,86 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, ABCMeta, abstractmethod -from collections.abc import Callable -from typing import Any, Optional +"""The definition of kernel registry. -from ....accelerator.helper import DeviceType, get_current_accelerator -from ....utils.types import HFModel -from .constants import KernelType +Init Phase: +1. Define kernel registry. +2. Register kernels. + +""" + +from typing import Optional + +from ....accelerator.helper import get_current_accelerator +from .base import BaseKernel -class KernelRegistry: - _instance: Optional["KernelRegistry"] = None - _initialized: bool = False - - def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry": - if cls._instance is None: - cls._instance = super().__new__(cls) - - return cls._instance - - def __init__(self) -> None: - if self._initialized: - return - - self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {} - self._initialized = True - - def register( - self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None - ) -> None: - """Register a kernel implementation. - - Args: - kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION). - device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA). - kernel_impl: the actual kernel function or class. - """ - if kernel_type not in self._registry: - self._registry[kernel_type] = {} - - if device_type in self._registry[kernel_type]: - print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.") - - self._registry[kernel_type][device_type] = kernel_impl - print(f"Registered kernel {kernel_type.name} for device {device_type.name}.") - - def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None: - return self._registry.get(kernel_type, {}).get(device_type) +__all__ = ["Registry", "register_kernel"] -KERNEL_REGISTRY = KernelRegistry() +class Registry: + r"""Registry for managing kernel implementations. - -class AutoRegisterKernelMeta(ABCMeta): - """Metaclass that automatically registers kernel classes upon creation. - - This metaclass checks if a newly created class has both `type` and `device` - attributes defined. If so, it automatically registers the kernel in the - global KERNEL_REGISTRY, eliminating the need for manual registration. - - To disable auto-registration for a specific class, set `auto_register = False`. + Storage structure: ``{ "kernel_id": Class }`` """ - def __new__(mcs, name, bases, namespace, **kwargs): - cls = super().__new__(mcs, name, bases, namespace, **kwargs) - - # Check if auto-registration is disabled - auto_register = namespace.get("auto_register", True) - - # Only auto-register if the class has both type and device attributes defined - # and they are not None (skip base classes like MetaKernel itself) - # and auto_register is True - kernel_type = namespace.get("type") - device_type = namespace.get("device") - - if auto_register and kernel_type is not None and device_type is not None: - # Auto-register this kernel - KERNEL_REGISTRY.register(kernel_type, device_type, cls) - - return cls - - -class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta): - """Base class for all kernel implementations. - - Subclasses are automatically registered when they define both `type` and `device` - attributes. To disable auto-registration, set `auto_register = False`. - - Attributes: - type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses. - device: The device type (e.g., DeviceType.NPU). Must be set in subclasses. - kernel: The actual kernel function or implementation. - auto_register: Set to False to disable automatic registration (default: True). - """ - - type: KernelType | None = None - device: DeviceType | None = None - kernel: Callable | None = None + _kernels: dict[str, type[BaseKernel]] = {} @classmethod - @abstractmethod - def apply(cls, model: HFModel, **kwargs) -> HFModel: - """Apply the kernel to the model. + def register(cls, kernel_cls: type[BaseKernel]): + r"""Decorator to register a kernel class. - This method should check if the kernel can be applied (e.g., dependencies - are installed, target modules exist) and perform the kernel replacement. + The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes. Args: - model: The HuggingFace model to optimize. - **kwargs: Additional arguments for kernel application. + kernel_cls (type[BaseKernel]): The kernel class to register. Returns: - The optimized model (may be the same object with modifications). + type[BaseKernel]: The registered kernel class. + + Raises: + TypeError: If the class does not inherit from :class:`BaseKernel`. + ValueError: If the kernel ID is missing or already registered. """ - raise NotImplementedError + if not issubclass(kernel_cls, BaseKernel): + raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel") + kernel_id = kernel_cls.get_kernel_id() + device = kernel_cls.get_device() + # The device type of the current accelerator does not match the device type required by the kernel, skip registration + if device != get_current_accelerator().type: + return + + if not kernel_id: + raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register") + + if kernel_id in cls._kernels: + raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}") + + cls._kernels[kernel_id] = kernel_cls + return kernel_cls -class MetaFlashAttentionKernel(MetaKernel): @classmethod - def apply(cls, model: HFModel, **kwargs) -> HFModel: - raise NotImplementedError + def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]: + r"""Retrieves a registered kernel implementation by its ID. + Args: + kernel_id (str): The ID of the kernel to retrieve. + + Returns: + Optional[type[BaseKernel]]: The kernel class if found, else ``None``. + """ + return cls._kernels.get(kernel_id) -class MetaRMSNormKernel(MetaKernel): @classmethod - def apply(cls, model: HFModel, **kwargs) -> HFModel: - raise NotImplementedError + def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]: + r"""Returns a dictionary of all registered kernels. + + Returns: + dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes. + """ + return cls._kernels -class MetaSwiGluKernel(MetaKernel): - @classmethod - def apply(cls, model: HFModel, **kwargs) -> HFModel: - raise NotImplementedError - - -class MetaRoPEKernel(MetaKernel): - @classmethod - def apply(cls, model: HFModel, **kwargs) -> HFModel: - raise NotImplementedError - - -class MetaMoEKernel(MetaKernel): - @classmethod - def apply(cls, model: HFModel, **kwargs) -> HFModel: - raise NotImplementedError - - -def _ensure_kernels_loaded() -> None: - """Ensure all kernel implementations are imported and registered. - - This function dynamically imports all kernel implementation modules to trigger - their auto-registration. Python's module system ensures each module is only - executed once (cached in sys.modules), so repeated calls are safe and fast. - """ - # List of kernel module paths to import - kernel_modules = [ - "rms_norm.npu_rms_norm", - "rope.npu_rope", - "mlp.npu_swiglu", - "mlp.npu_fused_moe", - # Add new kernel modules here as they are created - ] - - # Import each module to trigger kernel registration - # Python's import system caches modules, so this is fast on subsequent calls - for module_name in kernel_modules: - try: - __import__(f"{__package__}.{module_name}", fromlist=["*"]) - except ImportError: - # Silently ignore import errors (e.g., missing dependencies like torch_npu) - pass - - -def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]: - """Discover and return all kernel classes registered for the current device. - - This function inspects the runtime environment (device type) and returns - all MetaKernel classes registered for that device. Each kernel's `apply()` - method is responsible for checking if it can actually be applied (e.g., - required dependencies are installed, target modules exist in the model). - - The function automatically discovers all kernels registered in KERNEL_REGISTRY - without requiring manual enumeration. On first call, it dynamically imports - all kernel implementation modules to trigger their auto-registration. - - Args: - model: The HuggingFace model to apply kernels to. - TODO: implement the kernel route detection logic by model structure. - - Returns: - A list of MetaKernel classes available for the current device. - """ - # Ensure all kernel modules are imported to trigger registration - _ensure_kernels_loaded() - - discovered_kernels: list[type[MetaKernel]] = [] - - # Detect current device type - accelerator = get_current_accelerator() - try: - device_type = DeviceType(accelerator.type) - except ValueError: - # Unknown device type, return empty list - return discovered_kernels - - # Skip CPU as it typically doesn't have optimized kernels - if device_type == DeviceType.CPU: - return discovered_kernels - - # Iterate through registry and collect all kernels for current device - for devices in KERNEL_REGISTRY._registry.values(): - kernel_cls = devices.get(device_type) - if kernel_cls is not None: - discovered_kernels.append(kernel_cls) - - return discovered_kernels - - -def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel": - """Call the MetaKernel's `apply` to perform the replacement. - - Corresponding replacement logic is maintained inside each kernel; the only - requirement is that `apply` returns the replaced model. - - Example: - from transformers import AutoModelForCausalLM - from .rms_norm.npu_rms_norm import NpuRMSNormKernel - model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B") - model = apply_kernel(model, NpuRMSNormKernel) - """ - if not issubclass(kernel, MetaKernel): - raise ValueError(f"{kernel} must be a MetaKernel instance.") - - if kernel.device != get_current_accelerator().type: - raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.") - - return kernel.apply(model, **kwargs) - - -def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel": - """Apply all available kernels to the model.""" - for kernel in discover_kernels(model): - model = apply_kernel(model, kernel, **kwargs) - - return model +# export decorator alias +register_kernel = Registry.register diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py deleted file mode 100644 index 82fccce70..000000000 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -import torch - -from .....accelerator.helper import DeviceType, is_torch_npu_available -from .....utils.types import HFModel -from ..constants import KernelType -from ..registry import MetaRoPEKernel - - -def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors.""" - import torch_npu - - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = torch_npu.npu_rotary_mul(q, cos, sin) - k_embed = torch_npu.npu_rotary_mul(k, cos, sin) - return q_embed, k_embed - - -def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with multimodal sections (Qwen2-VL).""" - import torch_npu - - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - - q_embed = torch_npu.npu_rotary_mul(q, cos, sin) - k_embed = torch_npu.npu_rotary_mul(k, cos, sin) - return q_embed, k_embed - - -class NpuRoPEKernel(MetaRoPEKernel): - type = KernelType.ROPE - device = DeviceType.NPU - kernel = _apply_rotary_pos_emb - - @classmethod - def apply(cls, model, **kwargs) -> "HFModel": - """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. - - This function iterates through the model's modules to find attention layers, - identifies the module where they are defined, and replaces the original - `apply_rotary_pos_emb` function in that module's namespace with the - NPU-accelerated version from this file. - """ - if not is_torch_npu_available(): - return model - - _modules = set() - for module in model.modules(): - if "Attention" in module.__class__.__name__: - module_name = module.__class__.__module__ - if module_name in _modules: - continue - try: - target_module = sys.modules[module_name] - if hasattr(target_module, "apply_rotary_pos_emb"): - if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel: - setattr(target_module, "apply_rotary_pos_emb", cls.kernel) - _modules.add(module_name) - except Exception: - pass - return model - - -class NpuQwen2VLRoPEKernel(MetaRoPEKernel): - """Qwen2-VL specific RoPE kernel - not auto-registered. - - This kernel is for specific models (Qwen2-VL) and should be manually - applied when needed rather than auto-discovered. - """ - - type = KernelType.ROPE - device = DeviceType.NPU - kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl - auto_register = False # Disable auto-registration for model-specific kernel - - @classmethod - def apply(cls, model, **kwargs) -> "HFModel": - """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. - - This function iterates through the model's modules to find attention layers, - identifies the module where they are defined, and replaces the original - `apply_rotary_pos_emb` function in that module's namespace with the - NPU-accelerated version from this file. - """ - _modules = set() - for module in model.modules(): - if "Attention" in module.__class__.__name__: - module_name = module.__class__.__module__ - if module_name in _modules: - continue - try: - target_module = sys.modules[module_name] - if hasattr(target_module, "apply_multimodal_rotary_pos_emb"): - if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel: - setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel) - _modules.add(module_name) - except Exception: - pass - return model diff --git a/tests_v1/config/test_args_parser.py b/tests_v1/config/test_args_parser.py new file mode 100644 index 000000000..945e0e572 --- /dev/null +++ b/tests_v1/config/test_args_parser.py @@ -0,0 +1,71 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import pathlib +from unittest.mock import patch + +from llamafactory.v1.config.arg_parser import get_args + + +def test_get_args_from_yaml(tmp_path: pathlib.Path): + config_yaml = """ + ### model + model: "llamafactory/tiny-random-qwen2.5" + trust_remote_code: true + use_fast_processor: true + model_class: "llm" + kernel_config: + name: "auto" + include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null + peft_config: + name: "lora" + lora_rank: 0.8 + quant_config: null + + ### data + dataset: "llamafactory/tiny-supervised-dataset" + cutoff_len: 2048 + + ### training + output_dir: "outputs/test_run" + micro_batch_size: 1 + global_batch_size: 1 + learning_rate: 1.0e-4 + bf16: false + dist_config: null + + ### sample + sample_backend: "hf" + max_new_tokens: 128 + """ + + config_file = tmp_path / "config.yaml" + config_file.write_text(config_yaml, encoding="utf-8") + + test_argv = ["test_args_parser.py", str(config_file)] + + with patch.object(sys, "argv", test_argv): + data_args, model_args, training_args, sample_args = get_args() + assert training_args.output_dir == "outputs/test_run" + assert training_args.micro_batch_size == 1 + assert training_args.global_batch_size == 1 + assert training_args.learning_rate == 1.0e-4 + assert training_args.bf16 is False + assert training_args.dist_config is None + assert model_args.model == "llamafactory/tiny-random-qwen2.5" + assert model_args.kernel_config.name == "auto" + assert model_args.kernel_config.get("include_kernels") == "auto" + assert model_args.peft_config.name == "lora" + assert model_args.peft_config.get("lora_rank") == 0.8 diff --git a/tests_v1/core/test_model_loader.py b/tests_v1/core/test_model_loader.py index cee2e9f79..fa038229e 100644 --- a/tests_v1/core/test_model_loader.py +++ b/tests_v1/core/test_model_loader.py @@ -14,7 +14,7 @@ import torch -from llamafactory.v1.config.model_args import ModelArguments +from llamafactory.v1.config.model_args import ModelArguments, PluginConfig from llamafactory.v1.core.model_loader import ModelLoader @@ -29,5 +29,23 @@ def test_tiny_qwen(): assert model_loader.model.dtype == torch.bfloat16 +def test_tiny_qwen_with_kernel_plugin(): + from transformers import Qwen2ForCausalLM + + from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward + + model_args = ModelArguments( + model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto") + ) + model_loader = ModelLoader(model_args) + # test enable apply kernel plugin + if hasattr(torch, "npu"): + assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__ + else: + assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__ + assert isinstance(model_loader.model, Qwen2ForCausalLM) + + if __name__ == "__main__": test_tiny_qwen() + test_tiny_qwen_with_kernel_plugin() diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index 04f99a757..f087a822f 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys from unittest.mock import MagicMock, patch import pytest from transformers import AutoModelForCausalLM from llamafactory.v1.accelerator.helper import get_current_accelerator -from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu -from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel -from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm -from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope @pytest.fixture(autouse=True) @@ -29,24 +26,29 @@ def clear_accelerator_cache(): get_current_accelerator.cache_clear() +def reload_kernels(): + """Helper to reload kernel modules to respect mocked accelerator.""" + # Unload kernel interface and registry + keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")] + for k in keys_to_remove: + del sys.modules[k] + + @patch("torch.accelerator.current_accelerator") def test_apply_kernel(mock_get_accelerator: MagicMock): mock_device = MagicMock() setattr(mock_device, "type", "npu") mock_get_accelerator.return_value = mock_device + # Force reload of kernels with mocked accelerator + reload_kernels() + from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") - original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_swiglu_forward = model.model.layers[0].mlp.forward - - apply_kernel(model, npu_rope.NpuRoPEKernel) - - model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel) - assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward - - model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel) - assert model.model.layers[0].mlp.forward is not original_swiglu_forward + model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm") + assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__ + assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__ @patch("torch.accelerator.current_accelerator") @@ -56,12 +58,15 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock): setattr(mock_device, "type", "npu") mock_get_accelerator.return_value = mock_device + # Force reload of kernels with mocked accelerator + reload_kernels() + from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_swiglu_forward = model.model.layers[0].mlp.forward - model = apply_available_kernels(model) - - assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward - assert model.model.layers[0].mlp.forward is not original_swiglu_forward + model = apply_default_kernels(model=model, include_kernels=True) + assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__ + assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__