From f17efde6933d751d0d8363b6de2a9871aebe28c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=AE=E6=A2=A6?= <46097299+frozenleaves@users.noreply.github.com> Date: Thu, 27 Nov 2025 01:47:22 +0800 Subject: [PATCH] [v1] support automatic discovery of registered kernels. (#9509) Co-authored-by: frozenleaves --- .../model_plugins/kernels/mlp/npu_swiglu.py | 74 ++++++++- .../plugins/model_plugins/kernels/registry.py | 143 ++++++++++++++++-- .../kernels/rms_norm/npu_rms_norm.py | 8 +- .../model_plugins/kernels/rope/npu_rope.py | 19 +-- .../model_plugins/test_kernel_plugin.py | 20 +++ 5 files changed, 228 insertions(+), 36 deletions(-) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py index 702d27bc..e396a496 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py @@ -20,7 +20,7 @@ import torch from .....extras.types import HFModel from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ..constants import DeviceType, KernelType -from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel +from ..registry import MetaSwiGluKernel def _npu_swiglu_forward(self, hidden_state): @@ -31,25 +31,85 @@ def _npu_swiglu_forward(self, hidden_state): ) +def _npu_swiglu_glm4_forward(self, hidden_states): + import torch_npu + + up_states = self.gate_up_proj(hidden_states) + gate, up_states = up_states.chunk(2, dim=-1) + return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1)) + + +def _npu_swiglu_gemma3ntext_forward(self, hidden_states): + import torch_npu + + gate_proj = self.gate_proj(hidden_states) + if self.activation_sparsity > 0.0: + gate_proj = self._gaussian_topk(gate_proj) + down_proj = self.down_proj( + torch_npu.npu_swiglu(torch.cat((gate_proj, self.up_proj(hidden_states)), dim=-1), dim=-1) + ) + return down_proj + + class NpuSwiGluKernel(MetaSwiGluKernel): + type = KernelType.SWIGLU device = DeviceType.NPU kernel = _npu_swiglu_forward - @classmethod - def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU): - KERNEL_REGISTRY.register(kernel_type, device_type, cls) + # Don't apply the kernel to the following modules + expect_modules = frozenset( + { + "Qwen3VLMoeTextMLP", + "Qwen3VLTextMLP", + "Qwen3OmniMoeThinkerTextMLP", + "Qwen3OmniMoeMLP", + "Qwen3OmniMoeTalkerTextMLP", + "Qwen3OmniMoeCode2WavMlp", + "Qwen3NextMLP", + "Qwen3MoeMLP", + "Qwen3MLP", + "Qwen2MLP", + "Qwen2MoeMLP", + "Qwen2_5_VLMLP", + "Qwen2_5OmniMLP", + "Llama4TextMLP", + "LlamaMLP", + "Glm4MLP", + "Glm4MoeMLP", + "Glm4vMoeTextMLP", + "Gemma3MLP", + "Gemma2MLP", + "Gemma3nTextMLP", + "Phi3MLP", + "DeepseekV2MLP", + "DeepseekV3MLP", + "SeedOssMLP", + } + ) @classmethod def apply(cls, model, **kwargs) -> "HFModel": if not is_torch_npu_available(): return model + # Mapping of specific mlp modules to their corresponding kernel implementations + kernel_mapping = { + "Glm4MLP": _npu_swiglu_glm4_forward, + "Glm4vTextMLP": _npu_swiglu_glm4_forward, + "Phi3MLP": _npu_swiglu_glm4_forward, + "Gemma3nTextMLP": _npu_swiglu_gemma3ntext_forward, + } + swiglu_pattern = re.compile("MLP", re.IGNORECASE) for name, module in model.named_modules(): - # Match any module whose class name contains "RMSNorm" - if re.search(swiglu_pattern, module.__class__.__name__): + # Match any module whose class name contains "MLP" + if ( + re.search(swiglu_pattern, module.__class__.__name__) + and module.__class__.__name__ in cls.expect_modules + ): # Bind function as an instance method to preserve `self` semantics # and replace the original forward - module.forward = types.MethodType(cls.kernel, module) + kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward) + module.forward = types.MethodType(kernel_func, module) return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index 45e14a09..b677b8a6 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from typing import Any, Callable, Optional from ....extras.types import HFModel @@ -61,18 +61,67 @@ class KernelRegistry: KERNEL_REGISTRY = KernelRegistry() -class MetaKernel(ABC): +class AutoRegisterKernelMeta(ABCMeta): + """Metaclass that automatically registers kernel classes upon creation. + + This metaclass checks if a newly created class has both `type` and `device` + attributes defined. If so, it automatically registers the kernel in the + global KERNEL_REGISTRY, eliminating the need for manual registration. + + To disable auto-registration for a specific class, set `auto_register = False`. + """ + + def __new__(mcs, name, bases, namespace, **kwargs): + cls = super().__new__(mcs, name, bases, namespace, **kwargs) + + # Check if auto-registration is disabled + auto_register = namespace.get("auto_register", True) + + # Only auto-register if the class has both type and device attributes defined + # and they are not None (skip base classes like MetaKernel itself) + # and auto_register is True + kernel_type = namespace.get("type") + device_type = namespace.get("device") + + if auto_register and kernel_type is not None and device_type is not None: + # Auto-register this kernel + KERNEL_REGISTRY.register(kernel_type, device_type, cls) + + return cls + + +class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta): + """Base class for all kernel implementations. + + Subclasses are automatically registered when they define both `type` and `device` + attributes. To disable auto-registration, set `auto_register = False`. + + Attributes: + type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses. + device: The device type (e.g., DeviceType.NPU). Must be set in subclasses. + kernel: The actual kernel function or implementation. + auto_register: Set to False to disable automatic registration (default: True). + """ + type: Optional[KernelType] = None device: Optional[DeviceType] = None kernel: Optional[Callable] = None - @classmethod - def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType): - KERNEL_REGISTRY.register(kernel_type, device_type, cls) - @classmethod @abstractmethod def apply(cls, model: HFModel, **kwargs) -> HFModel: + """Apply the kernel to the model. + + This method should check if the kernel can be applied (e.g., dependencies + are installed, target modules exist) and perform the kernel replacement. + + Args: + model: The HuggingFace model to optimize. + **kwargs: Additional arguments for kernel application. + + Returns: + The optimized model (may be the same object with modifications). + """ raise NotImplementedError @@ -106,16 +155,75 @@ class MetaMoEKernel(MetaKernel): raise NotImplementedError -def discover_kernels(model: HFModel) -> list[MetaKernel]: - """Discover and construct MetaKernel instances for the current model/device. +def _ensure_kernels_loaded() -> None: + """Ensure all kernel implementations are imported and registered. - This is a placeholder to be implemented: it should inspect the runtime - environment (device type, available extensions, model architecture) and - return an ordered list of MetaKernel instances to be applied. Each returned - MetaKernel must encapsulate its own replacement logic in `apply`. + This function dynamically imports all kernel implementation modules to trigger + their auto-registration. Python's module system ensures each module is only + executed once (cached in sys.modules), so repeated calls are safe and fast. """ - # TODO: Implement auto discovery logic based on registry and device capabilities. - return [] + # List of kernel module paths to import + kernel_modules = [ + "rms_norm.npu_rms_norm", + "rope.npu_rope", + "mlp.npu_swiglu", + "mlp.npu_fused_moe", + # Add new kernel modules here as they are created + ] + + # Import each module to trigger kernel registration + # Python's import system caches modules, so this is fast on subsequent calls + for module_name in kernel_modules: + try: + __import__(f"{__package__}.{module_name}", fromlist=["*"]) + except ImportError: + # Silently ignore import errors (e.g., missing dependencies like torch_npu) + pass + + +def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]: + """Discover and return all kernel classes registered for the current device. + + This function inspects the runtime environment (device type) and returns + all MetaKernel classes registered for that device. Each kernel's `apply()` + method is responsible for checking if it can actually be applied (e.g., + required dependencies are installed, target modules exist in the model). + + The function automatically discovers all kernels registered in KERNEL_REGISTRY + without requiring manual enumeration. On first call, it dynamically imports + all kernel implementation modules to trigger their auto-registration. + + Args: + model: The HuggingFace model to apply kernels to. + TODO: implement the kernel route detection logic by model structure. + + Returns: + A list of MetaKernel classes available for the current device. + """ + # Ensure all kernel modules are imported to trigger registration + _ensure_kernels_loaded() + + discovered_kernels: list[type[MetaKernel]] = [] + + # Detect current device type + accelerator = get_available_accelerator() + try: + device_type = DeviceType(accelerator.type) + except ValueError: + # Unknown device type, return empty list + return discovered_kernels + + # Skip CPU as it typically doesn't have optimized kernels + if device_type == DeviceType.CPU: + return discovered_kernels + + # Iterate through registry and collect all kernels for current device + for kernel_type, devices in KERNEL_REGISTRY._registry.items(): + kernel_cls = devices.get(device_type) + if kernel_cls is not None: + discovered_kernels.append(kernel_cls) + + return discovered_kernels def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel": @@ -136,3 +244,10 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo raise ValueError( f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead." ) + + +def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel": + """Apply all available kernels to the model.""" + for kernel in discover_kernels(model): + model = apply_kernel(model, kernel, **kwargs) + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py index d6f032b9..ba51f332 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py @@ -17,7 +17,7 @@ import types from .....extras.types import HFModel from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ..constants import DeviceType, KernelType -from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel +from ..registry import MetaRMSNormKernel def _npu_rms_forward(self, hidden_states): @@ -38,14 +38,10 @@ def _npu_rms_forward(self, hidden_states): class NpuRMSNormKernel(MetaRMSNormKernel): """NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" + type = KernelType.RMSNORM device = DeviceType.NPU kernel = _npu_rms_forward - @classmethod - def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU): - """Register the NPU RMSNorm forward implementation to the global registry.""" - KERNEL_REGISTRY.register(kernel_type, device_type, cls) - @classmethod def apply(cls, model, **kwargs) -> HFModel: """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py index 8cb40575..5e877f0a 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py @@ -19,7 +19,7 @@ import torch from .....extras.types import HFModel from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ..constants import DeviceType, KernelType -from ..registry import KERNEL_REGISTRY, MetaRoPEKernel +from ..registry import MetaRoPEKernel def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -51,13 +51,10 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un class NpuRoPEKernel(MetaRoPEKernel): + type = KernelType.ROPE device = DeviceType.NPU kernel = _apply_rotary_pos_emb - @classmethod - def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU): - KERNEL_REGISTRY.register(kernel_type, device_type, cls) - @classmethod def apply(cls, model, **kwargs) -> "HFModel": """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. @@ -88,12 +85,16 @@ class NpuRoPEKernel(MetaRoPEKernel): class NpuQwen2VLRoPEKernel(MetaRoPEKernel): + """Qwen2-VL specific RoPE kernel - not auto-registered. + + This kernel is for specific models (Qwen2-VL) and should be manually + applied when needed rather than auto-discovered. + """ + + type = KernelType.ROPE device = DeviceType.NPU kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl - - @classmethod - def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU): - KERNEL_REGISTRY.register(kernel_type, device_type, cls) + auto_register = False # Disable auto-registration for model-specific kernel @classmethod def apply(cls, model, **kwargs) -> "HFModel": diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index 2830d8c5..06276c33 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -42,3 +42,23 @@ class TestKernelPlugin(unittest.TestCase): model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel) assert model.model.layers[0].mlp.forward is not original_swiglu_forward + + +class Test_Use_V1_Kernels(unittest.TestCase): + @patch("torch.accelerator.current_accelerator") + def test_use_v1_kernels(self, mock_get_accelerator): + mock_device = MagicMock() + mock_device.type = "npu" + mock_get_accelerator.return_value = mock_device + + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") + + original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward + original_swiglu_forward = model.model.layers[0].mlp.forward + + from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels + + model = apply_available_kernels(model) + + assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward + assert model.model.layers[0].mlp.forward is not original_swiglu_forward