[v1] add accelerator (#9607)

This commit is contained in:
Yaowei Zheng
2025-12-12 19:22:06 +08:00
committed by GitHub
parent 4fd94141a4
commit 203069e11c
36 changed files with 941 additions and 443 deletions

View File

@@ -21,10 +21,3 @@ class KernelType(str, Enum):
FLASH_ATTENTION = "flash_attention"
ROPE = "rope"
MOE = "moe"
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
NPU = "npu"
XPU = "xpu"

View File

@@ -18,10 +18,10 @@ import torch
import torch.nn.functional as F
import torch_npu
from .....accelerator.helper import is_torch_npu_available
from .....extras.packages import is_transformers_version_greater_than
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.packages import is_transformers_version_greater_than
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaMoEKernel

View File

@@ -17,9 +17,9 @@ import types
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaSwiGluKernel

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
from ....accelerator.helper import get_current_accelerator
from ....extras.types import HFModel
from .constants import DeviceType, KernelType
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
from .constants import KernelType
class KernelRegistry:
@@ -27,11 +27,13 @@ class KernelRegistry:
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True
@@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels
# Iterate through registry and collect all kernels for current device
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
for devices in KERNEL_REGISTRY._registry.values():
kernel_cls = devices.get(device_type)
if kernel_cls is not None:
discovered_kernels.append(kernel_cls)
@@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only
@@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
return kernel.apply(model, **kwargs)
if not issubclass(kernel, MetaKernel):
raise ValueError(f"{kernel} must be a MetaKernel instance.")
raise ValueError(
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
)
if kernel.device != get_current_accelerator().type:
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
return kernel.apply(model, **kwargs)
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
"""Apply all available kernels to the model."""
for kernel in discover_kernels(model):
model = apply_kernel(model, kernel, **kwargs)
return model

View File

@@ -14,9 +14,9 @@
import re
import types
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRMSNormKernel

View File

@@ -16,9 +16,9 @@ import sys
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ..constants import DeviceType, KernelType
from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....utils.types import HFModel
from ..constants import KernelType
from ..registry import MetaRoPEKernel

View File

@@ -0,0 +1,42 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model
from ...utils.plugin import BasePlugin
from ...utils.types import HFModel
class LoraConfigDict(TypedDict, total=False):
name: Literal["lora"]
"""Plugin name."""
r: int
"""Lora rank."""
lora_alpha: int
"""Lora alpha."""
target_modules: list[str]
"""Target modules."""
class PeftPlugin(BasePlugin):
pass
@PeftPlugin("lora").register
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model