mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-10 08:00:36 +08:00
[v1] Refactor kernel plugin (#9669)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@@ -173,7 +173,7 @@ class BaseModelArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||||
)
|
)
|
||||||
use_v1_kernels: bool = field(
|
use_v1_kernels: bool | None = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use high-performance kernels in training."},
|
metadata={"help": "Whether or not to use high-performance kernels in training."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -216,9 +216,9 @@ def load_model(
|
|||||||
"You are try to using future feature about kernels, please note that this feature "
|
"You are try to using future feature about kernels, please note that this feature "
|
||||||
"is not supported for all models. If get any error, please disable this feature, or report the issue."
|
"is not supported for all models. If get any error, please disable this feature, or report the issue."
|
||||||
)
|
)
|
||||||
from ..v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||||
|
|
||||||
model = apply_available_kernels(model)
|
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
|
|||||||
@@ -112,6 +112,13 @@ class ModelLoader:
|
|||||||
|
|
||||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||||
|
|
||||||
|
if self.args.kernel_config is not None:
|
||||||
|
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||||
|
|
||||||
|
model = KernelPlugin(self.args.kernel_config.name)(
|
||||||
|
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
87
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
Normal file
87
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""The definition of base kernel class.
|
||||||
|
|
||||||
|
Init Phase:
|
||||||
|
1. Define base kernel class.
|
||||||
|
2. Define abstract methods.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||||
|
from ....utils.types import HFModel
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKernel(ABC):
|
||||||
|
r"""Base class for all kernel implementations.
|
||||||
|
|
||||||
|
Subclasses must implement the abstract methods and define the required class attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
|
||||||
|
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_kernel_id(cls) -> str:
|
||||||
|
r"""Returns the unique identifier for the kernel."""
|
||||||
|
return cls._kernel_id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device(cls) -> str:
|
||||||
|
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||||
|
return cls._device
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_deps(cls) -> bool:
|
||||||
|
r"""Checks if the required dependencies for the kernel are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if dependencies are met, ``False`` otherwise.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
In explicit mode, if a user specifies an implementation but this check fails,
|
||||||
|
it should raise an error instead of silently switching.
|
||||||
|
Kernels can override this method to implement custom dependency checks.
|
||||||
|
"""
|
||||||
|
if cls._device != get_current_accelerator().type:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def apply(cls, **kwargs) -> HFModel:
|
||||||
|
r"""Applies the kernel optimization to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HFModel: The model with the kernel applied.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the kernel dependencies are not met.
|
||||||
|
NotImplementedError: If the method is not implemented by the subclass.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
|
||||||
|
>>> model = HFModel(config=config)
|
||||||
|
>>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
|
||||||
|
"""
|
||||||
|
if not cls.check_deps():
|
||||||
|
raise RuntimeError(f"{cls.__name__} is not available but {cls.__name__} kernel was called.")
|
||||||
|
raise NotImplementedError
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class KernelType(str, Enum):
|
|
||||||
RMSNORM = "rmsnorm"
|
|
||||||
SWIGLU = "swiglu"
|
|
||||||
FLASH_ATTENTION = "flash_attention"
|
|
||||||
ROPE = "rope"
|
|
||||||
MOE = "moe"
|
|
||||||
132
src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
Normal file
132
src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""The definition of kernel interface.
|
||||||
|
|
||||||
|
Init Phase:
|
||||||
|
1. Scan all kernels.
|
||||||
|
2. Register default kernels.
|
||||||
|
3. Define kernel plugin.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ....utils.logging import get_logger
|
||||||
|
from ....utils.plugin import BasePlugin
|
||||||
|
from .registry import Registry
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def scan_all_kernels():
|
||||||
|
r"""Scan all kernels in the ``ops`` directory.
|
||||||
|
|
||||||
|
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
|
||||||
|
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, type[BaseKernel]]: A dictionary of registered kernels.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This function assumes that the ``ops`` directory is located in the same directory as this file.
|
||||||
|
It recursively searches for ``.py`` files and constructs the module path for import.
|
||||||
|
"""
|
||||||
|
ops_path = Path(__file__).parent / "ops"
|
||||||
|
|
||||||
|
if not ops_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
base_package = __package__
|
||||||
|
|
||||||
|
for file_path in ops_path.rglob("*.py"):
|
||||||
|
if file_path.name == "__init__.py":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# calculate the relative path:
|
||||||
|
# file_path = .../kernels_v2/ops/mlp/npu_swiglu.py
|
||||||
|
# rel_path = ops/mlp/npu_swiglu.py
|
||||||
|
rel_path = file_path.relative_to(Path(__file__).parent)
|
||||||
|
|
||||||
|
# build module path:
|
||||||
|
module_name = ".".join(rel_path.parts)[:-3]
|
||||||
|
full_module_name = f"{base_package}.{module_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
importlib.import_module(full_module_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Kernel Registry] Failed to import {full_module_name} when loading kernels: {e}")
|
||||||
|
|
||||||
|
return Registry.get_registered_kernels()
|
||||||
|
|
||||||
|
|
||||||
|
default_kernels = scan_all_kernels()
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_kernels():
|
||||||
|
r"""Get a list of default registered kernel IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: List of kernel IDs.
|
||||||
|
"""
|
||||||
|
return list(default_kernels.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def apply_kernel(kernel_id: str, **kwargs):
|
||||||
|
r"""Applies a specific kernel to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_id (str): The ID of the kernel to apply.
|
||||||
|
**kwargs: Keyword arguments passed to the kernel application function.
|
||||||
|
Typically includes the model instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HFModel: The model with applied kernel.
|
||||||
|
"""
|
||||||
|
kernel = default_kernels.get(kernel_id)
|
||||||
|
if kernel is None:
|
||||||
|
raise ValueError(f"Kernel {kernel_id} not found")
|
||||||
|
kernel.apply(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class KernelPlugin(BasePlugin):
|
||||||
|
r"""Plugin for managing kernel optimizations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@KernelPlugin("auto").register
|
||||||
|
def apply_default_kernels(**kwargs):
|
||||||
|
r"""Applies all default registered kernels to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments passed to the kernel application function.
|
||||||
|
Typically includes the model instance and the include_kernels configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HFModel: The model with applied kernels.
|
||||||
|
"""
|
||||||
|
if not kwargs.get("include_kernels"): # None/False/empty string
|
||||||
|
return kwargs.get("model")
|
||||||
|
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
|
||||||
|
use_kernels = default_kernels.keys()
|
||||||
|
else:
|
||||||
|
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||||
|
for kernel in use_kernels:
|
||||||
|
if kernel not in default_kernels:
|
||||||
|
raise ValueError(f"Kernel {kernel} not found")
|
||||||
|
apply_kernel(kernel, **kwargs)
|
||||||
|
return kwargs.get("model")
|
||||||
@@ -12,22 +12,49 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""The definition of NPU fused MoE kernels.
|
||||||
|
|
||||||
|
Init Phase:
|
||||||
|
1. Define GMM functions.
|
||||||
|
2. Define NPU fused MoE functions.
|
||||||
|
3. Register NPU fused MoE kernel.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import types
|
import types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
|
||||||
from .....utils.packages import is_transformers_version_greater_than
|
try:
|
||||||
from .....utils.types import HFModel
|
import torch_npu
|
||||||
from ..constants import KernelType
|
except ImportError:
|
||||||
from ..registry import MetaMoEKernel
|
pass
|
||||||
|
|
||||||
|
from ......accelerator.helper import DeviceType
|
||||||
|
from ......utils.packages import is_transformers_version_greater_than
|
||||||
|
from ......utils.types import HFModel
|
||||||
|
from ...base import BaseKernel
|
||||||
|
from ...registry import register_kernel
|
||||||
|
|
||||||
|
|
||||||
class GmmFunction(torch.autograd.Function):
|
class GmmFunction(torch.autograd.Function):
|
||||||
|
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, weight, group_list):
|
def forward(ctx, x, weight, group_list):
|
||||||
|
r"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context object to save tensors for backward pass.
|
||||||
|
x (Tensor): Input tensor.
|
||||||
|
weight (Tensor): Weight tensor.
|
||||||
|
group_list (list): List of group sizes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The result of the grouped matrix multiplication.
|
||||||
|
"""
|
||||||
ctx.save_for_backward(x, weight)
|
ctx.save_for_backward(x, weight)
|
||||||
ctx.group_list = group_list
|
ctx.group_list = group_list
|
||||||
|
|
||||||
@@ -38,6 +65,15 @@ class GmmFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
r"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context object containing saved tensors.
|
||||||
|
grad_output (Tensor): Gradient with respect to the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Gradients with respect to input, weight, and None for group_list.
|
||||||
|
"""
|
||||||
input_tensor, weight = ctx.saved_tensors
|
input_tensor, weight = ctx.saved_tensors
|
||||||
group_list = ctx.group_list
|
group_list = ctx.group_list
|
||||||
|
|
||||||
@@ -58,8 +94,20 @@ class GmmFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
class HybridGmmFunction(torch.autograd.Function):
|
class HybridGmmFunction(torch.autograd.Function):
|
||||||
|
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, num_experts, *args):
|
def forward(ctx, num_experts, *args):
|
||||||
|
r"""Performs the forward pass of Hybrid GMM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context object to save tensors.
|
||||||
|
num_experts (int): Number of experts.
|
||||||
|
*args: Variable length argument list containing inputs and weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The outputs of the grouped matrix multiplication.
|
||||||
|
"""
|
||||||
x_list = list(args[:num_experts])
|
x_list = list(args[:num_experts])
|
||||||
weight_list = list(args[num_experts:])
|
weight_list = list(args[num_experts:])
|
||||||
|
|
||||||
@@ -76,6 +124,15 @@ class HybridGmmFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *grad_outputs):
|
def backward(ctx, *grad_outputs):
|
||||||
|
r"""Performs the backward pass of Hybrid GMM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context object containing saved tensors.
|
||||||
|
*grad_outputs: Gradients with respect to the outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Gradients with respect to inputs and weights.
|
||||||
|
"""
|
||||||
saved_tensors = ctx.saved_tensors
|
saved_tensors = ctx.saved_tensors
|
||||||
num_experts = ctx.num_experts
|
num_experts = ctx.num_experts
|
||||||
split_sizes = ctx.split_sizes
|
split_sizes = ctx.split_sizes
|
||||||
@@ -119,10 +176,23 @@ class HybridGmmFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
class NpuMoeFused:
|
class NpuMoeFused:
|
||||||
|
r"""Container for NPU fused MoE forward functions."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def npu_moe_experts_forward(
|
def npu_moe_experts_forward(
|
||||||
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
r"""Forward pass for MoE experts using NPU fused operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The MoE layer instance.
|
||||||
|
hidden_states (Tensor): Input hidden states.
|
||||||
|
routing_weights (Tensor): Routing weights.
|
||||||
|
router_indices (Tensor): Router indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor after expert computation.
|
||||||
|
"""
|
||||||
batch_size = hidden_states.shape[0]
|
batch_size = hidden_states.shape[0]
|
||||||
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||||
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||||
@@ -138,6 +208,15 @@ class NpuMoeFused:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def npu_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
r"""Forward pass for sparse MoE block using NPU optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The MoE sparse block instance.
|
||||||
|
hidden_states (Tensor): Input hidden states.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The routed output.
|
||||||
|
"""
|
||||||
batch_size = hidden_states.shape[0]
|
batch_size = hidden_states.shape[0]
|
||||||
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
@@ -151,8 +230,19 @@ class NpuMoeFused:
|
|||||||
|
|
||||||
|
|
||||||
class Qwen3NpuMoeFused:
|
class Qwen3NpuMoeFused:
|
||||||
|
r"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
||||||
|
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The Qwen3 MoE block instance.
|
||||||
|
hidden_states (Tensor): Input hidden states.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing the next states and router logits.
|
||||||
|
"""
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
@@ -206,14 +296,33 @@ if not is_transformers_version_greater_than("5.0.0"):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class NpuMoEFusedMoEKernel(MetaMoEKernel):
|
@register_kernel
|
||||||
type = KernelType.MOE
|
class NpuFusedMoEKernel(BaseKernel):
|
||||||
device = DeviceType.NPU
|
r"""NPU Fused MoE Kernel implementation."""
|
||||||
|
|
||||||
|
_kernel_id = "npu_fused_moe"
|
||||||
|
_device = DeviceType.NPU
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> HFModel:
|
def apply(cls, **kwargs) -> HFModel:
|
||||||
if not is_torch_npu_available():
|
r"""Applies the NPU fused MoE kernel to the model.
|
||||||
return model
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments containing the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HFModel: The model with patched MoE forward functions.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model is not provided.
|
||||||
|
RuntimeError: If dependencies are not met.
|
||||||
|
"""
|
||||||
|
model = kwargs.get("model", None)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||||
|
|
||||||
|
if not cls.check_deps():
|
||||||
|
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
|
||||||
|
|
||||||
archs = getattr(model.config, "architectures", [])
|
archs = getattr(model.config, "architectures", [])
|
||||||
target_moe_mapping = None
|
target_moe_mapping = None
|
||||||
@@ -12,36 +12,71 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""The definition of NPU fused SwiGLU kernels.
|
||||||
|
|
||||||
|
Init Phase:
|
||||||
|
1. Define SwiGLU forward functions.
|
||||||
|
2. Register NPU fused SwiGLU kernel.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
from ......accelerator.helper import DeviceType
|
||||||
from .....utils.types import HFModel
|
from ......utils.types import HFModel
|
||||||
from ..constants import KernelType
|
from ...base import BaseKernel
|
||||||
from ..registry import MetaSwiGluKernel
|
from ...registry import register_kernel
|
||||||
|
|
||||||
|
|
||||||
def _npu_swiglu_forward(self, hidden_state):
|
try:
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def npu_swiglu_forward(self, hidden_state):
|
||||||
|
r"""SwiGLU forward pass for NPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The MLP layer instance.
|
||||||
|
hidden_state (Tensor): Input hidden state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output of SwiGLU.
|
||||||
|
"""
|
||||||
return self.down_proj(
|
return self.down_proj(
|
||||||
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _npu_swiglu_glm4_forward(self, hidden_states):
|
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||||
import torch_npu
|
r"""SwiGLU forward pass for GLM4 on NPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The GLM4 MLP layer instance.
|
||||||
|
hidden_states (Tensor): Input hidden states.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output of SwiGLU.
|
||||||
|
"""
|
||||||
up_states = self.gate_up_proj(hidden_states)
|
up_states = self.gate_up_proj(hidden_states)
|
||||||
gate, up_states = up_states.chunk(2, dim=-1)
|
gate, up_states = up_states.chunk(2, dim=-1)
|
||||||
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
|
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
|
||||||
|
|
||||||
|
|
||||||
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||||
import torch_npu
|
r"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The Gemma3nText MLP layer instance.
|
||||||
|
hidden_states (Tensor): Input hidden states.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output of SwiGLU.
|
||||||
|
"""
|
||||||
gate_proj = self.gate_proj(hidden_states)
|
gate_proj = self.gate_proj(hidden_states)
|
||||||
if self.activation_sparsity > 0.0:
|
if self.activation_sparsity > 0.0:
|
||||||
gate_proj = self._gaussian_topk(gate_proj)
|
gate_proj = self._gaussian_topk(gate_proj)
|
||||||
@@ -51,12 +86,11 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
@register_kernel
|
||||||
type = KernelType.SWIGLU
|
class NpuSwiGluKernel(BaseKernel):
|
||||||
device = DeviceType.NPU
|
r"""NPU Kernel for fused SwiGLU activation."""
|
||||||
kernel = _npu_swiglu_forward
|
|
||||||
|
|
||||||
# Don't apply the kernel to the following modules
|
# just support apply to the following module layers
|
||||||
expect_modules = frozenset(
|
expect_modules = frozenset(
|
||||||
{
|
{
|
||||||
"Qwen3VLMoeTextMLP",
|
"Qwen3VLMoeTextMLP",
|
||||||
@@ -87,10 +121,29 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_kernel_id = "npu_fused_swiglu"
|
||||||
|
_device = DeviceType.NPU
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
def apply(cls, **kwargs) -> "HFModel":
|
||||||
if not is_torch_npu_available():
|
r"""Applies the NPU fused SwiGLU kernel to the model.
|
||||||
return model
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments containing the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HFModel: The model with patched SwiGLU forward functions.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model is not provided.
|
||||||
|
RuntimeError: If dependencies are not met.
|
||||||
|
"""
|
||||||
|
model = kwargs.get("model", None)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||||
|
|
||||||
|
if not cls.check_deps():
|
||||||
|
raise RuntimeError("torch_npu is not available but NpuSwiGluKernel was called.")
|
||||||
|
|
||||||
# Mapping of specific mlp modules to their corresponding kernel implementations
|
# Mapping of specific mlp modules to their corresponding kernel implementations
|
||||||
kernel_mapping = {
|
kernel_mapping = {
|
||||||
@@ -109,7 +162,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
|||||||
):
|
):
|
||||||
# Bind function as an instance method to preserve `self` semantics
|
# Bind function as an instance method to preserve `self` semantics
|
||||||
# and replace the original forward
|
# and replace the original forward
|
||||||
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
|
kernel_func = kernel_mapping.get(module.__class__.__name__, npu_swiglu_forward)
|
||||||
module.forward = types.MethodType(kernel_func, module)
|
module.forward = types.MethodType(kernel_func, module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@@ -11,40 +11,49 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""The definition of NPU fused RMSNorm kernels.
|
||||||
|
|
||||||
|
Init Phase:
|
||||||
|
1. Define RMSNorm forward function.
|
||||||
|
2. Register NPU fused RMSNorm kernel.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
|
|
||||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
from ......accelerator.helper import DeviceType
|
||||||
from .....utils.types import HFModel
|
from ......utils.types import HFModel
|
||||||
from ..constants import KernelType
|
from ...base import BaseKernel
|
||||||
from ..registry import MetaRMSNormKernel
|
from ...registry import register_kernel
|
||||||
|
|
||||||
|
|
||||||
def _npu_rms_forward(self, hidden_states):
|
def npu_rms_norm_forward(self, hidden_states):
|
||||||
"""NPU forward implementation for RMSNorm.
|
r"""NPU forward implementation for RMSNorm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||||
hidden_states: Input hidden states tensor, same shape as the baseline.
|
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized tensor consistent with the baseline RMSNorm behavior.
|
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||||
"""
|
"""
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||||
|
|
||||||
|
|
||||||
class NpuRMSNormKernel(MetaRMSNormKernel):
|
@register_kernel
|
||||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
class NpuRMSNormKernel(BaseKernel):
|
||||||
|
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||||
|
|
||||||
type = KernelType.RMSNORM
|
_kernel_id = "npu_fused_rmsnorm"
|
||||||
device = DeviceType.NPU
|
_device = DeviceType.NPU
|
||||||
kernel = _npu_rms_forward
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> HFModel:
|
def apply(cls, **kwargs) -> "HFModel":
|
||||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||||
|
|
||||||
Key points:
|
Key points:
|
||||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||||
@@ -52,10 +61,23 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
|||||||
replace the original `forward`.
|
replace the original `forward`.
|
||||||
- Do not modify weights, hyperparameters, or module structure to ensure
|
- Do not modify weights, hyperparameters, or module structure to ensure
|
||||||
numerical behavior and interface consistency.
|
numerical behavior and interface consistency.
|
||||||
"""
|
|
||||||
if not is_torch_npu_available():
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments containing the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HFModel: The model with NPU fused RMSNorm.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If torch_npu is not available.
|
||||||
|
ValueError: If the model is not provided.
|
||||||
|
"""
|
||||||
|
model = kwargs.get("model")
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||||
|
|
||||||
|
if not cls.check_deps():
|
||||||
|
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
@@ -63,6 +85,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
|||||||
if re.search(rms_norm_pattern, module.__class__.__name__):
|
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||||
# Bind function as an instance method to preserve `self` semantics
|
# Bind function as an instance method to preserve `self` semantics
|
||||||
# and replace the original forward
|
# and replace the original forward
|
||||||
module.forward = types.MethodType(cls.kernel, module)
|
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""The definition of NPU fused RoPE kernels.
|
||||||
|
|
||||||
|
Init Phase:
|
||||||
|
1. Define RoPE forward functions.
|
||||||
|
2. Register NPU fused RoPE kernel.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ......accelerator.helper import DeviceType
|
||||||
|
from ......utils.logging import get_logger
|
||||||
|
from ......utils.types import HFModel
|
||||||
|
from ...base import BaseKernel
|
||||||
|
from ...registry import register_kernel
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_npu
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (Tensor): Query tensor.
|
||||||
|
k (Tensor): Key tensor.
|
||||||
|
cos (Tensor): Cosine part of embedding.
|
||||||
|
sin (Tensor): Sine part of embedding.
|
||||||
|
position_ids (Tensor, optional): Position IDs. Default: ``None``.
|
||||||
|
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||||
|
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||||
|
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (Tensor): Query tensor.
|
||||||
|
k (Tensor): Key tensor.
|
||||||
|
cos (Tensor): Cosine part of embedding.
|
||||||
|
sin (Tensor): Sine part of embedding.
|
||||||
|
mrope_section (Tensor): Multimodal RoPE section.
|
||||||
|
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (q_embed, k_embed) The embedded query and key tensors.
|
||||||
|
"""
|
||||||
|
mrope_section = mrope_section * 2
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||||
|
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
@register_kernel
|
||||||
|
class NpuRoPEKernel(BaseKernel):
|
||||||
|
r"""NPU Kernel for Rotary Position Embedding."""
|
||||||
|
|
||||||
|
_kernel_id = "npu_fused_rope"
|
||||||
|
_device = DeviceType.NPU
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, **kwargs) -> "HFModel":
|
||||||
|
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||||
|
|
||||||
|
This function iterates through the model's modules to find attention layers,
|
||||||
|
identifies the module where they are defined, and replaces the original
|
||||||
|
`apply_rotary_pos_emb` function in that module's namespace with the
|
||||||
|
NPU-accelerated version from this file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments containing the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HFModel: The model with patched RoPE functions.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If dependencies are not met.
|
||||||
|
ValueError: If the model is not provided.
|
||||||
|
"""
|
||||||
|
if not cls.check_deps():
|
||||||
|
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||||
|
model = kwargs.get("model", None)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||||
|
_modules = set()
|
||||||
|
for module in model.modules():
|
||||||
|
if "Attention" in module.__class__.__name__:
|
||||||
|
module_name = module.__class__.__module__
|
||||||
|
if module_name in _modules:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
target_module = sys.modules[module_name]
|
||||||
|
if hasattr(target_module, "apply_rotary_pos_emb"):
|
||||||
|
if getattr(target_module, "apply_rotary_pos_emb") is not _apply_rotary_pos_emb:
|
||||||
|
setattr(target_module, "apply_rotary_pos_emb", _apply_rotary_pos_emb)
|
||||||
|
_modules.add(module_name)
|
||||||
|
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
|
||||||
|
if (
|
||||||
|
getattr(target_module, "apply_multimodal_rotary_pos_emb")
|
||||||
|
is not _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||||
|
):
|
||||||
|
setattr(
|
||||||
|
target_module,
|
||||||
|
"apply_multimodal_rotary_pos_emb",
|
||||||
|
_apply_multimodal_rotary_pos_emb_qwen25_vl,
|
||||||
|
)
|
||||||
|
_modules.add(module_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
|
||||||
|
return model
|
||||||
@@ -12,247 +12,86 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, ABCMeta, abstractmethod
|
"""The definition of kernel registry.
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
Init Phase:
|
||||||
from ....utils.types import HFModel
|
1. Define kernel registry.
|
||||||
from .constants import KernelType
|
2. Register kernels.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ....accelerator.helper import get_current_accelerator
|
||||||
|
from .base import BaseKernel
|
||||||
|
|
||||||
|
|
||||||
class KernelRegistry:
|
__all__ = ["Registry", "register_kernel"]
|
||||||
_instance: Optional["KernelRegistry"] = None
|
|
||||||
_initialized: bool = False
|
|
||||||
|
|
||||||
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
if self._initialized:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
|
|
||||||
) -> None:
|
|
||||||
"""Register a kernel implementation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
|
|
||||||
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
|
|
||||||
kernel_impl: the actual kernel function or class.
|
|
||||||
"""
|
|
||||||
if kernel_type not in self._registry:
|
|
||||||
self._registry[kernel_type] = {}
|
|
||||||
|
|
||||||
if device_type in self._registry[kernel_type]:
|
|
||||||
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
|
|
||||||
|
|
||||||
self._registry[kernel_type][device_type] = kernel_impl
|
|
||||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
|
||||||
|
|
||||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
|
|
||||||
return self._registry.get(kernel_type, {}).get(device_type)
|
|
||||||
|
|
||||||
|
|
||||||
KERNEL_REGISTRY = KernelRegistry()
|
class Registry:
|
||||||
|
r"""Registry for managing kernel implementations.
|
||||||
|
|
||||||
|
Storage structure: ``{ "kernel_id": Class }``
|
||||||
class AutoRegisterKernelMeta(ABCMeta):
|
|
||||||
"""Metaclass that automatically registers kernel classes upon creation.
|
|
||||||
|
|
||||||
This metaclass checks if a newly created class has both `type` and `device`
|
|
||||||
attributes defined. If so, it automatically registers the kernel in the
|
|
||||||
global KERNEL_REGISTRY, eliminating the need for manual registration.
|
|
||||||
|
|
||||||
To disable auto-registration for a specific class, set `auto_register = False`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
_kernels: dict[str, type[BaseKernel]] = {}
|
||||||
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
|
|
||||||
|
|
||||||
# Check if auto-registration is disabled
|
|
||||||
auto_register = namespace.get("auto_register", True)
|
|
||||||
|
|
||||||
# Only auto-register if the class has both type and device attributes defined
|
|
||||||
# and they are not None (skip base classes like MetaKernel itself)
|
|
||||||
# and auto_register is True
|
|
||||||
kernel_type = namespace.get("type")
|
|
||||||
device_type = namespace.get("device")
|
|
||||||
|
|
||||||
if auto_register and kernel_type is not None and device_type is not None:
|
|
||||||
# Auto-register this kernel
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
|
||||||
"""Base class for all kernel implementations.
|
|
||||||
|
|
||||||
Subclasses are automatically registered when they define both `type` and `device`
|
|
||||||
attributes. To disable auto-registration, set `auto_register = False`.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
|
|
||||||
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
|
|
||||||
kernel: The actual kernel function or implementation.
|
|
||||||
auto_register: Set to False to disable automatic registration (default: True).
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: KernelType | None = None
|
|
||||||
device: DeviceType | None = None
|
|
||||||
kernel: Callable | None = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
def register(cls, kernel_cls: type[BaseKernel]):
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
r"""Decorator to register a kernel class.
|
||||||
"""Apply the kernel to the model.
|
|
||||||
|
|
||||||
This method should check if the kernel can be applied (e.g., dependencies
|
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
|
||||||
are installed, target modules exist) and perform the kernel replacement.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The HuggingFace model to optimize.
|
kernel_cls (type[BaseKernel]): The kernel class to register.
|
||||||
**kwargs: Additional arguments for kernel application.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The optimized model (may be the same object with modifications).
|
type[BaseKernel]: The registered kernel class.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the class does not inherit from :class:`BaseKernel`.
|
||||||
|
ValueError: If the kernel ID is missing or already registered.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
if not issubclass(kernel_cls, BaseKernel):
|
||||||
|
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
|
||||||
|
kernel_id = kernel_cls.get_kernel_id()
|
||||||
|
device = kernel_cls.get_device()
|
||||||
|
|
||||||
|
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
|
||||||
|
if device != get_current_accelerator().type:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not kernel_id:
|
||||||
|
raise ValueError(f"Kernel ID (_kernel_id) is needed for {kernel_cls} to register")
|
||||||
|
|
||||||
|
if kernel_id in cls._kernels:
|
||||||
|
raise ValueError(f"{kernel_id} already registered! The registered kernel is {cls._kernels[kernel_id]}")
|
||||||
|
|
||||||
|
cls._kernels[kernel_id] = kernel_cls
|
||||||
|
return kernel_cls
|
||||||
|
|
||||||
class MetaFlashAttentionKernel(MetaKernel):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
||||||
raise NotImplementedError
|
r"""Retrieves a registered kernel implementation by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_id (str): The ID of the kernel to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
|
||||||
|
"""
|
||||||
|
return cls._kernels.get(kernel_id)
|
||||||
|
|
||||||
class MetaRMSNormKernel(MetaKernel):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
|
||||||
raise NotImplementedError
|
r"""Returns a dictionary of all registered kernels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
|
||||||
|
"""
|
||||||
|
return cls._kernels
|
||||||
|
|
||||||
|
|
||||||
class MetaSwiGluKernel(MetaKernel):
|
# export decorator alias
|
||||||
@classmethod
|
register_kernel = Registry.register
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class MetaRoPEKernel(MetaKernel):
|
|
||||||
@classmethod
|
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class MetaMoEKernel(MetaKernel):
|
|
||||||
@classmethod
|
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_kernels_loaded() -> None:
|
|
||||||
"""Ensure all kernel implementations are imported and registered.
|
|
||||||
|
|
||||||
This function dynamically imports all kernel implementation modules to trigger
|
|
||||||
their auto-registration. Python's module system ensures each module is only
|
|
||||||
executed once (cached in sys.modules), so repeated calls are safe and fast.
|
|
||||||
"""
|
|
||||||
# List of kernel module paths to import
|
|
||||||
kernel_modules = [
|
|
||||||
"rms_norm.npu_rms_norm",
|
|
||||||
"rope.npu_rope",
|
|
||||||
"mlp.npu_swiglu",
|
|
||||||
"mlp.npu_fused_moe",
|
|
||||||
# Add new kernel modules here as they are created
|
|
||||||
]
|
|
||||||
|
|
||||||
# Import each module to trigger kernel registration
|
|
||||||
# Python's import system caches modules, so this is fast on subsequent calls
|
|
||||||
for module_name in kernel_modules:
|
|
||||||
try:
|
|
||||||
__import__(f"{__package__}.{module_name}", fromlist=["*"])
|
|
||||||
except ImportError:
|
|
||||||
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
|
||||||
"""Discover and return all kernel classes registered for the current device.
|
|
||||||
|
|
||||||
This function inspects the runtime environment (device type) and returns
|
|
||||||
all MetaKernel classes registered for that device. Each kernel's `apply()`
|
|
||||||
method is responsible for checking if it can actually be applied (e.g.,
|
|
||||||
required dependencies are installed, target modules exist in the model).
|
|
||||||
|
|
||||||
The function automatically discovers all kernels registered in KERNEL_REGISTRY
|
|
||||||
without requiring manual enumeration. On first call, it dynamically imports
|
|
||||||
all kernel implementation modules to trigger their auto-registration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The HuggingFace model to apply kernels to.
|
|
||||||
TODO: implement the kernel route detection logic by model structure.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of MetaKernel classes available for the current device.
|
|
||||||
"""
|
|
||||||
# Ensure all kernel modules are imported to trigger registration
|
|
||||||
_ensure_kernels_loaded()
|
|
||||||
|
|
||||||
discovered_kernels: list[type[MetaKernel]] = []
|
|
||||||
|
|
||||||
# Detect current device type
|
|
||||||
accelerator = get_current_accelerator()
|
|
||||||
try:
|
|
||||||
device_type = DeviceType(accelerator.type)
|
|
||||||
except ValueError:
|
|
||||||
# Unknown device type, return empty list
|
|
||||||
return discovered_kernels
|
|
||||||
|
|
||||||
# Skip CPU as it typically doesn't have optimized kernels
|
|
||||||
if device_type == DeviceType.CPU:
|
|
||||||
return discovered_kernels
|
|
||||||
|
|
||||||
# Iterate through registry and collect all kernels for current device
|
|
||||||
for devices in KERNEL_REGISTRY._registry.values():
|
|
||||||
kernel_cls = devices.get(device_type)
|
|
||||||
if kernel_cls is not None:
|
|
||||||
discovered_kernels.append(kernel_cls)
|
|
||||||
|
|
||||||
return discovered_kernels
|
|
||||||
|
|
||||||
|
|
||||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
|
|
||||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
|
||||||
|
|
||||||
Corresponding replacement logic is maintained inside each kernel; the only
|
|
||||||
requirement is that `apply` returns the replaced model.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
|
||||||
model = apply_kernel(model, NpuRMSNormKernel)
|
|
||||||
"""
|
|
||||||
if not issubclass(kernel, MetaKernel):
|
|
||||||
raise ValueError(f"{kernel} must be a MetaKernel instance.")
|
|
||||||
|
|
||||||
if kernel.device != get_current_accelerator().type:
|
|
||||||
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
|
|
||||||
|
|
||||||
return kernel.apply(model, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
|
||||||
"""Apply all available kernels to the model."""
|
|
||||||
for kernel in discover_kernels(model):
|
|
||||||
model = apply_kernel(model, kernel, **kwargs)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|||||||
@@ -1,122 +0,0 @@
|
|||||||
# Copyright 2025 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
|
||||||
from .....utils.types import HFModel
|
|
||||||
from ..constants import KernelType
|
|
||||||
from ..registry import MetaRoPEKernel
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
|
||||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
||||||
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
|
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
mrope_section = mrope_section * 2
|
|
||||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
|
||||||
unsqueeze_dim
|
|
||||||
)
|
|
||||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
|
||||||
unsqueeze_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
|
||||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class NpuRoPEKernel(MetaRoPEKernel):
|
|
||||||
type = KernelType.ROPE
|
|
||||||
device = DeviceType.NPU
|
|
||||||
kernel = _apply_rotary_pos_emb
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
|
||||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
|
||||||
|
|
||||||
This function iterates through the model's modules to find attention layers,
|
|
||||||
identifies the module where they are defined, and replaces the original
|
|
||||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
|
||||||
NPU-accelerated version from this file.
|
|
||||||
"""
|
|
||||||
if not is_torch_npu_available():
|
|
||||||
return model
|
|
||||||
|
|
||||||
_modules = set()
|
|
||||||
for module in model.modules():
|
|
||||||
if "Attention" in module.__class__.__name__:
|
|
||||||
module_name = module.__class__.__module__
|
|
||||||
if module_name in _modules:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
target_module = sys.modules[module_name]
|
|
||||||
if hasattr(target_module, "apply_rotary_pos_emb"):
|
|
||||||
if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
|
|
||||||
setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
|
|
||||||
_modules.add(module_name)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
|
||||||
"""Qwen2-VL specific RoPE kernel - not auto-registered.
|
|
||||||
|
|
||||||
This kernel is for specific models (Qwen2-VL) and should be manually
|
|
||||||
applied when needed rather than auto-discovered.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type = KernelType.ROPE
|
|
||||||
device = DeviceType.NPU
|
|
||||||
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
|
||||||
auto_register = False # Disable auto-registration for model-specific kernel
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
|
||||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
|
||||||
|
|
||||||
This function iterates through the model's modules to find attention layers,
|
|
||||||
identifies the module where they are defined, and replaces the original
|
|
||||||
`apply_rotary_pos_emb` function in that module's namespace with the
|
|
||||||
NPU-accelerated version from this file.
|
|
||||||
"""
|
|
||||||
_modules = set()
|
|
||||||
for module in model.modules():
|
|
||||||
if "Attention" in module.__class__.__name__:
|
|
||||||
module_name = module.__class__.__module__
|
|
||||||
if module_name in _modules:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
target_module = sys.modules[module_name]
|
|
||||||
if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
|
|
||||||
if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
|
|
||||||
setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
|
|
||||||
_modules.add(module_name)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return model
|
|
||||||
71
tests_v1/config/test_args_parser.py
Normal file
71
tests_v1/config/test_args_parser.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import pathlib
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from llamafactory.v1.config.arg_parser import get_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||||
|
config_yaml = """
|
||||||
|
### model
|
||||||
|
model: "llamafactory/tiny-random-qwen2.5"
|
||||||
|
trust_remote_code: true
|
||||||
|
use_fast_processor: true
|
||||||
|
model_class: "llm"
|
||||||
|
kernel_config:
|
||||||
|
name: "auto"
|
||||||
|
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||||
|
peft_config:
|
||||||
|
name: "lora"
|
||||||
|
lora_rank: 0.8
|
||||||
|
quant_config: null
|
||||||
|
|
||||||
|
### data
|
||||||
|
dataset: "llamafactory/tiny-supervised-dataset"
|
||||||
|
cutoff_len: 2048
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: "outputs/test_run"
|
||||||
|
micro_batch_size: 1
|
||||||
|
global_batch_size: 1
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: false
|
||||||
|
dist_config: null
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: "hf"
|
||||||
|
max_new_tokens: 128
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
config_file.write_text(config_yaml, encoding="utf-8")
|
||||||
|
|
||||||
|
test_argv = ["test_args_parser.py", str(config_file)]
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", test_argv):
|
||||||
|
data_args, model_args, training_args, sample_args = get_args()
|
||||||
|
assert training_args.output_dir == "outputs/test_run"
|
||||||
|
assert training_args.micro_batch_size == 1
|
||||||
|
assert training_args.global_batch_size == 1
|
||||||
|
assert training_args.learning_rate == 1.0e-4
|
||||||
|
assert training_args.bf16 is False
|
||||||
|
assert training_args.dist_config is None
|
||||||
|
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
|
||||||
|
assert model_args.kernel_config.name == "auto"
|
||||||
|
assert model_args.kernel_config.get("include_kernels") == "auto"
|
||||||
|
assert model_args.peft_config.name == "lora"
|
||||||
|
assert model_args.peft_config.get("lora_rank") == 0.8
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from llamafactory.v1.config.model_args import ModelArguments
|
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
||||||
from llamafactory.v1.core.model_loader import ModelLoader
|
from llamafactory.v1.core.model_loader import ModelLoader
|
||||||
|
|
||||||
|
|
||||||
@@ -29,5 +29,23 @@ def test_tiny_qwen():
|
|||||||
assert model_loader.model.dtype == torch.bfloat16
|
assert model_loader.model.dtype == torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
|
def test_tiny_qwen_with_kernel_plugin():
|
||||||
|
from transformers import Qwen2ForCausalLM
|
||||||
|
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
|
||||||
|
|
||||||
|
model_args = ModelArguments(
|
||||||
|
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
||||||
|
)
|
||||||
|
model_loader = ModelLoader(model_args)
|
||||||
|
# test enable apply kernel plugin
|
||||||
|
if hasattr(torch, "npu"):
|
||||||
|
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
|
||||||
|
else:
|
||||||
|
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||||
|
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_tiny_qwen()
|
test_tiny_qwen()
|
||||||
|
test_tiny_qwen_with_kernel_plugin()
|
||||||
|
|||||||
@@ -12,16 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from llamafactory.v1.accelerator.helper import get_current_accelerator
|
from llamafactory.v1.accelerator.helper import get_current_accelerator
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
|
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -29,24 +26,29 @@ def clear_accelerator_cache():
|
|||||||
get_current_accelerator.cache_clear()
|
get_current_accelerator.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def reload_kernels():
|
||||||
|
"""Helper to reload kernel modules to respect mocked accelerator."""
|
||||||
|
# Unload kernel interface and registry
|
||||||
|
keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")]
|
||||||
|
for k in keys_to_remove:
|
||||||
|
del sys.modules[k]
|
||||||
|
|
||||||
|
|
||||||
@patch("torch.accelerator.current_accelerator")
|
@patch("torch.accelerator.current_accelerator")
|
||||||
def test_apply_kernel(mock_get_accelerator: MagicMock):
|
def test_apply_kernel(mock_get_accelerator: MagicMock):
|
||||||
mock_device = MagicMock()
|
mock_device = MagicMock()
|
||||||
setattr(mock_device, "type", "npu")
|
setattr(mock_device, "type", "npu")
|
||||||
mock_get_accelerator.return_value = mock_device
|
mock_get_accelerator.return_value = mock_device
|
||||||
|
# Force reload of kernels with mocked accelerator
|
||||||
|
reload_kernels()
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
|
|
||||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
|
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
||||||
apply_kernel(model, npu_rope.NpuRoPEKernel)
|
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
|
||||||
|
assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__
|
||||||
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
|
|
||||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
|
||||||
|
|
||||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
|
||||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
|
||||||
|
|
||||||
|
|
||||||
@patch("torch.accelerator.current_accelerator")
|
@patch("torch.accelerator.current_accelerator")
|
||||||
@@ -56,12 +58,15 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
|||||||
setattr(mock_device, "type", "npu")
|
setattr(mock_device, "type", "npu")
|
||||||
mock_get_accelerator.return_value = mock_device
|
mock_get_accelerator.return_value = mock_device
|
||||||
|
|
||||||
|
# Force reload of kernels with mocked accelerator
|
||||||
|
reload_kernels()
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
|
|
||||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
|
|
||||||
model = apply_available_kernels(model)
|
model = apply_default_kernels(model=model, include_kernels=True)
|
||||||
|
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
|
||||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__
|
||||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
|
||||||
|
|||||||
Reference in New Issue
Block a user