[v1] add models & accelerator (#9579)

This commit is contained in:
Yaowei Zheng
2025-12-08 02:30:25 +08:00
committed by GitHub
parent 739954910a
commit 5744f1ea94
27 changed files with 335 additions and 105 deletions

3
.gitignore vendored
View File

@@ -165,6 +165,9 @@ cython_debug/
# uv
uv.lock
# macOS
.DS_Store
# custom .gitignore
hf_cache/
ms_cache/

View File

@@ -1,8 +1,8 @@
identity:
file_name: identity.json
file_name: data/identity.json
converter: alpaca
alpaca_en_demo:
file_name: alpaca_en_demo.json
dataset_dir: ~/data
dataset_dir: data
converter: alpaca
num_samples: 500

View File

@@ -15,8 +15,8 @@
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union
from packaging import version
from packaging import version
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer

View File

@@ -465,6 +465,7 @@ class BasePlugin(MMPluginMixin):
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class ErnieVLPlugin(BasePlugin):
@override

View File

@@ -12,16 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.model_args import ModelArguments
from ..extras.types import Model, Processor
from typing import Optional
from torch.distributed.device_mesh import DeviceMesh
class ModelEngine:
def __init__(self, model_args: ModelArguments) -> None:
self.args = model_args
class DeviceMeshManager:
"""Device mesh manager."""
def get_model(self) -> Model:
pass
_instance: Optional["DeviceMeshManager"] = None
_initialized: bool = False
def get_processor(self) -> Processor:
pass
def __new__(cls) -> "DeviceMeshManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self.device_mesh: Optional[DeviceMesh] = None
self._initialized = True

View File

@@ -0,0 +1,52 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
import torch
def get_current_accelerator(check_available: bool = True):
"""Get current accelerator.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
"""
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
if accelerator is None:
return torch.device("cpu")
return accelerator
@lru_cache
def is_torch_npu_available():
return get_current_accelerator().type == "npu"
@lru_cache
def is_torch_cuda_available():
return get_current_accelerator().type == "cuda"
@lru_cache
def is_torch_xpu_available():
return get_current_accelerator().type == "xpu"
@lru_cache
def is_torch_mps_available():
return get_current_accelerator().type == "mps"

View File

@@ -23,10 +23,6 @@ class DataArguments:
default=None,
metadata={"help": "Path to the dataset."},
)
dataset_dir: str = field(
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},

View File

@@ -25,3 +25,11 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
auto_model_class: str = field(
default="causallm",
metadata={"help": "Model class from Hugging Face."},
)

View File

@@ -14,10 +14,20 @@
from dataclasses import dataclass, field
from enum import Enum
class SampleBackend(Enum):
HF = "hf"
VLLM = "vllm"
@dataclass
class SampleArguments:
sample_backend: SampleBackend = field(
default=SampleBackend.HF,
metadata={"help": "Sampling backend, default to 'hf'."},
)
max_new_tokens: int = field(
default=128,
metadata={"help": "Maximum number of new tokens to generate."},

View File

@@ -12,44 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
"""The definition of trainer.
Init Phase:
1. Init dataloader.
2. Init model worker.
3. Init optimizer (deepspeed).
4. Shard model.
5. Init optimizer (fsdp).
6. Init scheduler.
Train Phase:
1. Train Loop
"""
from ..config.training_args import TrainingArguments
from ..extras.types import Model, Processor, Tensor, TorchDataset
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
# sft: messages
# dpo: chosen_messages, rejected_messages
from ..extras.types import TorchDataset
from .model_worker import ModelWorker
from .trainer_utils.data_collator import DataCollator
class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: Model,
processor: Processor,
dataset: TorchDataset,
data_collator: DataCollator,
model_worker: ModelWorker,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.dataset = dataset
self.data_collator = data_collator
self.model_worker = model_worker
self.optimizer = None
self.lr_scheduler = None
def init_device_mesh(self) -> None:
pass
def init_model_and_optimizer(self) -> None:
self.model_config = self.model_worker.get_model_config()
# with self.dist_plugin.get_model_init_context():
# self.model = self.model_worker.get_model(self.model_config)
def create_dataloader(self) -> None:
pass

View File

@@ -12,9 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.sample_args import SampleArguments
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_worker import ModelWorker
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
self.model = model_worker.get_model()
self.processor = model_worker.get_processor()
self.args = sample_args
class ChatSampler:
def __init__(self, sample_args: SampleArguments) -> None:
self.args = sample_args
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_worker, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -12,11 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of data engine.
Init Data engine:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
"""
import os
from collections.abc import AsyncIterable, Iterable
from typing import Any, Union
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from torch.utils.data import Dataset
@@ -45,15 +57,13 @@ class DataEngine(Dataset):
def get_dataset_info(self) -> None:
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(
os.path.join(self.args.dataset_dir, self.args.dataset)
): # local file
self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
self.dataset_infos = OmegaConf.load(self.args.dataset)
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s)
elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
@@ -65,11 +75,13 @@ class DataEngine(Dataset):
streaming = value.get("streaming", False)
self.streaming |= streaming
if "hf_hub_url" in value:
from datasets import load_dataset
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
def build_data_index(self) -> None:
"""Build dataset index."""
@@ -145,11 +157,11 @@ class DataEngine(Dataset):
dataset_name, sample_index = selected_index
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
def __iter__(self) -> Iterable:
def __iter__(self) -> Iterable[Sample]:
"""Get dataset iterator.
Returns:
Iterable: Dataset iterator.
Iterable[Sample]: Dataset iterator.
"""
if self.streaming:
pass
@@ -159,11 +171,11 @@ class DataEngine(Dataset):
raise NotImplementedError()
async def __aiter__(self) -> AsyncIterable:
async def __aiter__(self) -> AsyncIterable[Sample]:
"""Get dataset async iterator.
Returns:
AsyncIterable: Dataset async iterator.
AsyncIterable[Sample]: Dataset async iterator.
"""
if self.streaming:
pass

View File

@@ -0,0 +1,98 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model worker.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
from typing import Optional
from transformers import AutoConfig, AutoProcessor
from ..config.model_args import ModelArguments
from ..extras.types import DistModel, HFConfig, HFModel, Processor
class ModelWorker:
def __init__(self, model_args: ModelArguments) -> None:
self.args = model_args
"""Model arguments."""
self.processor: Optional[Processor] = None
"""Tokenizer or multi-modal processor."""
self.model_config: Optional[HFConfig] = None
"""Model configuration."""
self.unwrapped_model: Optional[HFModel] = None
"""Unwrapped model."""
self.model: Optional[DistModel] = None
"""Distributed model."""
self.init_processor()
self.init_model_config()
self.init_model()
self.init_adapter()
def init_processor(self) -> None:
self.processor = AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def init_model_config(self) -> None:
self.model_config = AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def init_model(self) -> None:
if self.args.auto_model_class == "causallm":
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.auto_model_class == "classification":
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
self.unwrapped_model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map="cpu",
trust_remote_code=self.args.trust_remote_code,
)
def init_adapter(self) -> None:
pass
def get_processor(self) -> Processor:
return self.processor
def get_model_config(self) -> HFConfig:
return self.model_config
def get_model(self) -> HFModel:
return self.unwrapped_model

View File

@@ -0,0 +1,47 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from ...extras.types import Processor, Tensor, TorchDataset
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
# sft: messages
# dpo: chosen_messages, rejected_messages
class DataLoader:
"""Default DataLoader."""
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic

View File

@@ -28,18 +28,24 @@ if TYPE_CHECKING:
HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
DataCollator = transformers.DataCollator
DataLoader = torch.utils.data.DataLoader
HFConfig = transformers.PretrainedConfig
HFModel = transformers.PreTrainedModel
DistModel = torch.nn.parallel.DistributedDataParallel
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
Optimizer = torch.optim.Optimizer
Scheduler = torch.optim.lr_scheduler.LRScheduler
else:
Tensor = None
TorchDataset = None
HFDataset = None
DataCollator = None
DataLoader = None
HFConfig = None
HFModel = None
DistModel = None
Processor = None
Optimizer = None
Scheduler = None
class DatasetInfo(TypedDict, total=False):
@@ -86,10 +92,3 @@ class DPOSample(TypedDict):
Sample = Union[SFTSample, DPOSample]
class Model(TypedDict):
hf_model: HFModel
"""HF model."""
dist_model: DistModel
"""Distributed model."""

View File

@@ -19,7 +19,6 @@ from typing import Any, Literal, Optional, Union
from datasets import load_dataset
from ...config.data_args import DataArguments
from ...extras.types import DatasetInfo, HFDataset
@@ -27,9 +26,6 @@ from ...extras.types import DatasetInfo, HFDataset
class DataLoaderPlugin:
"""Plugin for loading dataset."""
args: DataArguments
"""Data arguments."""
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name.
@@ -42,7 +38,7 @@ class DataLoaderPlugin:
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir)
dataset_dir = dataset_info.get("dataset_dir", ".")
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
if "file_name" in dataset_info:

View File

@@ -18,9 +18,9 @@ import torch
import torch.nn.functional as F
import torch_npu
from .....accelerator.helper import is_torch_npu_available
from .....extras.packages import is_transformers_version_greater_than
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaMoEKernel

View File

@@ -17,8 +17,8 @@ import types
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaSwiGluKernel

View File

@@ -15,8 +15,8 @@
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional
from ....accelerator.helper import get_current_accelerator
from ....extras.types import HFModel
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
from .constants import DeviceType, KernelType
@@ -206,7 +206,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
discovered_kernels: list[type[MetaKernel]] = []
# Detect current device type
accelerator = get_available_accelerator()
accelerator = get_current_accelerator()
try:
device_type = DeviceType(accelerator.type)
except ValueError:
@@ -238,11 +238,11 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type:
return kernel.apply(model, **kwargs)
raise ValueError(
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead."
)

View File

@@ -14,8 +14,8 @@
import re
import types
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaRMSNormKernel

View File

@@ -16,8 +16,8 @@ import sys
import torch
from .....accelerator.helper import is_torch_npu_available
from .....extras.types import HFModel
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
from ..constants import DeviceType, KernelType
from ..registry import MetaRoPEKernel

View File

@@ -11,37 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
import torch
def get_available_accelerator():
"""Get available accelerator in current environment.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
"""
accelerator = torch.accelerator.current_accelerator()
if accelerator is None:
return torch.device("cpu")
return accelerator
@lru_cache
def is_torch_npu_available():
return get_available_accelerator().type == "npu"
@lru_cache
def is_torch_cuda_available():
return get_available_accelerator().type == "cuda"
@lru_cache
def is_torch_xpu_available():
return get_available_accelerator().type == "xpu"
@lru_cache
def is_torch_mps_available():
return get_available_accelerator().type == "mps"