mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[v1] kernel plugin (#9274)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
		
							parent
							
								
									d9d67ba62d
								
							
						
					
					
						commit
						2c6aded5d4
					
				@ -12,7 +12,9 @@
 | 
				
			|||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
 | 
					from typing import TYPE_CHECKING, Literal, TypedDict, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing_extensions import NotRequired
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TYPE_CHECKING:
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
				
			|||||||
@ -13,7 +13,9 @@
 | 
				
			|||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Callable, NotRequired, TypedDict
 | 
					from typing import Callable, TypedDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing_extensions import NotRequired
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ...extras.types import Sample, SFTSample
 | 
					from ...extras.types import Sample, SFTSample
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KernelType(str, Enum):
 | 
				
			||||||
 | 
					    RMSNORM = "rmsnorm"
 | 
				
			||||||
 | 
					    SWIGLU = "swiglu"
 | 
				
			||||||
 | 
					    FLASH_ATTENTION = "flash_attention"
 | 
				
			||||||
 | 
					    ROPE = "rope"
 | 
				
			||||||
 | 
					    MOE = "moe"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DeviceType(str, Enum):
 | 
				
			||||||
 | 
					    CPU = 'cpu'
 | 
				
			||||||
 | 
					    CUDA = 'cuda'
 | 
				
			||||||
 | 
					    NPU = 'npu'
 | 
				
			||||||
 | 
					    XPU = 'xpu'
 | 
				
			||||||
@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
@ -0,0 +1,59 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					import types
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .....extras.types import HFModel
 | 
				
			||||||
 | 
					from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
 | 
				
			||||||
 | 
					from ..constants import DeviceType, KernelType
 | 
				
			||||||
 | 
					from ..registry import KERNEL_REGISTRY, MetaSwiGluKernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _npu_swiglu_forward(self, hidden_state):
 | 
				
			||||||
 | 
					    import torch_npu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return self.down_proj(
 | 
				
			||||||
 | 
					        torch_npu.npu_swiglu(
 | 
				
			||||||
 | 
					            torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NpuSwiGluKernel(MetaSwiGluKernel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device = DeviceType.NPU
 | 
				
			||||||
 | 
					    kernel = _npu_swiglu_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def register_kernel(cls, kernel_type=KernelType.SWIGLU, device_type=DeviceType.NPU):
 | 
				
			||||||
 | 
					        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model, **kwargs) -> 'HFModel':
 | 
				
			||||||
 | 
					        if not is_torch_npu_available():
 | 
				
			||||||
 | 
					            return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        swiglu_pattern = re.compile("MLP", re.IGNORECASE)
 | 
				
			||||||
 | 
					        for name, module in model.named_modules():
 | 
				
			||||||
 | 
					            # Match any module whose class name contains "RMSNorm"
 | 
				
			||||||
 | 
					            if re.search(swiglu_pattern, module.__class__.__name__):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Bind function as an instance method to preserve `self` semantics
 | 
				
			||||||
 | 
					                # and replace the original forward
 | 
				
			||||||
 | 
					                module.forward = types.MethodType(cls.kernel, module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
							
								
								
									
										148
									
								
								src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,148 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
 | 
					from typing import Any, Callable, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ....extras.types import HFModel
 | 
				
			||||||
 | 
					from ...trainer_plugins.distributed.accelerate import get_available_accelerator
 | 
				
			||||||
 | 
					from .constants import DeviceType, KernelType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KernelRegistry:
 | 
				
			||||||
 | 
					    _instance: Optional['KernelRegistry'] = None
 | 
				
			||||||
 | 
					    _initialized: bool = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry':
 | 
				
			||||||
 | 
					        if cls._instance is None:
 | 
				
			||||||
 | 
					            cls._instance = super().__new__(cls)
 | 
				
			||||||
 | 
					        return cls._instance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self) -> None:
 | 
				
			||||||
 | 
					        if self._initialized:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
 | 
				
			||||||
 | 
					        self._initialized = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def register(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        kernel_type: KernelType,
 | 
				
			||||||
 | 
					        device_type: DeviceType,
 | 
				
			||||||
 | 
					        kernel_impl: Optional[Callable[..., Any]]
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """Register a kernel implementation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
 | 
				
			||||||
 | 
					            device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
 | 
				
			||||||
 | 
					            kernel_impl: the actual kernel function or class.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if kernel_type not in self._registry:
 | 
				
			||||||
 | 
					            self._registry[kernel_type] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if device_type in self._registry[kernel_type]:
 | 
				
			||||||
 | 
					            print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._registry[kernel_type][device_type] = kernel_impl
 | 
				
			||||||
 | 
					        print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_kernel(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        kernel_type: KernelType,
 | 
				
			||||||
 | 
					        device_type: DeviceType
 | 
				
			||||||
 | 
					    ) -> Optional[Callable[..., Any]]:
 | 
				
			||||||
 | 
					        return self._registry.get(kernel_type, {}).get(device_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KERNEL_REGISTRY = KernelRegistry()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetaKernel(ABC):
 | 
				
			||||||
 | 
					    type: Optional[KernelType] = None
 | 
				
			||||||
 | 
					    device: Optional[DeviceType] = None
 | 
				
			||||||
 | 
					    kernel: Optional[Callable] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
 | 
				
			||||||
 | 
					        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetaFlashAttentionKernel(MetaKernel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetaRMSNormKernel(MetaKernel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetaSwiGluKernel(MetaKernel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetaRoPEKernel(MetaKernel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MetaMoEKernel(MetaKernel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def discover_kernels(model: HFModel) -> list[MetaKernel]:
 | 
				
			||||||
 | 
					    """Discover and construct MetaKernel instances for the current model/device.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This is a placeholder to be implemented: it should inspect the runtime
 | 
				
			||||||
 | 
					    environment (device type, available extensions, model architecture) and
 | 
				
			||||||
 | 
					    return an ordered list of MetaKernel instances to be applied. Each returned
 | 
				
			||||||
 | 
					    MetaKernel must encapsulate its own replacement logic in `apply`.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # TODO: Implement auto discovery logic based on registry and device capabilities.
 | 
				
			||||||
 | 
					    return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel':
 | 
				
			||||||
 | 
					    """Call the MetaKernel's `apply` to perform the replacement.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Corresponding replacement logic is maintained inside each kernel; the only
 | 
				
			||||||
 | 
					    requirement is that `apply` returns the replaced model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Example:
 | 
				
			||||||
 | 
					        from transformers import AutoModelForCausalLM
 | 
				
			||||||
 | 
					        from .rms_norm.npu_rms_norm import NpuRMSNormKernel
 | 
				
			||||||
 | 
					        model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
 | 
				
			||||||
 | 
					        model = apply_kernel(model, NpuRMSNormKernel)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
 | 
				
			||||||
 | 
					        return kernel.apply(model, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.")
 | 
				
			||||||
@ -0,0 +1,73 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					import types
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .....extras.types import HFModel
 | 
				
			||||||
 | 
					from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
 | 
				
			||||||
 | 
					from ..constants import DeviceType, KernelType
 | 
				
			||||||
 | 
					from ..registry import KERNEL_REGISTRY, MetaRMSNormKernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _npu_rms_forward(self, hidden_states):
 | 
				
			||||||
 | 
					    """NPU forward implementation for RMSNorm.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        self: RMSNorm module instance with `weight` and `variance_epsilon`.
 | 
				
			||||||
 | 
					        hidden_states: Input hidden states tensor, same shape as the baseline.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        Normalized tensor consistent with the baseline RMSNorm behavior.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    import torch_npu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NpuRMSNormKernel(MetaRMSNormKernel):
 | 
				
			||||||
 | 
					    """NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    device = DeviceType.NPU
 | 
				
			||||||
 | 
					    kernel = _npu_rms_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def register_kernel(cls, kernel_type=KernelType.RMSNORM, device_type=DeviceType.NPU):
 | 
				
			||||||
 | 
					        """Register the NPU RMSNorm forward implementation to the global registry."""
 | 
				
			||||||
 | 
					        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model, **kwargs) -> HFModel:
 | 
				
			||||||
 | 
					        """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Key points:
 | 
				
			||||||
 | 
					        - Match modules whose class name contains "RMSNorm" (case-insensitive).
 | 
				
			||||||
 | 
					        - Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
 | 
				
			||||||
 | 
					          replace the original `forward`.
 | 
				
			||||||
 | 
					        - Do not modify weights, hyperparameters, or module structure to ensure
 | 
				
			||||||
 | 
					          numerical behavior and interface consistency.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not is_torch_npu_available():
 | 
				
			||||||
 | 
					            return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for name, module in model.named_modules():
 | 
				
			||||||
 | 
					            # Match any module whose class name contains "RMSNorm"
 | 
				
			||||||
 | 
					            if re.search(rms_norm_pattern, module.__class__.__name__):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Bind function as an instance method to preserve `self` semantics
 | 
				
			||||||
 | 
					                # and replace the original forward
 | 
				
			||||||
 | 
					                module.forward = types.MethodType(cls.kernel, module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
@ -0,0 +1,121 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .....extras.types import HFModel
 | 
				
			||||||
 | 
					from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
 | 
				
			||||||
 | 
					from ..constants import DeviceType, KernelType
 | 
				
			||||||
 | 
					from ..registry import KERNEL_REGISTRY, MetaRoPEKernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 | 
				
			||||||
 | 
					    """Applies Rotary Position Embedding to the query and key tensors."""
 | 
				
			||||||
 | 
					    import torch_npu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cos = cos.unsqueeze(unsqueeze_dim)
 | 
				
			||||||
 | 
					    sin = sin.unsqueeze(unsqueeze_dim)
 | 
				
			||||||
 | 
					    q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
 | 
				
			||||||
 | 
					    k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
 | 
				
			||||||
 | 
					    return q_embed, k_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
 | 
				
			||||||
 | 
					    """Applies Rotary Position Embedding with multimodal sections (Qwen2-VL)."""
 | 
				
			||||||
 | 
					    import torch_npu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mrope_section = mrope_section * 2
 | 
				
			||||||
 | 
					    cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
 | 
				
			||||||
 | 
					        unsqueeze_dim
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
 | 
				
			||||||
 | 
					        unsqueeze_dim
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
 | 
				
			||||||
 | 
					    k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
 | 
				
			||||||
 | 
					    return q_embed, k_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NpuRoPEKernel(MetaRoPEKernel):
 | 
				
			||||||
 | 
					    device = DeviceType.NPU
 | 
				
			||||||
 | 
					    kernel = _apply_rotary_pos_emb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
 | 
				
			||||||
 | 
					        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model, **kwargs) -> 'HFModel':
 | 
				
			||||||
 | 
					        """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This function iterates through the model's modules to find attention layers,
 | 
				
			||||||
 | 
					        identifies the module where they are defined, and replaces the original
 | 
				
			||||||
 | 
					        `apply_rotary_pos_emb` function in that module's namespace with the
 | 
				
			||||||
 | 
					        NPU-accelerated version from this file.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not is_torch_npu_available():
 | 
				
			||||||
 | 
					            return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        _modules = set()
 | 
				
			||||||
 | 
					        for module in model.modules():
 | 
				
			||||||
 | 
					            if "Attention" in module.__class__.__name__:
 | 
				
			||||||
 | 
					                module_name = module.__class__.__module__
 | 
				
			||||||
 | 
					                if module_name in _modules:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    target_module = sys.modules[module_name]
 | 
				
			||||||
 | 
					                    if hasattr(target_module, "apply_rotary_pos_emb"):
 | 
				
			||||||
 | 
					                        if getattr(target_module, "apply_rotary_pos_emb") is not cls.kernel:
 | 
				
			||||||
 | 
					                            setattr(target_module, "apply_rotary_pos_emb", cls.kernel)
 | 
				
			||||||
 | 
					                            _modules.add(module_name)
 | 
				
			||||||
 | 
					                except Exception:
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
 | 
				
			||||||
 | 
					    device = DeviceType.NPU
 | 
				
			||||||
 | 
					    kernel = _apply_multimodal_rotary_pos_emb_qwen25_vl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def register_kernel(cls, kernel_type=KernelType.ROPE, device_type=DeviceType.NPU):
 | 
				
			||||||
 | 
					        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def apply(cls, model, **kwargs) -> 'HFModel':
 | 
				
			||||||
 | 
					        """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This function iterates through the model's modules to find attention layers,
 | 
				
			||||||
 | 
					        identifies the module where they are defined, and replaces the original
 | 
				
			||||||
 | 
					        `apply_rotary_pos_emb` function in that module's namespace with the
 | 
				
			||||||
 | 
					        NPU-accelerated version from this file.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        _modules = set()
 | 
				
			||||||
 | 
					        for module in model.modules():
 | 
				
			||||||
 | 
					            if "Attention" in module.__class__.__name__:
 | 
				
			||||||
 | 
					                module_name = module.__class__.__module__
 | 
				
			||||||
 | 
					                if module_name in _modules:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    target_module = sys.modules[module_name]
 | 
				
			||||||
 | 
					                    if hasattr(target_module, "apply_multimodal_rotary_pos_emb"):
 | 
				
			||||||
 | 
					                        if getattr(target_module, "apply_multimodal_rotary_pos_emb") is not cls.kernel:
 | 
				
			||||||
 | 
					                            setattr(target_module, "apply_multimodal_rotary_pos_emb", cls.kernel)
 | 
				
			||||||
 | 
					                            _modules.add(module_name)
 | 
				
			||||||
 | 
					                except Exception:
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
@ -0,0 +1,47 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					from functools import lru_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_available_accelerator():
 | 
				
			||||||
 | 
					    """Get available accelerator in current environment.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    accelerator = torch.accelerator.current_accelerator()
 | 
				
			||||||
 | 
					    if accelerator is None:
 | 
				
			||||||
 | 
					        return torch.device('cpu')
 | 
				
			||||||
 | 
					    return accelerator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@lru_cache
 | 
				
			||||||
 | 
					def is_torch_npu_available():
 | 
				
			||||||
 | 
					    return get_available_accelerator().type == 'npu'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@lru_cache
 | 
				
			||||||
 | 
					def is_torch_cuda_available():
 | 
				
			||||||
 | 
					    return get_available_accelerator().type == 'cuda'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@lru_cache
 | 
				
			||||||
 | 
					def is_torch_xpu_available():
 | 
				
			||||||
 | 
					    return get_available_accelerator().type == 'xpu'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@lru_cache
 | 
				
			||||||
 | 
					def is_torch_mps_available():
 | 
				
			||||||
 | 
					    return get_available_accelerator().type == 'mps'
 | 
				
			||||||
							
								
								
									
										46
									
								
								tests_v1/plugins/model_plugins/test_kernel_plugin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests_v1/plugins/model_plugins/test_kernel_plugin.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,46 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					from unittest.mock import MagicMock, patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from transformers import AutoModelForCausalLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestKernelPlugin(unittest.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @patch('torch.accelerator.current_accelerator')
 | 
				
			||||||
 | 
					    def test_apply_kernel(self, mock_get_accelerator):
 | 
				
			||||||
 | 
					        mock_device = MagicMock()
 | 
				
			||||||
 | 
					        mock_device.type = 'npu'
 | 
				
			||||||
 | 
					        mock_get_accelerator.return_value = mock_device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
 | 
				
			||||||
 | 
					        original_swiglu_forward = model.model.layers[0].mlp.forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
 | 
				
			||||||
 | 
					        from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
 | 
				
			||||||
 | 
					        from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
 | 
				
			||||||
 | 
					        from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        apply_kernel(model, npu_rope.NpuRoPEKernel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
 | 
				
			||||||
 | 
					        assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
 | 
				
			||||||
 | 
					        assert model.model.layers[0].mlp.forward is not original_swiglu_forward
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user