diff --git a/.gitignore b/.gitignore index f425bd59..40bc040a 100644 --- a/.gitignore +++ b/.gitignore @@ -165,6 +165,9 @@ cython_debug/ # uv uv.lock +# macOS +.DS_Store + # custom .gitignore hf_cache/ ms_cache/ diff --git a/data/v1_sft_demo.yaml b/data/v1_sft_demo.yaml index dcee9134..9687431d 100644 --- a/data/v1_sft_demo.yaml +++ b/data/v1_sft_demo.yaml @@ -1,8 +1,8 @@ identity: - file_name: identity.json + file_name: data/identity.json converter: alpaca alpaca_en_demo: file_name: alpaca_en_demo.json - dataset_dir: ~/data + dataset_dir: data converter: alpaca num_samples: 500 diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 709afe4a..075924a2 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -15,8 +15,8 @@ import uuid from collections.abc import AsyncGenerator, AsyncIterator from typing import TYPE_CHECKING, Any, Optional, Union -from packaging import version +from packaging import version from typing_extensions import override from ..data import get_template_and_fix_tokenizer diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 20ea1417..30ffa8e9 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -465,6 +465,7 @@ class BasePlugin(MMPluginMixin): self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) + @dataclass class ErnieVLPlugin(BasePlugin): @override diff --git a/src/llamafactory/v1/accelerator/__init__.py b/src/llamafactory/v1/accelerator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/accelerator/distributed.py similarity index 51% rename from src/llamafactory/v1/core/model_engine.py rename to src/llamafactory/v1/accelerator/distributed.py index 24d2d4b7..0094bf5a 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/accelerator/distributed.py @@ -12,16 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..config.model_args import ModelArguments -from ..extras.types import Model, Processor +from typing import Optional + +from torch.distributed.device_mesh import DeviceMesh -class ModelEngine: - def __init__(self, model_args: ModelArguments) -> None: - self.args = model_args +class DeviceMeshManager: + """Device mesh manager.""" - def get_model(self) -> Model: - pass + _instance: Optional["DeviceMeshManager"] = None + _initialized: bool = False - def get_processor(self) -> Processor: - pass + def __new__(cls) -> "DeviceMeshManager": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + if self._initialized: + return + + self.device_mesh: Optional[DeviceMesh] = None + self._initialized = True diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py new file mode 100644 index 00000000..a04202a2 --- /dev/null +++ b/src/llamafactory/v1/accelerator/helper.py @@ -0,0 +1,52 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import lru_cache + +import torch + + +def get_current_accelerator(check_available: bool = True): + """Get current accelerator. + + Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError + """ + if not hasattr(torch, "accelerator"): + raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.") + + accelerator = torch.accelerator.current_accelerator(check_available=check_available) + if accelerator is None: + return torch.device("cpu") + + return accelerator + + +@lru_cache +def is_torch_npu_available(): + return get_current_accelerator().type == "npu" + + +@lru_cache +def is_torch_cuda_available(): + return get_current_accelerator().type == "cuda" + + +@lru_cache +def is_torch_xpu_available(): + return get_current_accelerator().type == "xpu" + + +@lru_cache +def is_torch_mps_available(): + return get_current_accelerator().type == "mps" diff --git a/src/llamafactory/v1/accelerator/profiler.py b/src/llamafactory/v1/accelerator/profiler.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/config/data_args.py b/src/llamafactory/v1/config/data_args.py index 28829642..845bd1a8 100644 --- a/src/llamafactory/v1/config/data_args.py +++ b/src/llamafactory/v1/config/data_args.py @@ -23,10 +23,6 @@ class DataArguments: default=None, metadata={"help": "Path to the dataset."}, ) - dataset_dir: str = field( - default="data", - metadata={"help": "Path to the folder containing the datasets."}, - ) cutoff_len: int = field( default=2048, metadata={"help": "Cutoff length for the dataset."}, diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py index b9a98660..6bb8dd58 100644 --- a/src/llamafactory/v1/config/model_args.py +++ b/src/llamafactory/v1/config/model_args.py @@ -25,3 +25,11 @@ class ModelArguments: default=False, metadata={"help": "Trust remote code from Hugging Face."}, ) + use_fast_processor: bool = field( + default=True, + metadata={"help": "Use fast processor from Hugging Face."}, + ) + auto_model_class: str = field( + default="causallm", + metadata={"help": "Model class from Hugging Face."}, + ) diff --git a/src/llamafactory/v1/config/sample_args.py b/src/llamafactory/v1/config/sample_args.py index 666efb01..52f24ae5 100644 --- a/src/llamafactory/v1/config/sample_args.py +++ b/src/llamafactory/v1/config/sample_args.py @@ -14,10 +14,20 @@ from dataclasses import dataclass, field +from enum import Enum + + +class SampleBackend(Enum): + HF = "hf" + VLLM = "vllm" @dataclass class SampleArguments: + sample_backend: SampleBackend = field( + default=SampleBackend.HF, + metadata={"help": "Sampling backend, default to 'hf'."}, + ) max_new_tokens: int = field( default=128, metadata={"help": "Maximum number of new tokens to generate."}, diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 041e5f7d..cacca81c 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -12,44 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +"""The definition of trainer. + +Init Phase: + +1. Init dataloader. +2. Init model worker. +3. Init optimizer (deepspeed). +4. Shard model. +5. Init optimizer (fsdp). +6. Init scheduler. + +Train Phase: +1. Train Loop + +""" from ..config.training_args import TrainingArguments -from ..extras.types import Model, Processor, Tensor, TorchDataset - - -class DataCollator: - """Default Data collator.""" - - def __init__(self, processor: Processor) -> None: - self.processor = processor - - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]: - """Collate features into a batch.""" - for feature in features: - pass - - # sft: messages - # dpo: chosen_messages, rejected_messages +from ..extras.types import TorchDataset +from .model_worker import ModelWorker +from .trainer_utils.data_collator import DataCollator class BaseTrainer: def __init__( self, args: TrainingArguments, - model: Model, - processor: Processor, dataset: TorchDataset, data_collator: DataCollator, + model_worker: ModelWorker, ) -> None: self.args = args - self.model = model - self.processor = processor self.dataset = dataset self.data_collator = data_collator + self.model_worker = model_worker self.optimizer = None self.lr_scheduler = None + def init_device_mesh(self) -> None: + pass + + def init_model_and_optimizer(self) -> None: + self.model_config = self.model_worker.get_model_config() + # with self.dist_plugin.get_model_init_context(): + # self.model = self.model_worker.get_model(self.model_config) + def create_dataloader(self) -> None: pass diff --git a/src/llamafactory/v1/core/chat_sampler.py b/src/llamafactory/v1/core/chat_sampler.py index 3f213914..412185b5 100644 --- a/src/llamafactory/v1/core/chat_sampler.py +++ b/src/llamafactory/v1/core/chat_sampler.py @@ -12,9 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..config.sample_args import SampleArguments +from abc import ABC, abstractmethod + +from ..config.sample_args import SampleArguments, SampleBackend +from .model_worker import ModelWorker + + +class BaseEngine(ABC): + @abstractmethod + def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ... + + @abstractmethod + async def generate(self): + pass + + @abstractmethod + async def batch_infer(self): + pass + + +class HuggingFaceEngine(BaseEngine): + def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None: + self.model = model_worker.get_model() + self.processor = model_worker.get_processor() + self.args = sample_args class ChatSampler: - def __init__(self, sample_args: SampleArguments) -> None: - self.args = sample_args + def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None: + if sample_args.sample_backend == SampleBackend.HF: + self.engine = HuggingFaceEngine(model_worker, sample_args) + else: + raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}") diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py index 8abe3bb7..523037b1 100644 --- a/src/llamafactory/v1/core/data_engine.py +++ b/src/llamafactory/v1/core/data_engine.py @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""The definition of data engine. + +Init Data engine: +1. Parse dataset info from arguments. +2. Load datasets according to dataset info. +3. Build data index (and reweight samples if necessary). + +Get Data Sample: +1. Get sample from data index. +2. Convert sample to standard format. +3. Return sample. +""" + import os from collections.abc import AsyncIterable, Iterable from typing import Any, Union -from datasets import load_dataset from huggingface_hub import hf_hub_download from omegaconf import OmegaConf from torch.utils.data import Dataset @@ -45,15 +57,13 @@ class DataEngine(Dataset): def get_dataset_info(self) -> None: """Get dataset info from data arguments.""" - if self.args.dataset.endswith(".yaml") and os.path.isfile( - os.path.join(self.args.dataset_dir, self.args.dataset) - ): # local file - self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset)) + if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file + self.dataset_infos = OmegaConf.load(self.args.dataset) elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml repo_id, filename = os.path.split(self.args.dataset) filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") self.dataset_infos = OmegaConf.load(filepath) - elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s) + elif os.path.exists(self.args.dataset): # local file(s) self.dataset_infos = {"default": {"file_name": self.args.dataset}} else: # hf hub dataset, e.g. llamafactory/v1-sft-demo self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}} @@ -65,11 +75,13 @@ class DataEngine(Dataset): streaming = value.get("streaming", False) self.streaming |= streaming if "hf_hub_url" in value: + from datasets import load_dataset + self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming) else: # data loader plugin from ..plugins.data_plugins.loader import DataLoaderPlugin - self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value) + self.datasets[key] = DataLoaderPlugin().auto_load_data(value) def build_data_index(self) -> None: """Build dataset index.""" @@ -145,11 +157,11 @@ class DataEngine(Dataset): dataset_name, sample_index = selected_index return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) - def __iter__(self) -> Iterable: + def __iter__(self) -> Iterable[Sample]: """Get dataset iterator. Returns: - Iterable: Dataset iterator. + Iterable[Sample]: Dataset iterator. """ if self.streaming: pass @@ -159,11 +171,11 @@ class DataEngine(Dataset): raise NotImplementedError() - async def __aiter__(self) -> AsyncIterable: + async def __aiter__(self) -> AsyncIterable[Sample]: """Get dataset async iterator. Returns: - AsyncIterable: Dataset async iterator. + AsyncIterable[Sample]: Dataset async iterator. """ if self.streaming: pass diff --git a/src/llamafactory/v1/core/model_worker.py b/src/llamafactory/v1/core/model_worker.py new file mode 100644 index 00000000..6d38189c --- /dev/null +++ b/src/llamafactory/v1/core/model_worker.py @@ -0,0 +1,98 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of model worker. + +Init Phase: +1. Init processor. +2. Init model config. +3. Init model. +4. Init adapter. + +""" + +from typing import Optional + +from transformers import AutoConfig, AutoProcessor + +from ..config.model_args import ModelArguments +from ..extras.types import DistModel, HFConfig, HFModel, Processor + + +class ModelWorker: + def __init__(self, model_args: ModelArguments) -> None: + self.args = model_args + """Model arguments.""" + self.processor: Optional[Processor] = None + """Tokenizer or multi-modal processor.""" + self.model_config: Optional[HFConfig] = None + """Model configuration.""" + self.unwrapped_model: Optional[HFModel] = None + """Unwrapped model.""" + self.model: Optional[DistModel] = None + """Distributed model.""" + self.init_processor() + self.init_model_config() + self.init_model() + self.init_adapter() + + def init_processor(self) -> None: + self.processor = AutoProcessor.from_pretrained( + self.args.model, + trust_remote_code=self.args.trust_remote_code, + use_fast=self.args.use_fast_processor, + ) + + def init_model_config(self) -> None: + self.model_config = AutoConfig.from_pretrained( + self.args.model, + trust_remote_code=self.args.trust_remote_code, + ) + + def init_model(self) -> None: + if self.args.auto_model_class == "causallm": + from transformers import AutoModelForCausalLM, AutoModelForImageTextToText + + if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys(): + AutoClass = AutoModelForImageTextToText + else: + AutoClass = AutoModelForCausalLM + elif self.args.auto_model_class == "classification": + from transformers import AutoModelForTokenClassification + + AutoClass = AutoModelForTokenClassification + else: + from transformers import AutoModel + + AutoClass = AutoModel + + self.unwrapped_model = AutoClass.from_pretrained( + self.args.model, + config=self.model_config, + dtype="auto", + device_map="cpu", + trust_remote_code=self.args.trust_remote_code, + ) + + def init_adapter(self) -> None: + pass + + def get_processor(self) -> Processor: + return self.processor + + def get_model_config(self) -> HFConfig: + return self.model_config + + def get_model(self) -> HFModel: + return self.unwrapped_model diff --git a/src/llamafactory/v1/core/trainer_utils/__init__.py b/src/llamafactory/v1/core/trainer_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/core/trainer_utils/callback.py b/src/llamafactory/v1/core/trainer_utils/callback.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/core/trainer_utils/data_collator.py b/src/llamafactory/v1/core/trainer_utils/data_collator.py new file mode 100644 index 00000000..88f49672 --- /dev/null +++ b/src/llamafactory/v1/core/trainer_utils/data_collator.py @@ -0,0 +1,47 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +from ...extras.types import Processor, Tensor, TorchDataset + + +class DataCollator: + """Default Data collator.""" + + def __init__(self, processor: Processor) -> None: + self.processor = processor + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]: + """Collate features into a batch.""" + for feature in features: + pass + + # sft: messages + # dpo: chosen_messages, rejected_messages + + +class DataLoader: + """Default DataLoader.""" + + def __init__(self, dataset: TorchDataset) -> None: + self.dataset = dataset + # 1. Init stateful dataloader (tokenize) + # 2. Add to buffer (2 * max seq len per device) + # 3. Yield batch indexes (micro batch * grad acc) + # a ) non pack + non dynamic + # b ) non pack + dynamic + # c ) pack + non dynamic + # d ) pack + dynamic diff --git a/src/llamafactory/v1/core/trainer_utils/lr_scheduler.py b/src/llamafactory/v1/core/trainer_utils/lr_scheduler.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/extras/types.py b/src/llamafactory/v1/extras/types.py index 9539931a..16658c3d 100644 --- a/src/llamafactory/v1/extras/types.py +++ b/src/llamafactory/v1/extras/types.py @@ -28,18 +28,24 @@ if TYPE_CHECKING: HFDataset = Union[datasets.Dataset, datasets.IterableDataset] DataCollator = transformers.DataCollator DataLoader = torch.utils.data.DataLoader + HFConfig = transformers.PretrainedConfig HFModel = transformers.PreTrainedModel DistModel = torch.nn.parallel.DistributedDataParallel Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin] + Optimizer = torch.optim.Optimizer + Scheduler = torch.optim.lr_scheduler.LRScheduler else: Tensor = None TorchDataset = None HFDataset = None DataCollator = None DataLoader = None + HFConfig = None HFModel = None DistModel = None Processor = None + Optimizer = None + Scheduler = None class DatasetInfo(TypedDict, total=False): @@ -86,10 +92,3 @@ class DPOSample(TypedDict): Sample = Union[SFTSample, DPOSample] - - -class Model(TypedDict): - hf_model: HFModel - """HF model.""" - dist_model: DistModel - """Distributed model.""" diff --git a/src/llamafactory/v1/plugins/data_plugins/loader.py b/src/llamafactory/v1/plugins/data_plugins/loader.py index 543db438..92ec6add 100644 --- a/src/llamafactory/v1/plugins/data_plugins/loader.py +++ b/src/llamafactory/v1/plugins/data_plugins/loader.py @@ -19,7 +19,6 @@ from typing import Any, Literal, Optional, Union from datasets import load_dataset -from ...config.data_args import DataArguments from ...extras.types import DatasetInfo, HFDataset @@ -27,9 +26,6 @@ from ...extras.types import DatasetInfo, HFDataset class DataLoaderPlugin: """Plugin for loading dataset.""" - args: DataArguments - """Data arguments.""" - def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]: """Get dataset builder name. @@ -42,7 +38,7 @@ class DataLoaderPlugin: return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text") def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset: - dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir) + dataset_dir = dataset_info.get("dataset_dir", ".") split = dataset_info.get("split", "train") streaming = dataset_info.get("streaming", False) if "file_name" in dataset_info: diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py index 21055760..7f1e8c0b 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py @@ -18,9 +18,9 @@ import torch import torch.nn.functional as F import torch_npu +from .....accelerator.helper import is_torch_npu_available from .....extras.packages import is_transformers_version_greater_than from .....extras.types import HFModel -from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ..constants import DeviceType, KernelType from ..registry import MetaMoEKernel diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py index e396a496..011b53b3 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py @@ -17,8 +17,8 @@ import types import torch +from .....accelerator.helper import is_torch_npu_available from .....extras.types import HFModel -from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ..constants import DeviceType, KernelType from ..registry import MetaSwiGluKernel diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index b677b8a6..06595ba1 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -15,8 +15,8 @@ from abc import ABC, ABCMeta, abstractmethod from typing import Any, Callable, Optional +from ....accelerator.helper import get_current_accelerator from ....extras.types import HFModel -from ...trainer_plugins.distributed.accelerate import get_available_accelerator from .constants import DeviceType, KernelType @@ -206,7 +206,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]: discovered_kernels: list[type[MetaKernel]] = [] # Detect current device type - accelerator = get_available_accelerator() + accelerator = get_current_accelerator() try: device_type = DeviceType(accelerator.type) except ValueError: @@ -238,11 +238,11 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B") model = apply_kernel(model, NpuRMSNormKernel) """ - if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type: + if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type: return kernel.apply(model, **kwargs) raise ValueError( - f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead." + f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead." ) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py index ba51f332..01550719 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py @@ -14,8 +14,8 @@ import re import types +from .....accelerator.helper import is_torch_npu_available from .....extras.types import HFModel -from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ..constants import DeviceType, KernelType from ..registry import MetaRMSNormKernel diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py index 5e877f0a..a0608e7d 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py @@ -16,8 +16,8 @@ import sys import torch +from .....accelerator.helper import is_torch_npu_available from .....extras.types import HFModel -from ....trainer_plugins.distributed.accelerate import is_torch_npu_available from ..constants import DeviceType, KernelType from ..registry import MetaRoPEKernel diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py index e7a9bf30..ec0d6255 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py @@ -11,37 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import lru_cache - -import torch - - -def get_available_accelerator(): - """Get available accelerator in current environment. - - Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError - """ - accelerator = torch.accelerator.current_accelerator() - if accelerator is None: - return torch.device("cpu") - return accelerator - - -@lru_cache -def is_torch_npu_available(): - return get_available_accelerator().type == "npu" - - -@lru_cache -def is_torch_cuda_available(): - return get_available_accelerator().type == "cuda" - - -@lru_cache -def is_torch_xpu_available(): - return get_available_accelerator().type == "xpu" - - -@lru_cache -def is_torch_mps_available(): - return get_available_accelerator().type == "mps"