[v1] support automatic discovery of registered kernels. (#9509)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
浮梦 2025-11-27 01:47:22 +08:00 committed by GitHub
parent 591fc9ed02
commit f17efde693
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 228 additions and 36 deletions

View File

@ -20,7 +20,7 @@ import torch
from .....extras.types import HFModel from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel from ..registry import MetaSwiGluKernel
def _npu_swiglu_forward(self, hidden_state): def _npu_swiglu_forward(self, hidden_state):
@ -31,25 +31,85 @@ def _npu_swiglu_forward(self, hidden_state):
) )
def _npu_swiglu_glm4_forward(self, hidden_states):
import torch_npu
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
import torch_npu
gate_proj = self.gate_proj(hidden_states)
if self.activation_sparsity > 0.0:
gate_proj = self._gaussian_topk(gate_proj)
down_proj = self.down_proj(
torch_npu.npu_swiglu(torch.cat((gate_proj, self.up_proj(hidden_states)), dim=-1), dim=-1)
)
return down_proj
class NpuSwiGluKernel(MetaSwiGluKernel): class NpuSwiGluKernel(MetaSwiGluKernel):
type = KernelType.SWIGLU
device = DeviceType.NPU device = DeviceType.NPU
kernel = _npu_swiglu_forward kernel = _npu_swiglu_forward
@classmethod # Don't apply the kernel to the following modules
def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU): expect_modules = frozenset(
KERNEL_REGISTRY.register(kernel_type, device_type, cls) {
"Qwen3VLMoeTextMLP",
"Qwen3VLTextMLP",
"Qwen3OmniMoeThinkerTextMLP",
"Qwen3OmniMoeMLP",
"Qwen3OmniMoeTalkerTextMLP",
"Qwen3OmniMoeCode2WavMlp",
"Qwen3NextMLP",
"Qwen3MoeMLP",
"Qwen3MLP",
"Qwen2MLP",
"Qwen2MoeMLP",
"Qwen2_5_VLMLP",
"Qwen2_5OmniMLP",
"Llama4TextMLP",
"LlamaMLP",
"Glm4MLP",
"Glm4MoeMLP",
"Glm4vMoeTextMLP",
"Gemma3MLP",
"Gemma2MLP",
"Gemma3nTextMLP",
"Phi3MLP",
"DeepseekV2MLP",
"DeepseekV3MLP",
"SeedOssMLP",
}
)
@classmethod @classmethod
def apply(cls, model, **kwargs) -> "HFModel": def apply(cls, model, **kwargs) -> "HFModel":
if not is_torch_npu_available(): if not is_torch_npu_available():
return model return model
# Mapping of specific mlp modules to their corresponding kernel implementations
kernel_mapping = {
"Glm4MLP": _npu_swiglu_glm4_forward,
"Glm4vTextMLP": _npu_swiglu_glm4_forward,
"Phi3MLP": _npu_swiglu_glm4_forward,
"Gemma3nTextMLP": _npu_swiglu_gemma3ntext_forward,
}
swiglu_pattern = re.compile("MLP", re.IGNORECASE) swiglu_pattern = re.compile("MLP", re.IGNORECASE)
for name, module in model.named_modules(): for name, module in model.named_modules():
# Match any module whose class name contains "RMSNorm" # Match any module whose class name contains "MLP"
if re.search(swiglu_pattern, module.__class__.__name__): if (
re.search(swiglu_pattern, module.__class__.__name__)
and module.__class__.__name__ in cls.expect_modules
):
# Bind function as an instance method to preserve `self` semantics # Bind function as an instance method to preserve `self` semantics
# and replace the original forward # and replace the original forward
module.forward = types.MethodType(cls.kernel, module) kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
module.forward = types.MethodType(kernel_func, module)
return model return model

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from ....extras.types import HFModel from ....extras.types import HFModel
@ -61,18 +61,67 @@ class KernelRegistry:
KERNEL_REGISTRY = KernelRegistry() KERNEL_REGISTRY = KernelRegistry()
class MetaKernel(ABC): class AutoRegisterKernelMeta(ABCMeta):
"""Metaclass that automatically registers kernel classes upon creation.
This metaclass checks if a newly created class has both `type` and `device`
attributes defined. If so, it automatically registers the kernel in the
global KERNEL_REGISTRY, eliminating the need for manual registration.
To disable auto-registration for a specific class, set `auto_register = False`.
"""
def __new__(mcs, name, bases, namespace, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
# Check if auto-registration is disabled
auto_register = namespace.get("auto_register", True)
# Only auto-register if the class has both type and device attributes defined
# and they are not None (skip base classes like MetaKernel itself)
# and auto_register is True
kernel_type = namespace.get("type")
device_type = namespace.get("device")
if auto_register and kernel_type is not None and device_type is not None:
# Auto-register this kernel
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
return cls
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
"""Base class for all kernel implementations.
Subclasses are automatically registered when they define both `type` and `device`
attributes. To disable auto-registration, set `auto_register = False`.
Attributes:
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
kernel: The actual kernel function or implementation.
auto_register: Set to False to disable automatic registration (default: True).
"""
type: Optional[KernelType] = None type: Optional[KernelType] = None
device: Optional[DeviceType] = None device: Optional[DeviceType] = None
kernel: Optional[Callable] = None kernel: Optional[Callable] = None
@classmethod
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod @classmethod
@abstractmethod @abstractmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel: def apply(cls, model: HFModel, **kwargs) -> HFModel:
"""Apply the kernel to the model.
This method should check if the kernel can be applied (e.g., dependencies
are installed, target modules exist) and perform the kernel replacement.
Args:
model: The HuggingFace model to optimize.
**kwargs: Additional arguments for kernel application.
Returns:
The optimized model (may be the same object with modifications).
"""
raise NotImplementedError raise NotImplementedError
@ -106,16 +155,75 @@ class MetaMoEKernel(MetaKernel):
raise NotImplementedError raise NotImplementedError
def discover_kernels(model: HFModel) -> list[MetaKernel]: def _ensure_kernels_loaded() -> None:
"""Discover and construct MetaKernel instances for the current model/device. """Ensure all kernel implementations are imported and registered.
This is a placeholder to be implemented: it should inspect the runtime This function dynamically imports all kernel implementation modules to trigger
environment (device type, available extensions, model architecture) and their auto-registration. Python's module system ensures each module is only
return an ordered list of MetaKernel instances to be applied. Each returned executed once (cached in sys.modules), so repeated calls are safe and fast.
MetaKernel must encapsulate its own replacement logic in `apply`.
""" """
# TODO: Implement auto discovery logic based on registry and device capabilities. # List of kernel module paths to import
return [] kernel_modules = [
"rms_norm.npu_rms_norm",
"rope.npu_rope",
"mlp.npu_swiglu",
"mlp.npu_fused_moe",
# Add new kernel modules here as they are created
]
# Import each module to trigger kernel registration
# Python's import system caches modules, so this is fast on subsequent calls
for module_name in kernel_modules:
try:
__import__(f"{__package__}.{module_name}", fromlist=["*"])
except ImportError:
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
pass
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
"""Discover and return all kernel classes registered for the current device.
This function inspects the runtime environment (device type) and returns
all MetaKernel classes registered for that device. Each kernel's `apply()`
method is responsible for checking if it can actually be applied (e.g.,
required dependencies are installed, target modules exist in the model).
The function automatically discovers all kernels registered in KERNEL_REGISTRY
without requiring manual enumeration. On first call, it dynamically imports
all kernel implementation modules to trigger their auto-registration.
Args:
model: The HuggingFace model to apply kernels to.
TODO: implement the kernel route detection logic by model structure.
Returns:
A list of MetaKernel classes available for the current device.
"""
# Ensure all kernel modules are imported to trigger registration
_ensure_kernels_loaded()
discovered_kernels: list[type[MetaKernel]] = []
# Detect current device type
accelerator = get_available_accelerator()
try:
device_type = DeviceType(accelerator.type)
except ValueError:
# Unknown device type, return empty list
return discovered_kernels
# Skip CPU as it typically doesn't have optimized kernels
if device_type == DeviceType.CPU:
return discovered_kernels
# Iterate through registry and collect all kernels for current device
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
kernel_cls = devices.get(device_type)
if kernel_cls is not None:
discovered_kernels.append(kernel_cls)
return discovered_kernels
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel": def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
@ -136,3 +244,10 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
raise ValueError( raise ValueError(
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead." f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
) )
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
"""Apply all available kernels to the model."""
for kernel in discover_kernels(model):
model = apply_kernel(model, kernel, **kwargs)
return model

View File

@ -17,7 +17,7 @@ import types
from .....extras.types import HFModel from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel from ..registry import MetaRMSNormKernel
def _npu_rms_forward(self, hidden_states): def _npu_rms_forward(self, hidden_states):
@ -38,14 +38,10 @@ def _npu_rms_forward(self, hidden_states):
class NpuRMSNormKernel(MetaRMSNormKernel): class NpuRMSNormKernel(MetaRMSNormKernel):
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" """NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
type = KernelType.RMSNORM
device = DeviceType.NPU device = DeviceType.NPU
kernel = _npu_rms_forward kernel = _npu_rms_forward
@classmethod
def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
"""Register the NPU RMSNorm forward implementation to the global registry."""
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod @classmethod
def apply(cls, model, **kwargs) -> HFModel: def apply(cls, model, **kwargs) -> HFModel:
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.

View File

@ -19,7 +19,7 @@ import torch
from .....extras.types import HFModel from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType from ..constants import DeviceType, KernelType
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel from ..registry import MetaRoPEKernel
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
@ -51,13 +51,10 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
class NpuRoPEKernel(MetaRoPEKernel): class NpuRoPEKernel(MetaRoPEKernel):
type = KernelType.ROPE
device = DeviceType.NPU device = DeviceType.NPU
kernel = _apply_rotary_pos_emb kernel = _apply_rotary_pos_emb
@classmethod
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod @classmethod
def apply(cls, model, **kwargs) -> "HFModel": def apply(cls, model, **kwargs) -> "HFModel":
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
@ -88,12 +85,16 @@ class NpuRoPEKernel(MetaRoPEKernel):
class NpuQwen2VLRoPEKernel(MetaRoPEKernel): class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
"""Qwen2-VL specific RoPE kernel - not auto-registered.
This kernel is for specific models (Qwen2-VL) and should be manually
applied when needed rather than auto-discovered.
"""
type = KernelType.ROPE
device = DeviceType.NPU device = DeviceType.NPU
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
auto_register = False # Disable auto-registration for model-specific kernel
@classmethod
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod @classmethod
def apply(cls, model, **kwargs) -> "HFModel": def apply(cls, model, **kwargs) -> "HFModel":

View File

@ -42,3 +42,23 @@ class TestKernelPlugin(unittest.TestCase):
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel) model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward assert model.model.layers[0].mlp.forward is not original_swiglu_forward
class Test_Use_V1_Kernels(unittest.TestCase):
@patch("torch.accelerator.current_accelerator")
def test_use_v1_kernels(self, mock_get_accelerator):
mock_device = MagicMock()
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels
model = apply_available_kernels(model)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward