mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-07 12:15:59 +08:00
[v1] add models & accelerator (#9579)
This commit is contained in:
@@ -18,9 +18,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.packages import is_transformers_version_greater_than
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaSwiGluKernel
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from ....extras.types import HFModel
|
||||
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
|
||||
from .constants import DeviceType, KernelType
|
||||
|
||||
|
||||
@@ -206,7 +206,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
discovered_kernels: list[type[MetaKernel]] = []
|
||||
|
||||
# Detect current device type
|
||||
accelerator = get_available_accelerator()
|
||||
accelerator = get_current_accelerator()
|
||||
try:
|
||||
device_type = DeviceType(accelerator.type)
|
||||
except ValueError:
|
||||
@@ -238,11 +238,11 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaRMSNormKernel
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
Reference in New Issue
Block a user