mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-29 03:34:18 +08:00
[v1] support automatic discovery of registered kernels. (#9509)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
parent
591fc9ed02
commit
f17efde693
@ -20,7 +20,7 @@ import torch
|
|||||||
from .....extras.types import HFModel
|
from .....extras.types import HFModel
|
||||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import DeviceType, KernelType
|
||||||
from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
|
from ..registry import MetaSwiGluKernel
|
||||||
|
|
||||||
|
|
||||||
def _npu_swiglu_forward(self, hidden_state):
|
def _npu_swiglu_forward(self, hidden_state):
|
||||||
@ -31,25 +31,85 @@ def _npu_swiglu_forward(self, hidden_state):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate, up_states = up_states.chunk(2, dim=-1)
|
||||||
|
return self.down_proj(torch_npu.npu_swiglu(torch.cat((gate, up_states), dim=-1), dim=-1))
|
||||||
|
|
||||||
|
|
||||||
|
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
gate_proj = self.gate_proj(hidden_states)
|
||||||
|
if self.activation_sparsity > 0.0:
|
||||||
|
gate_proj = self._gaussian_topk(gate_proj)
|
||||||
|
down_proj = self.down_proj(
|
||||||
|
torch_npu.npu_swiglu(torch.cat((gate_proj, self.up_proj(hidden_states)), dim=-1), dim=-1)
|
||||||
|
)
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||||
|
type = KernelType.SWIGLU
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _npu_swiglu_forward
|
kernel = _npu_swiglu_forward
|
||||||
|
|
||||||
@classmethod
|
# Don't apply the kernel to the following modules
|
||||||
def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU):
|
expect_modules = frozenset(
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
{
|
||||||
|
"Qwen3VLMoeTextMLP",
|
||||||
|
"Qwen3VLTextMLP",
|
||||||
|
"Qwen3OmniMoeThinkerTextMLP",
|
||||||
|
"Qwen3OmniMoeMLP",
|
||||||
|
"Qwen3OmniMoeTalkerTextMLP",
|
||||||
|
"Qwen3OmniMoeCode2WavMlp",
|
||||||
|
"Qwen3NextMLP",
|
||||||
|
"Qwen3MoeMLP",
|
||||||
|
"Qwen3MLP",
|
||||||
|
"Qwen2MLP",
|
||||||
|
"Qwen2MoeMLP",
|
||||||
|
"Qwen2_5_VLMLP",
|
||||||
|
"Qwen2_5OmniMLP",
|
||||||
|
"Llama4TextMLP",
|
||||||
|
"LlamaMLP",
|
||||||
|
"Glm4MLP",
|
||||||
|
"Glm4MoeMLP",
|
||||||
|
"Glm4vMoeTextMLP",
|
||||||
|
"Gemma3MLP",
|
||||||
|
"Gemma2MLP",
|
||||||
|
"Gemma3nTextMLP",
|
||||||
|
"Phi3MLP",
|
||||||
|
"DeepseekV2MLP",
|
||||||
|
"DeepseekV3MLP",
|
||||||
|
"SeedOssMLP",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
if not is_torch_npu_available():
|
if not is_torch_npu_available():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
# Mapping of specific mlp modules to their corresponding kernel implementations
|
||||||
|
kernel_mapping = {
|
||||||
|
"Glm4MLP": _npu_swiglu_glm4_forward,
|
||||||
|
"Glm4vTextMLP": _npu_swiglu_glm4_forward,
|
||||||
|
"Phi3MLP": _npu_swiglu_glm4_forward,
|
||||||
|
"Gemma3nTextMLP": _npu_swiglu_gemma3ntext_forward,
|
||||||
|
}
|
||||||
|
|
||||||
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
|
swiglu_pattern = re.compile("MLP", re.IGNORECASE)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
# Match any module whose class name contains "RMSNorm"
|
# Match any module whose class name contains "MLP"
|
||||||
if re.search(swiglu_pattern, module.__class__.__name__):
|
if (
|
||||||
|
re.search(swiglu_pattern, module.__class__.__name__)
|
||||||
|
and module.__class__.__name__ in cls.expect_modules
|
||||||
|
):
|
||||||
# Bind function as an instance method to preserve `self` semantics
|
# Bind function as an instance method to preserve `self` semantics
|
||||||
# and replace the original forward
|
# and replace the original forward
|
||||||
module.forward = types.MethodType(cls.kernel, module)
|
kernel_func = kernel_mapping.get(module.__class__.__name__, _npu_swiglu_forward)
|
||||||
|
module.forward = types.MethodType(kernel_func, module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, ABCMeta, abstractmethod
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from ....extras.types import HFModel
|
from ....extras.types import HFModel
|
||||||
@ -61,18 +61,67 @@ class KernelRegistry:
|
|||||||
KERNEL_REGISTRY = KernelRegistry()
|
KERNEL_REGISTRY = KernelRegistry()
|
||||||
|
|
||||||
|
|
||||||
class MetaKernel(ABC):
|
class AutoRegisterKernelMeta(ABCMeta):
|
||||||
|
"""Metaclass that automatically registers kernel classes upon creation.
|
||||||
|
|
||||||
|
This metaclass checks if a newly created class has both `type` and `device`
|
||||||
|
attributes defined. If so, it automatically registers the kernel in the
|
||||||
|
global KERNEL_REGISTRY, eliminating the need for manual registration.
|
||||||
|
|
||||||
|
To disable auto-registration for a specific class, set `auto_register = False`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||||
|
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
|
||||||
|
|
||||||
|
# Check if auto-registration is disabled
|
||||||
|
auto_register = namespace.get("auto_register", True)
|
||||||
|
|
||||||
|
# Only auto-register if the class has both type and device attributes defined
|
||||||
|
# and they are not None (skip base classes like MetaKernel itself)
|
||||||
|
# and auto_register is True
|
||||||
|
kernel_type = namespace.get("type")
|
||||||
|
device_type = namespace.get("device")
|
||||||
|
|
||||||
|
if auto_register and kernel_type is not None and device_type is not None:
|
||||||
|
# Auto-register this kernel
|
||||||
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||||
|
"""Base class for all kernel implementations.
|
||||||
|
|
||||||
|
Subclasses are automatically registered when they define both `type` and `device`
|
||||||
|
attributes. To disable auto-registration, set `auto_register = False`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: The kernel type (e.g., KernelType.RMSNORM). Must be set in subclasses.
|
||||||
|
device: The device type (e.g., DeviceType.NPU). Must be set in subclasses.
|
||||||
|
kernel: The actual kernel function or implementation.
|
||||||
|
auto_register: Set to False to disable automatic registration (default: True).
|
||||||
|
"""
|
||||||
|
|
||||||
type: Optional[KernelType] = None
|
type: Optional[KernelType] = None
|
||||||
device: Optional[DeviceType] = None
|
device: Optional[DeviceType] = None
|
||||||
kernel: Optional[Callable] = None
|
kernel: Optional[Callable] = None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
|
"""Apply the kernel to the model.
|
||||||
|
|
||||||
|
This method should check if the kernel can be applied (e.g., dependencies
|
||||||
|
are installed, target modules exist) and perform the kernel replacement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The HuggingFace model to optimize.
|
||||||
|
**kwargs: Additional arguments for kernel application.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The optimized model (may be the same object with modifications).
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -106,16 +155,75 @@ class MetaMoEKernel(MetaKernel):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def discover_kernels(model: HFModel) -> list[MetaKernel]:
|
def _ensure_kernels_loaded() -> None:
|
||||||
"""Discover and construct MetaKernel instances for the current model/device.
|
"""Ensure all kernel implementations are imported and registered.
|
||||||
|
|
||||||
This is a placeholder to be implemented: it should inspect the runtime
|
This function dynamically imports all kernel implementation modules to trigger
|
||||||
environment (device type, available extensions, model architecture) and
|
their auto-registration. Python's module system ensures each module is only
|
||||||
return an ordered list of MetaKernel instances to be applied. Each returned
|
executed once (cached in sys.modules), so repeated calls are safe and fast.
|
||||||
MetaKernel must encapsulate its own replacement logic in `apply`.
|
|
||||||
"""
|
"""
|
||||||
# TODO: Implement auto discovery logic based on registry and device capabilities.
|
# List of kernel module paths to import
|
||||||
return []
|
kernel_modules = [
|
||||||
|
"rms_norm.npu_rms_norm",
|
||||||
|
"rope.npu_rope",
|
||||||
|
"mlp.npu_swiglu",
|
||||||
|
"mlp.npu_fused_moe",
|
||||||
|
# Add new kernel modules here as they are created
|
||||||
|
]
|
||||||
|
|
||||||
|
# Import each module to trigger kernel registration
|
||||||
|
# Python's import system caches modules, so this is fast on subsequent calls
|
||||||
|
for module_name in kernel_modules:
|
||||||
|
try:
|
||||||
|
__import__(f"{__package__}.{module_name}", fromlist=["*"])
|
||||||
|
except ImportError:
|
||||||
|
# Silently ignore import errors (e.g., missing dependencies like torch_npu)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||||
|
"""Discover and return all kernel classes registered for the current device.
|
||||||
|
|
||||||
|
This function inspects the runtime environment (device type) and returns
|
||||||
|
all MetaKernel classes registered for that device. Each kernel's `apply()`
|
||||||
|
method is responsible for checking if it can actually be applied (e.g.,
|
||||||
|
required dependencies are installed, target modules exist in the model).
|
||||||
|
|
||||||
|
The function automatically discovers all kernels registered in KERNEL_REGISTRY
|
||||||
|
without requiring manual enumeration. On first call, it dynamically imports
|
||||||
|
all kernel implementation modules to trigger their auto-registration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The HuggingFace model to apply kernels to.
|
||||||
|
TODO: implement the kernel route detection logic by model structure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of MetaKernel classes available for the current device.
|
||||||
|
"""
|
||||||
|
# Ensure all kernel modules are imported to trigger registration
|
||||||
|
_ensure_kernels_loaded()
|
||||||
|
|
||||||
|
discovered_kernels: list[type[MetaKernel]] = []
|
||||||
|
|
||||||
|
# Detect current device type
|
||||||
|
accelerator = get_available_accelerator()
|
||||||
|
try:
|
||||||
|
device_type = DeviceType(accelerator.type)
|
||||||
|
except ValueError:
|
||||||
|
# Unknown device type, return empty list
|
||||||
|
return discovered_kernels
|
||||||
|
|
||||||
|
# Skip CPU as it typically doesn't have optimized kernels
|
||||||
|
if device_type == DeviceType.CPU:
|
||||||
|
return discovered_kernels
|
||||||
|
|
||||||
|
# Iterate through registry and collect all kernels for current device
|
||||||
|
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
|
||||||
|
kernel_cls = devices.get(device_type)
|
||||||
|
if kernel_cls is not None:
|
||||||
|
discovered_kernels.append(kernel_cls)
|
||||||
|
|
||||||
|
return discovered_kernels
|
||||||
|
|
||||||
|
|
||||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
||||||
@ -136,3 +244,10 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||||
|
"""Apply all available kernels to the model."""
|
||||||
|
for kernel in discover_kernels(model):
|
||||||
|
model = apply_kernel(model, kernel, **kwargs)
|
||||||
|
return model
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import types
|
|||||||
from .....extras.types import HFModel
|
from .....extras.types import HFModel
|
||||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import DeviceType, KernelType
|
||||||
from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
|
from ..registry import MetaRMSNormKernel
|
||||||
|
|
||||||
|
|
||||||
def _npu_rms_forward(self, hidden_states):
|
def _npu_rms_forward(self, hidden_states):
|
||||||
@ -38,14 +38,10 @@ def _npu_rms_forward(self, hidden_states):
|
|||||||
class NpuRMSNormKernel(MetaRMSNormKernel):
|
class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||||
|
|
||||||
|
type = KernelType.RMSNORM
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _npu_rms_forward
|
kernel = _npu_rms_forward
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
|
|
||||||
"""Register the NPU RMSNorm forward implementation to the global registry."""
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> HFModel:
|
def apply(cls, model, **kwargs) -> HFModel:
|
||||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import torch
|
|||||||
from .....extras.types import HFModel
|
from .....extras.types import HFModel
|
||||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
from ..constants import DeviceType, KernelType
|
from ..constants import DeviceType, KernelType
|
||||||
from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
|
from ..registry import MetaRoPEKernel
|
||||||
|
|
||||||
|
|
||||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
@ -51,13 +51,10 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
|
|||||||
|
|
||||||
|
|
||||||
class NpuRoPEKernel(MetaRoPEKernel):
|
class NpuRoPEKernel(MetaRoPEKernel):
|
||||||
|
type = KernelType.ROPE
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _apply_rotary_pos_emb
|
kernel = _apply_rotary_pos_emb
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||||
@ -88,12 +85,16 @@ class NpuRoPEKernel(MetaRoPEKernel):
|
|||||||
|
|
||||||
|
|
||||||
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
||||||
|
"""Qwen2-VL specific RoPE kernel - not auto-registered.
|
||||||
|
|
||||||
|
This kernel is for specific models (Qwen2-VL) and should be manually
|
||||||
|
applied when needed rather than auto-discovered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type = KernelType.ROPE
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
|
||||||
|
auto_register = False # Disable auto-registration for model-specific kernel
|
||||||
@classmethod
|
|
||||||
def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
|
|
||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> "HFModel":
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
|
|||||||
@ -42,3 +42,23 @@ class TestKernelPlugin(unittest.TestCase):
|
|||||||
|
|
||||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||||
|
|
||||||
|
|
||||||
|
class Test_Use_V1_Kernels(unittest.TestCase):
|
||||||
|
@patch("torch.accelerator.current_accelerator")
|
||||||
|
def test_use_v1_kernels(self, mock_get_accelerator):
|
||||||
|
mock_device = MagicMock()
|
||||||
|
mock_device.type = "npu"
|
||||||
|
mock_get_accelerator.return_value = mock_device
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
|
|
||||||
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
|
|
||||||
|
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
||||||
|
|
||||||
|
model = apply_available_kernels(model)
|
||||||
|
|
||||||
|
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||||
|
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user