[v1] add models & accelerator (#9579)

This commit is contained in:
Yaowei Zheng
2025-12-08 02:30:25 +08:00
committed by GitHub
parent 739954910a
commit 5744f1ea94
27 changed files with 335 additions and 105 deletions

3
.gitignore vendored
View File

@@ -165,6 +165,9 @@ cython_debug/
# uv # uv
uv.lock uv.lock
# macOS
.DS_Store
# custom .gitignore # custom .gitignore
hf_cache/ hf_cache/
ms_cache/ ms_cache/

View File

@@ -1,8 +1,8 @@
identity: identity:
file_name: identity.json file_name: data/identity.json
converter: alpaca converter: alpaca
alpaca_en_demo: alpaca_en_demo:
file_name: alpaca_en_demo.json file_name: alpaca_en_demo.json
dataset_dir: ~/data dataset_dir: data
converter: alpaca converter: alpaca
num_samples: 500 num_samples: 500

View File

@@ -15,8 +15,8 @@
import uuid import uuid
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from packaging import version
from packaging import version
from typing_extensions import override from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer

View File

@@ -465,6 +465,7 @@ class BasePlugin(MMPluginMixin):
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass @dataclass
class ErnieVLPlugin(BasePlugin): class ErnieVLPlugin(BasePlugin):
@override @override

View File

@@ -12,16 +12,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..config.model_args import ModelArguments from typing import Optional
from ..extras.types import Model, Processor
from torch.distributed.device_mesh import DeviceMesh
class ModelEngine: class DeviceMeshManager:
def __init__(self, model_args: ModelArguments) -> None: """Device mesh manager."""
self.args = model_args
def get_model(self) -> Model: _instance: Optional["DeviceMeshManager"] = None
pass _initialized: bool = False
def get_processor(self) -> Processor: def __new__(cls) -> "DeviceMeshManager":
pass if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self.device_mesh: Optional[DeviceMesh] = None
self._initialized = True

View File

@@ -0,0 +1,52 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
import torch
def get_current_accelerator(check_available: bool = True):
"""Get current accelerator.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
"""
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
if accelerator is None:
return torch.device("cpu")
return accelerator
@lru_cache
def is_torch_npu_available():
return get_current_accelerator().type == "npu"
@lru_cache
def is_torch_cuda_available():
return get_current_accelerator().type == "cuda"
@lru_cache
def is_torch_xpu_available():
return get_current_accelerator().type == "xpu"
@lru_cache
def is_torch_mps_available():
return get_current_accelerator().type == "mps"

View File

@@ -23,10 +23,6 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Path to the dataset."}, metadata={"help": "Path to the dataset."},
) )
dataset_dir: str = field(
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
cutoff_len: int = field( cutoff_len: int = field(
default=2048, default=2048,
metadata={"help": "Cutoff length for the dataset."}, metadata={"help": "Cutoff length for the dataset."},

View File

@@ -25,3 +25,11 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Trust remote code from Hugging Face."}, metadata={"help": "Trust remote code from Hugging Face."},
) )
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
auto_model_class: str = field(
default="causallm",
metadata={"help": "Model class from Hugging Face."},
)

View File

@@ -14,10 +14,20 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
class SampleBackend(Enum):
HF = "hf"
VLLM = "vllm"
@dataclass @dataclass
class SampleArguments: class SampleArguments:
sample_backend: SampleBackend = field(
default=SampleBackend.HF,
metadata={"help": "Sampling backend, default to 'hf'."},
)
max_new_tokens: int = field( max_new_tokens: int = field(
default=128, default=128,
metadata={"help": "Maximum number of new tokens to generate."}, metadata={"help": "Maximum number of new tokens to generate."},

View File

@@ -12,44 +12,51 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any """The definition of trainer.
Init Phase:
1. Init dataloader.
2. Init model worker.
3. Init optimizer (deepspeed).
4. Shard model.
5. Init optimizer (fsdp).
6. Init scheduler.
Train Phase:
1. Train Loop
"""
from ..config.training_args import TrainingArguments from ..config.training_args import TrainingArguments
from ..extras.types import Model, Processor, Tensor, TorchDataset from ..extras.types import TorchDataset
from .model_worker import ModelWorker
from .trainer_utils.data_collator import DataCollator
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
# sft: messages
# dpo: chosen_messages, rejected_messages
class BaseTrainer: class BaseTrainer:
def __init__( def __init__(
self, self,
args: TrainingArguments, args: TrainingArguments,
model: Model,
processor: Processor,
dataset: TorchDataset, dataset: TorchDataset,
data_collator: DataCollator, data_collator: DataCollator,
model_worker: ModelWorker,
) -> None: ) -> None:
self.args = args self.args = args
self.model = model
self.processor = processor
self.dataset = dataset self.dataset = dataset
self.data_collator = data_collator self.data_collator = data_collator
self.model_worker = model_worker
self.optimizer = None self.optimizer = None
self.lr_scheduler = None self.lr_scheduler = None
def init_device_mesh(self) -> None:
pass
def init_model_and_optimizer(self) -> None:
self.model_config = self.model_worker.get_model_config()
# with self.dist_plugin.get_model_init_context():
# self.model = self.model_worker.get_model(self.model_config)
def create_dataloader(self) -> None: def create_dataloader(self) -> None:
pass pass

View File

@@ -12,9 +12,35 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..config.sample_args import SampleArguments from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_worker import ModelWorker
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
self.model = model_worker.get_model()
self.processor = model_worker.get_processor()
self.args = sample_args
class ChatSampler: class ChatSampler:
def __init__(self, sample_args: SampleArguments) -> None: def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
self.args = sample_args if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_worker, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -12,11 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""The definition of data engine.
Init Data engine:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
"""
import os import os
from collections.abc import AsyncIterable, Iterable from collections.abc import AsyncIterable, Iterable
from typing import Any, Union from typing import Any, Union
from datasets import load_dataset
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch.utils.data import Dataset from torch.utils.data import Dataset
@@ -45,15 +57,13 @@ class DataEngine(Dataset):
def get_dataset_info(self) -> None: def get_dataset_info(self) -> None:
"""Get dataset info from data arguments.""" """Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile( if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
os.path.join(self.args.dataset_dir, self.args.dataset) self.dataset_infos = OmegaConf.load(self.args.dataset)
): # local file
self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset) repo_id, filename = os.path.split(self.args.dataset)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath) self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s) elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"file_name": self.args.dataset}} self.dataset_infos = {"default": {"file_name": self.args.dataset}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}} self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
@@ -65,11 +75,13 @@ class DataEngine(Dataset):
streaming = value.get("streaming", False) streaming = value.get("streaming", False)
self.streaming |= streaming self.streaming |= streaming
if "hf_hub_url" in value: if "hf_hub_url" in value:
from datasets import load_dataset
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming) self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
else: # data loader plugin else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value) self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
def build_data_index(self) -> None: def build_data_index(self) -> None:
"""Build dataset index.""" """Build dataset index."""
@@ -145,11 +157,11 @@ class DataEngine(Dataset):
dataset_name, sample_index = selected_index dataset_name, sample_index = selected_index
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
def __iter__(self) -> Iterable: def __iter__(self) -> Iterable[Sample]:
"""Get dataset iterator. """Get dataset iterator.
Returns: Returns:
Iterable: Dataset iterator. Iterable[Sample]: Dataset iterator.
""" """
if self.streaming: if self.streaming:
pass pass
@@ -159,11 +171,11 @@ class DataEngine(Dataset):
raise NotImplementedError() raise NotImplementedError()
async def __aiter__(self) -> AsyncIterable: async def __aiter__(self) -> AsyncIterable[Sample]:
"""Get dataset async iterator. """Get dataset async iterator.
Returns: Returns:
AsyncIterable: Dataset async iterator. AsyncIterable[Sample]: Dataset async iterator.
""" """
if self.streaming: if self.streaming:
pass pass

View File

@@ -0,0 +1,98 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model worker.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
from typing import Optional
from transformers import AutoConfig, AutoProcessor
from ..config.model_args import ModelArguments
from ..extras.types import DistModel, HFConfig, HFModel, Processor
class ModelWorker:
def __init__(self, model_args: ModelArguments) -> None:
self.args = model_args
"""Model arguments."""
self.processor: Optional[Processor] = None
"""Tokenizer or multi-modal processor."""
self.model_config: Optional[HFConfig] = None
"""Model configuration."""
self.unwrapped_model: Optional[HFModel] = None
"""Unwrapped model."""
self.model: Optional[DistModel] = None
"""Distributed model."""
self.init_processor()
self.init_model_config()
self.init_model()
self.init_adapter()
def init_processor(self) -> None:
self.processor = AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def init_model_config(self) -> None:
self.model_config = AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def init_model(self) -> None:
if self.args.auto_model_class == "causallm":
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.auto_model_class == "classification":
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
self.unwrapped_model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map="cpu",
trust_remote_code=self.args.trust_remote_code,
)
def init_adapter(self) -> None:
pass
def get_processor(self) -> Processor:
return self.processor
def get_model_config(self) -> HFConfig:
return self.model_config
def get_model(self) -> HFModel:
return self.unwrapped_model

View File

@@ -0,0 +1,47 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from ...extras.types import Processor, Tensor, TorchDataset
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
# sft: messages
# dpo: chosen_messages, rejected_messages
class DataLoader:
"""Default DataLoader."""
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic

View File

@@ -28,18 +28,24 @@ if TYPE_CHECKING:
HFDataset = Union[datasets.Dataset, datasets.IterableDataset] HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
DataCollator = transformers.DataCollator DataCollator = transformers.DataCollator
DataLoader = torch.utils.data.DataLoader DataLoader = torch.utils.data.DataLoader
HFConfig = transformers.PretrainedConfig
HFModel = transformers.PreTrainedModel HFModel = transformers.PreTrainedModel
DistModel = torch.nn.parallel.DistributedDataParallel DistModel = torch.nn.parallel.DistributedDataParallel
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin] Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
Optimizer = torch.optim.Optimizer
Scheduler = torch.optim.lr_scheduler.LRScheduler
else: else:
Tensor = None Tensor = None
TorchDataset = None TorchDataset = None
HFDataset = None HFDataset = None
DataCollator = None DataCollator = None
DataLoader = None DataLoader = None
HFConfig = None
HFModel = None HFModel = None
DistModel = None DistModel = None
Processor = None Processor = None
Optimizer = None
Scheduler = None
class DatasetInfo(TypedDict, total=False): class DatasetInfo(TypedDict, total=False):
@@ -86,10 +92,3 @@ class DPOSample(TypedDict):
Sample = Union[SFTSample, DPOSample] Sample = Union[SFTSample, DPOSample]
class Model(TypedDict):
hf_model: HFModel
"""HF model."""
dist_model: DistModel
"""Distributed model."""

View File

@@ -19,7 +19,6 @@ from typing import Any, Literal, Optional, Union
from datasets import load_dataset from datasets import load_dataset
from ...config.data_args import DataArguments
from ...extras.types import DatasetInfo, HFDataset from ...extras.types import DatasetInfo, HFDataset
@@ -27,9 +26,6 @@ from ...extras.types import DatasetInfo, HFDataset
class DataLoaderPlugin: class DataLoaderPlugin:
"""Plugin for loading dataset.""" """Plugin for loading dataset."""
args: DataArguments
"""Data arguments."""
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]: def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name. """Get dataset builder name.
@@ -42,7 +38,7 @@ class DataLoaderPlugin:
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text") return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset: def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir) dataset_dir = dataset_info.get("dataset_dir", ".")
split = dataset_info.get("split", "train") split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False) streaming = dataset_info.get("streaming", False)
if "file_name" in dataset_info: if "file_name" in dataset_info:

View File

@@ -18,9 +18,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch_npu import torch_npu
from .....accelerator.helper import is_torch_npu_available
from .....extras.packages import is_transformers_version_greater_than from .....extras.packages import is_transformers_version_greater_than
from .....extras.types import HFModel from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType from ..constants import DeviceType, KernelType
from ..registry import MetaMoEKernel from ..registry import MetaMoEKernel

View File

@@ -17,8 +17,8 @@ import types
import torch import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType from ..constants import DeviceType, KernelType
from ..registry import MetaSwiGluKernel from ..registry import MetaSwiGluKernel

View File

@@ -15,8 +15,8 @@
from abc import ABC, ABCMeta, abstractmethod from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from ....accelerator.helper import get_current_accelerator
from ....extras.types import HFModel from ....extras.types import HFModel
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
from .constants import DeviceType, KernelType from .constants import DeviceType, KernelType
@@ -206,7 +206,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
discovered_kernels: list[type[MetaKernel]] = [] discovered_kernels: list[type[MetaKernel]] = []
# Detect current device type # Detect current device type
accelerator = get_available_accelerator() accelerator = get_current_accelerator()
try: try:
device_type = DeviceType(accelerator.type) device_type = DeviceType(accelerator.type)
except ValueError: except ValueError:
@@ -238,11 +238,11 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B") model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel) model = apply_kernel(model, NpuRMSNormKernel)
""" """
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type: if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
return kernel.apply(model, **kwargs) return kernel.apply(model, **kwargs)
raise ValueError( raise ValueError(
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead." f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
) )

View File

@@ -14,8 +14,8 @@
import re import re
import types import types
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType from ..constants import DeviceType, KernelType
from ..registry import MetaRMSNormKernel from ..registry import MetaRMSNormKernel

View File

@@ -16,8 +16,8 @@ import sys
import torch import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType from ..constants import DeviceType, KernelType
from ..registry import MetaRoPEKernel from ..registry import MetaRoPEKernel

View File

@@ -11,37 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import lru_cache
import torch
def get_available_accelerator():
"""Get available accelerator in current environment.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
"""
accelerator = torch.accelerator.current_accelerator()
if accelerator is None:
return torch.device("cpu")
return accelerator
@lru_cache
def is_torch_npu_available():
return get_available_accelerator().type == "npu"
@lru_cache
def is_torch_cuda_available():
return get_available_accelerator().type == "cuda"
@lru_cache
def is_torch_xpu_available():
return get_available_accelerator().type == "xpu"
@lru_cache
def is_torch_mps_available():
return get_available_accelerator().type == "mps"