mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-07 12:15:59 +08:00
[v1] add accelerator (#9607)
This commit is contained in:
@@ -21,10 +21,3 @@ class KernelType(str, Enum):
|
||||
FLASH_ATTENTION = "flash_attention"
|
||||
ROPE = "rope"
|
||||
MOE = "moe"
|
||||
|
||||
|
||||
class DeviceType(str, Enum):
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
NPU = "npu"
|
||||
XPU = "xpu"
|
||||
|
||||
@@ -18,10 +18,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.packages import is_transformers_version_greater_than
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.packages import is_transformers_version_greater_than
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaSwiGluKernel
|
||||
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from ....extras.types import HFModel
|
||||
from .constants import DeviceType, KernelType
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
from .constants import KernelType
|
||||
|
||||
|
||||
class KernelRegistry:
|
||||
@@ -27,11 +27,13 @@ class KernelRegistry:
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
|
||||
self._initialized = True
|
||||
|
||||
@@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
# Iterate through registry and collect all kernels for current device
|
||||
for kernel_type, devices in KERNEL_REGISTRY._registry.items():
|
||||
for devices in KERNEL_REGISTRY._registry.values():
|
||||
kernel_cls = devices.get(device_type)
|
||||
if kernel_cls is not None:
|
||||
discovered_kernels.append(kernel_cls)
|
||||
@@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
||||
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
@@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
|
||||
return kernel.apply(model, **kwargs)
|
||||
if not issubclass(kernel, MetaKernel):
|
||||
raise ValueError(f"{kernel} must be a MetaKernel instance.")
|
||||
|
||||
raise ValueError(
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
|
||||
)
|
||||
if kernel.device != get_current_accelerator().type:
|
||||
raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
|
||||
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
|
||||
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
|
||||
"""Apply all available kernels to the model."""
|
||||
for kernel in discover_kernels(model):
|
||||
model = apply_kernel(model, kernel, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRMSNormKernel
|
||||
|
||||
|
||||
|
||||
@@ -16,9 +16,9 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ..constants import DeviceType, KernelType
|
||||
from .....accelerator.helper import DeviceType, is_torch_npu_available
|
||||
from .....utils.types import HFModel
|
||||
from ..constants import KernelType
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import HFModel
|
||||
|
||||
|
||||
class LoraConfigDict(TypedDict, total=False):
|
||||
name: Literal["lora"]
|
||||
"""Plugin name."""
|
||||
r: int
|
||||
"""Lora rank."""
|
||||
lora_alpha: int
|
||||
"""Lora alpha."""
|
||||
target_modules: list[str]
|
||||
"""Target modules."""
|
||||
|
||||
|
||||
class PeftPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
|
||||
@PeftPlugin("lora").register
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user