mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-28 19:24:20 +08:00
[v1] support automatic discovery of registered kernels. (#9509)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
parent
591fc9ed02
commit
f17efde693
@ -20,7 +20,7 @@ import torch
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
|
||||
from ..registry import MetaSwiGluKernel
|
||||
|
||||
|
||||
def _npu_swiglu_forward(self, hidden_state):
|
||||
@ -31,25 +31,85 @@ def _npu_swiglu_forward(self, hidden_state):
|
||||
)
|
||||
|
||||
|
||||
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
import torch_npu
|
||||
|
||||
up_states = self.gate_up_proj(hidden_states)
|
||||
gate, up_states = up_states.chunk(2, dim=-1)
|
||||
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
|
||||
|
||||
|
||||
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
import torch_npu
|
||||
|
||||
gate_proj = self.gate_proj(hidden_states)
|
||||
if self.activation_sparsity > 0.0:
|
||||
gate_proj = self._gaussian_topk(gate_proj)
|
||||
down_proj = self.down_proj(
|
||||
torch_npu.npu_swiglu(torch.cat((gate_proj, self.up_proj(hidden_states)), dim=-1), dim=-1)
|
||||
)
|
||||
return down_proj
|
||||
|
||||
|
||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
type = KernelType.SWIGLU
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_swiglu_forward
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
# Don't apply the kernel to the following modules
|
||||
expect_modules = frozenset(
|
||||
{
|
||||
"Qwen3VLMoeTextMLP",
|
||||
"Qwen3VLTextMLP",
|
||||
"Qwen3OmniMoeThinkerTextMLP",
|
||||
"Qwen3OmniMoeMLP",
|
||||
"Qwen3OmniMoeTalkerTextMLP",
|
||||
"Qwen3OmniMoeCode2WavMlp",
|
||||
"Qwen3NextMLP",
|
||||
"Qwen3MoeMLP",
|
||||
"Qwen3MLP",
|
||||
"Qwen2MLP",
|
||||
"Qwen2MoeMLP",
|
||||
"Qwen2_5_VLMLP",
|
||||
"Qwen2_5OmniMLP",
|
||||
"Llama4TextMLP",
|
||||
"LlamaMLP",
|
||||
"Glm4MLP",
|
||||
"Glm4MoeMLP",
|
||||
"Glm4vMoeTextMLP",
|
||||
"Gemma3MLP",
|
||||
"Gemma2MLP",
|
||||
"Gemma3nTextMLP",
|
||||
"Phi3MLP",
|
||||
"DeepseekV2MLP",
|
||||
"DeepseekV3MLP",
|
||||
"SeedOssMLP",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
# Mapping of specific mlp modules to their corresponding kernel implementations
|
||||
kernel_mapping = {
|
||||
"Glm4MLP": _npu_swiglu_glm4_forward,
|
||||
"Glm4vTextMLP": _npu_swiglu_glm4_forward,
|
||||
"Phi3MLP": _npu_swiglu_glm4_forward,
|
||||
"Gemma3nTextMLP": _npu_swiglu_gemma3ntext_forward,
|
||||
}
|
||||
|
||||
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
|
||||
for name, module in model.named_modules():
|
||||
# Match any module whose class name contains "RMSNorm"
|
||||
if re.search(swiglu_pattern, module.__class__.__name__):
|
||||
# Match any module whose class name contains "MLP"
|
||||
if (
|
||||
re.search(swiglu_pattern, module.__class__.__name__)
|
||||
and module.__class__.__name__ in cls.expect_modules
|
||||
):
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
module.forward = types.MethodType(cls.kernel, module)
|
||||
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
|
||||
module.forward = types.MethodType(kernel_func, module)
|
||||
|
||||
return model
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ....extras.types import HFModel
|
||||
@ -61,18 +61,67 @@ class KernelRegistry:
|
||||
KERNEL_REGISTRY = KernelRegistry()
|
||||
|
||||
|
||||
class MetaKernel(ABC):
|
||||
class AutoRegisterKernelMeta(ABCMeta):
|
||||
"""Metaclass that automatically registers kernel classes upon creation.
|
||||
|
||||
This metaclass checks if a newly created class has both `type` and `device`
|
||||
attributes defined. If so, it automatically registers the kernel in the
|
||||
global KERNEL_REGISTRY, eliminating the need for manual registration.
|
||||
|
||||
To disable auto-registration for a specific class, set `auto_register = False`.
|
||||
"""
|
||||
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
|
||||
|
||||
# Check if auto-registration is disabled
|
||||
auto_register = namespace.get("auto_register", True)
|
||||
|
||||
# Only auto-register if the class has both type and device attributes defined
|
||||
# and they are not None (skip base classes like MetaKernel itself)
|
||||
# and auto_register is True
|
||||
kernel_type = namespace.get("type")
|
||||
device_type = namespace.get("device")
|
||||
|
||||
if auto_register and kernel_type is not None and device_type is not None:
|
||||
# Auto-register this kernel
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||
"""Base class for all kernel implementations.
|
||||
|
||||
Subclasses are automatically registered when they define both `type` and `device`
|
||||
attributes. To disable auto-registration, set `auto_register = False`.
|
||||
|
||||
Attributes:
|
||||
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
|
||||
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
|
||||
kernel: The actual kernel function or implementation.
|
||||
auto_register: Set to False to disable automatic registration (default: True).
|
||||
"""
|
||||
|
||||
type: Optional[KernelType] = None
|
||||
device: Optional[DeviceType] = None
|
||||
kernel: Optional[Callable] = None
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
"""Apply the kernel to the model.
|
||||
|
||||
This method should check if the kernel can be applied (e.g., dependencies
|
||||
are installed, target modules exist) and perform the kernel replacement.
|
||||
|
||||
Args:
|
||||
model: The HuggingFace model to optimize.
|
||||
**kwargs: Additional arguments for kernel application.
|
||||
|
||||
Returns:
|
||||
The optimized model (may be the same object with modifications).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -106,16 +155,75 @@ class MetaMoEKernel(MetaKernel):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def discover_kernels(model: HFModel) -> list[MetaKernel]:
|
||||
"""Discover and construct MetaKernel instances for the current model/device.
|
||||
def _ensure_kernels_loaded() -> None:
|
||||
"""Ensure all kernel implementations are imported and registered.
|
||||
|
||||
This is a placeholder to be implemented: it should inspect the runtime
|
||||
environment (device type, available extensions, model architecture) and
|
||||
return an ordered list of MetaKernel instances to be applied. Each returned
|
||||
MetaKernel must encapsulate its own replacement logic in `apply`.
|
||||
This function dynamically imports all kernel implementation modules to trigger
|
||||
their auto-registration. Python's module system ensures each module is only
|
||||
executed once (cached in sys.modules), so repeated calls are safe and fast.
|
||||
"""
|
||||
# TODO: Implement auto discovery logic based on registry and device capabilities.
|
||||
return []
|
||||
# List of kernel module paths to import
|
||||
kernel_modules = [
|
||||
"rms_norm.npu_rms_norm",
|
||||
"rope.npu_rope",
|
||||
"mlp.npu_swiglu",
|
||||
"mlp.npu_fused_moe",
|
||||
# Add new kernel modules here as they are created
|
||||
]
|
||||
|
||||
# Import each module to trigger kernel registration
|
||||
# Python's import system caches modules, so this is fast on subsequent calls
|
||||
for module_name in kernel_modules:
|
||||
try:
|
||||
__import__(f"{__package__}.{module_name}", fromlist=["*"])
|
||||
except ImportError:
|
||||
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
|
||||
pass
|
||||
|
||||
|
||||
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
"""Discover and return all kernel classes registered for the current device.
|
||||
|
||||
This function inspects the runtime environment (device type) and returns
|
||||
all MetaKernel classes registered for that device. Each kernel's `apply()`
|
||||
method is responsible for checking if it can actually be applied (e.g.,
|
||||
required dependencies are installed, target modules exist in the model).
|
||||
|
||||
The function automatically discovers all kernels registered in KERNEL_REGISTRY
|
||||
without requiring manual enumeration. On first call, it dynamically imports
|
||||
all kernel implementation modules to trigger their auto-registration.
|
||||
|
||||
Args:
|
||||
model: The HuggingFace model to apply kernels to.
|
||||
TODO: implement the kernel route detection logic by model structure.
|
||||
|
||||
Returns:
|
||||
A list of MetaKernel classes available for the current device.
|
||||
"""
|
||||
# Ensure all kernel modules are imported to trigger registration
|
||||
_ensure_kernels_loaded()
|
||||
|
||||
discovered_kernels: list[type[MetaKernel]] = []
|
||||
|
||||
# Detect current device type
|
||||
accelerator = get_available_accelerator()
|
||||
try:
|
||||
device_type = DeviceType(accelerator.type)
|
||||
except ValueError:
|
||||
# Unknown device type, return empty list
|
||||
return discovered_kernels
|
||||
|
||||
# Skip CPU as it typically doesn't have optimized kernels
|
||||
if device_type == DeviceType.CPU:
|
||||
return discovered_kernels
|
||||
|
||||
# Iterate through registry and collect all kernels for current device
|
||||
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
|
||||
kernel_cls = devices.get(device_type)
|
||||
if kernel_cls is not None:
|
||||
discovered_kernels.append(kernel_cls)
|
||||
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
||||
@ -136,3 +244,10 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
||||
raise ValueError(
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
||||
)
|
||||
|
||||
|
||||
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||
"""Apply all available kernels to the model."""
|
||||
for kernel in discover_kernels(model):
|
||||
model = apply_kernel(model, kernel, **kwargs)
|
||||
return model
|
||||
|
||||
@ -17,7 +17,7 @@ import types
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
|
||||
from ..registry import MetaRMSNormKernel
|
||||
|
||||
|
||||
def _npu_rms_forward(self, hidden_states):
|
||||
@ -38,14 +38,10 @@ def _npu_rms_forward(self, hidden_states):
|
||||
class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
|
||||
type = KernelType.RMSNORM
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_rms_forward
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
|
||||
"""Register the NPU RMSNorm forward implementation to the global registry."""
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> HFModel:
|
||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
|
||||
@ -19,7 +19,7 @@ import torch
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
@ -51,13 +51,10 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
|
||||
|
||||
|
||||
class NpuRoPEKernel(MetaRoPEKernel):
|
||||
type = KernelType.ROPE
|
||||
device = DeviceType.NPU
|
||||
kernel = _apply_rotary_pos_emb
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
@ -88,12 +85,16 @@ class NpuRoPEKernel(MetaRoPEKernel):
|
||||
|
||||
|
||||
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
||||
"""Qwen2-VL specific RoPE kernel - not auto-registered.
|
||||
|
||||
This kernel is for specific models (Qwen2-VL) and should be manually
|
||||
applied when needed rather than auto-discovered.
|
||||
"""
|
||||
|
||||
type = KernelType.ROPE
|
||||
device = DeviceType.NPU
|
||||
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||
|
||||
@classmethod
|
||||
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
auto_register = False # Disable auto-registration for model-specific kernel
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
|
||||
@ -42,3 +42,23 @@ class TestKernelPlugin(unittest.TestCase):
|
||||
|
||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
|
||||
|
||||
class Test_Use_V1_Kernels(unittest.TestCase):
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_use_v1_kernels(self, mock_get_accelerator):
|
||||
mock_device = MagicMock()
|
||||
mock_device.type = "npu"
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
||||
|
||||
model = apply_available_kernels(model)
|
||||
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user