mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
[v1] add models & accelerator (#9579)
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -165,6 +165,9 @@ cython_debug/
|
||||
# uv
|
||||
uv.lock
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
# custom .gitignore
|
||||
hf_cache/
|
||||
ms_cache/
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
identity:
|
||||
file_name: identity.json
|
||||
file_name: data/identity.json
|
||||
converter: alpaca
|
||||
alpaca_en_demo:
|
||||
file_name: alpaca_en_demo.json
|
||||
dataset_dir: ~/data
|
||||
dataset_dir: data
|
||||
converter: alpaca
|
||||
num_samples: 500
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from packaging import version
|
||||
|
||||
from packaging import version
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
|
||||
@@ -465,6 +465,7 @@ class BasePlugin(MMPluginMixin):
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErnieVLPlugin(BasePlugin):
|
||||
@override
|
||||
|
||||
0
src/llamafactory/v1/accelerator/__init__.py
Normal file
0
src/llamafactory/v1/accelerator/__init__.py
Normal file
@@ -12,16 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..config.model_args import ModelArguments
|
||||
from ..extras.types import Model, Processor
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class ModelEngine:
|
||||
def __init__(self, model_args: ModelArguments) -> None:
|
||||
self.args = model_args
|
||||
class DeviceMeshManager:
|
||||
"""Device mesh manager."""
|
||||
|
||||
def get_model(self) -> Model:
|
||||
pass
|
||||
_instance: Optional["DeviceMeshManager"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
pass
|
||||
def __new__(cls) -> "DeviceMeshManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.device_mesh: Optional[DeviceMesh] = None
|
||||
self._initialized = True
|
||||
52
src/llamafactory/v1/accelerator/helper.py
Normal file
52
src/llamafactory/v1/accelerator/helper.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_current_accelerator(check_available: bool = True):
|
||||
"""Get current accelerator.
|
||||
|
||||
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
||||
"""
|
||||
if not hasattr(torch, "accelerator"):
|
||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||
|
||||
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
|
||||
if accelerator is None:
|
||||
return torch.device("cpu")
|
||||
|
||||
return accelerator
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_npu_available():
|
||||
return get_current_accelerator().type == "npu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_cuda_available():
|
||||
return get_current_accelerator().type == "cuda"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xpu_available():
|
||||
return get_current_accelerator().type == "xpu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_mps_available():
|
||||
return get_current_accelerator().type == "mps"
|
||||
0
src/llamafactory/v1/accelerator/profiler.py
Normal file
0
src/llamafactory/v1/accelerator/profiler.py
Normal file
@@ -23,10 +23,6 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
)
|
||||
dataset_dir: str = field(
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Cutoff length for the dataset."},
|
||||
|
||||
@@ -25,3 +25,11 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
use_fast_processor: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
auto_model_class: str = field(
|
||||
default="causallm",
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
|
||||
@@ -14,10 +14,20 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SampleBackend(Enum):
|
||||
HF = "hf"
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleArguments:
|
||||
sample_backend: SampleBackend = field(
|
||||
default=SampleBackend.HF,
|
||||
metadata={"help": "Sampling backend, default to 'hf'."},
|
||||
)
|
||||
max_new_tokens: int = field(
|
||||
default=128,
|
||||
metadata={"help": "Maximum number of new tokens to generate."},
|
||||
|
||||
@@ -12,44 +12,51 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
"""The definition of trainer.
|
||||
|
||||
Init Phase:
|
||||
|
||||
1. Init dataloader.
|
||||
2. Init model worker.
|
||||
3. Init optimizer (deepspeed).
|
||||
4. Shard model.
|
||||
5. Init optimizer (fsdp).
|
||||
6. Init scheduler.
|
||||
|
||||
Train Phase:
|
||||
1. Train Loop
|
||||
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..extras.types import Model, Processor, Tensor, TorchDataset
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
def __init__(self, processor: Processor) -> None:
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
for feature in features:
|
||||
pass
|
||||
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
from ..extras.types import TorchDataset
|
||||
from .model_worker import ModelWorker
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: Model,
|
||||
processor: Processor,
|
||||
dataset: TorchDataset,
|
||||
data_collator: DataCollator,
|
||||
model_worker: ModelWorker,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.dataset = dataset
|
||||
self.data_collator = data_collator
|
||||
self.model_worker = model_worker
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def init_device_mesh(self) -> None:
|
||||
pass
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
self.model_config = self.model_worker.get_model_config()
|
||||
# with self.dist_plugin.get_model_init_context():
|
||||
# self.model = self.model_worker.get_model(self.model_config)
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -12,9 +12,35 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..config.sample_args import SampleArguments
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..config.sample_args import SampleArguments, SampleBackend
|
||||
from .model_worker import ModelWorker
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self):
|
||||
pass
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
|
||||
self.model = model_worker.get_model()
|
||||
self.processor = model_worker.get_processor()
|
||||
self.args = sample_args
|
||||
|
||||
|
||||
class ChatSampler:
|
||||
def __init__(self, sample_args: SampleArguments) -> None:
|
||||
self.args = sample_args
|
||||
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
|
||||
if sample_args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(model_worker, sample_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
||||
|
||||
@@ -12,11 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of data engine.
|
||||
|
||||
Init Data engine:
|
||||
1. Parse dataset info from arguments.
|
||||
2. Load datasets according to dataset info.
|
||||
3. Build data index (and reweight samples if necessary).
|
||||
|
||||
Get Data Sample:
|
||||
1. Get sample from data index.
|
||||
2. Convert sample to standard format.
|
||||
3. Return sample.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterable, Iterable
|
||||
from typing import Any, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import Dataset
|
||||
@@ -45,15 +57,13 @@ class DataEngine(Dataset):
|
||||
|
||||
def get_dataset_info(self) -> None:
|
||||
"""Get dataset info from data arguments."""
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(
|
||||
os.path.join(self.args.dataset_dir, self.args.dataset)
|
||||
): # local file
|
||||
self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
|
||||
self.dataset_infos = OmegaConf.load(self.args.dataset)
|
||||
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
||||
repo_id, filename = os.path.split(self.args.dataset)
|
||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||
self.dataset_infos = OmegaConf.load(filepath)
|
||||
elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s)
|
||||
elif os.path.exists(self.args.dataset): # local file(s)
|
||||
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
|
||||
@@ -65,11 +75,13 @@ class DataEngine(Dataset):
|
||||
streaming = value.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if "hf_hub_url" in value:
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
|
||||
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
|
||||
|
||||
def build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
@@ -145,11 +157,11 @@ class DataEngine(Dataset):
|
||||
dataset_name, sample_index = selected_index
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
|
||||
def __iter__(self) -> Iterable:
|
||||
def __iter__(self) -> Iterable[Sample]:
|
||||
"""Get dataset iterator.
|
||||
|
||||
Returns:
|
||||
Iterable: Dataset iterator.
|
||||
Iterable[Sample]: Dataset iterator.
|
||||
"""
|
||||
if self.streaming:
|
||||
pass
|
||||
@@ -159,11 +171,11 @@ class DataEngine(Dataset):
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterable:
|
||||
async def __aiter__(self) -> AsyncIterable[Sample]:
|
||||
"""Get dataset async iterator.
|
||||
|
||||
Returns:
|
||||
AsyncIterable: Dataset async iterator.
|
||||
AsyncIterable[Sample]: Dataset async iterator.
|
||||
"""
|
||||
if self.streaming:
|
||||
pass
|
||||
|
||||
98
src/llamafactory/v1/core/model_worker.py
Normal file
98
src/llamafactory/v1/core/model_worker.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model worker.
|
||||
|
||||
Init Phase:
|
||||
1. Init processor.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..config.model_args import ModelArguments
|
||||
from ..extras.types import DistModel, HFConfig, HFModel, Processor
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self, model_args: ModelArguments) -> None:
|
||||
self.args = model_args
|
||||
"""Model arguments."""
|
||||
self.processor: Optional[Processor] = None
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.model_config: Optional[HFConfig] = None
|
||||
"""Model configuration."""
|
||||
self.unwrapped_model: Optional[HFModel] = None
|
||||
"""Unwrapped model."""
|
||||
self.model: Optional[DistModel] = None
|
||||
"""Distributed model."""
|
||||
self.init_processor()
|
||||
self.init_model_config()
|
||||
self.init_model()
|
||||
self.init_adapter()
|
||||
|
||||
def init_processor(self) -> None:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_model(self) -> None:
|
||||
if self.args.auto_model_class == "causallm":
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||
AutoClass = AutoModelForImageTextToText
|
||||
else:
|
||||
AutoClass = AutoModelForCausalLM
|
||||
elif self.args.auto_model_class == "classification":
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
self.unwrapped_model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map="cpu",
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_adapter(self) -> None:
|
||||
pass
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
return self.processor
|
||||
|
||||
def get_model_config(self) -> HFConfig:
|
||||
return self.model_config
|
||||
|
||||
def get_model(self) -> HFModel:
|
||||
return self.unwrapped_model
|
||||
0
src/llamafactory/v1/core/trainer_utils/__init__.py
Normal file
0
src/llamafactory/v1/core/trainer_utils/__init__.py
Normal file
0
src/llamafactory/v1/core/trainer_utils/callback.py
Normal file
0
src/llamafactory/v1/core/trainer_utils/callback.py
Normal file
47
src/llamafactory/v1/core/trainer_utils/data_collator.py
Normal file
47
src/llamafactory/v1/core/trainer_utils/data_collator.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...extras.types import Processor, Tensor, TorchDataset
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
def __init__(self, processor: Processor) -> None:
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
for feature in features:
|
||||
pass
|
||||
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
@@ -28,18 +28,24 @@ if TYPE_CHECKING:
|
||||
HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
|
||||
DataCollator = transformers.DataCollator
|
||||
DataLoader = torch.utils.data.DataLoader
|
||||
HFConfig = transformers.PretrainedConfig
|
||||
HFModel = transformers.PreTrainedModel
|
||||
DistModel = torch.nn.parallel.DistributedDataParallel
|
||||
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
|
||||
Optimizer = torch.optim.Optimizer
|
||||
Scheduler = torch.optim.lr_scheduler.LRScheduler
|
||||
else:
|
||||
Tensor = None
|
||||
TorchDataset = None
|
||||
HFDataset = None
|
||||
DataCollator = None
|
||||
DataLoader = None
|
||||
HFConfig = None
|
||||
HFModel = None
|
||||
DistModel = None
|
||||
Processor = None
|
||||
Optimizer = None
|
||||
Scheduler = None
|
||||
|
||||
|
||||
class DatasetInfo(TypedDict, total=False):
|
||||
@@ -86,10 +92,3 @@ class DPOSample(TypedDict):
|
||||
|
||||
|
||||
Sample = Union[SFTSample, DPOSample]
|
||||
|
||||
|
||||
class Model(TypedDict):
|
||||
hf_model: HFModel
|
||||
"""HF model."""
|
||||
dist_model: DistModel
|
||||
"""Distributed model."""
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import Any, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from ...config.data_args import DataArguments
|
||||
from ...extras.types import DatasetInfo, HFDataset
|
||||
|
||||
|
||||
@@ -27,9 +26,6 @@ from ...extras.types import DatasetInfo, HFDataset
|
||||
class DataLoaderPlugin:
|
||||
"""Plugin for loading dataset."""
|
||||
|
||||
args: DataArguments
|
||||
"""Data arguments."""
|
||||
|
||||
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
"""Get dataset builder name.
|
||||
|
||||
@@ -42,7 +38,7 @@ class DataLoaderPlugin:
|
||||
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
|
||||
|
||||
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||
dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir)
|
||||
dataset_dir = dataset_info.get("dataset_dir", ".")
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
if "file_name" in dataset_info:
|
||||
|
||||
@@ -18,9 +18,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.packages import is_transformers_version_greater_than
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ import types
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaSwiGluKernel
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from ....extras.types import HFModel
|
||||
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
|
||||
from .constants import DeviceType, KernelType
|
||||
|
||||
|
||||
@@ -206,7 +206,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
discovered_kernels: list[type[MetaKernel]] = []
|
||||
|
||||
# Detect current device type
|
||||
accelerator = get_available_accelerator()
|
||||
accelerator = get_current_accelerator()
|
||||
try:
|
||||
device_type = DeviceType(accelerator.type)
|
||||
except ValueError:
|
||||
@@ -238,11 +238,11 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
model = apply_kernel(model, NpuRMSNormKernel)
|
||||
"""
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
import re
|
||||
import types
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaRMSNormKernel
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from .....accelerator.helper import is_torch_npu_available
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaRoPEKernel
|
||||
|
||||
|
||||
@@ -11,37 +11,3 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_available_accelerator():
|
||||
"""Get available accelerator in current environment.
|
||||
|
||||
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
|
||||
"""
|
||||
accelerator = torch.accelerator.current_accelerator()
|
||||
if accelerator is None:
|
||||
return torch.device("cpu")
|
||||
return accelerator
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_npu_available():
|
||||
return get_available_accelerator().type == "npu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_cuda_available():
|
||||
return get_available_accelerator().type == "cuda"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xpu_available():
|
||||
return get_available_accelerator().type == "xpu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_mps_available():
|
||||
return get_available_accelerator().type == "mps"
|
||||
|
||||
Reference in New Issue
Block a user