[v1] add accelerator (#9607)

This commit is contained in:
Yaowei Zheng
2025-12-12 19:22:06 +08:00
committed by GitHub
parent 4fd94141a4
commit 203069e11c
36 changed files with 941 additions and 443 deletions

View File

@@ -1,4 +1,4 @@
dpo_zh_demo: dpo_zh_demo:
hf_hub_url: HuggingFaceH4/orca_dpo_pairs path: HuggingFaceH4/orca_dpo_pairs
split: train_prefs split: train_prefs
converter: pair converter: pair

View File

@@ -1,8 +1,9 @@
identity: identity:
file_name: data/identity.json path: data/identity.json
source: local
converter: alpaca converter: alpaca
alpaca_en_demo: alpaca_en_demo:
file_name: alpaca_en_demo.json path: data/alpaca_en_demo.json
dataset_dir: data source: local
converter: alpaca converter: alpaca
num_samples: 500 size: 500

View File

@@ -1,4 +1,7 @@
# Copyright 2025 the LlamaFactory team. # Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -12,12 +15,68 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Optional
import numpy as np
import torch import torch
import torch.distributed as dist
from ..utils.types import Tensor, TensorLike
def get_current_accelerator(check_available: bool = True): if TYPE_CHECKING:
from torch.distributed import ProcessGroup
@unique
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
META = "meta"
MPS = "mps"
NPU = "npu"
XPU = "xpu"
@unique
class ReduceOp(str, Enum):
SUM = "sum"
MEAN = "mean"
MAX = "max"
MIN = "min"
def is_distributed() -> bool:
"""Check if distributed environment is available."""
return os.getenv("RANK") is not None
def get_rank() -> int:
"""Get rank."""
return int(os.getenv("RANK", "0"))
def get_local_rank() -> int:
"""Get local rank."""
return int(os.getenv("LOCAL_RANK", "0"))
def get_world_size() -> int:
"""Get world size."""
return int(os.getenv("WORLD_SIZE", "1"))
def get_local_world_size() -> int:
"""Get local world size."""
return int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@lru_cache
def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator. """Get current accelerator.
Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError
@@ -27,26 +86,78 @@ def get_current_accelerator(check_available: bool = True):
accelerator = torch.accelerator.current_accelerator(check_available=check_available) accelerator = torch.accelerator.current_accelerator(check_available=check_available)
if accelerator is None: if accelerator is None:
return torch.device("cpu") return torch.device(DeviceType.CPU.value)
return accelerator return accelerator
@lru_cache
def is_torch_npu_available():
return get_current_accelerator().type == "npu"
@lru_cache
def is_torch_cuda_available(): def is_torch_cuda_available():
return get_current_accelerator().type == "cuda" return get_current_accelerator().type == DeviceType.CUDA
@lru_cache
def is_torch_xpu_available():
return get_current_accelerator().type == "xpu"
@lru_cache
def is_torch_mps_available(): def is_torch_mps_available():
return get_current_accelerator().type == "mps" return get_current_accelerator().type == DeviceType.MPS
def is_torch_npu_available():
return get_current_accelerator().type == DeviceType.NPU
def is_torch_xpu_available():
return get_current_accelerator().type == DeviceType.XPU
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
"""Gathers the tensor from all ranks and concats them along the first dim."""
world_size = get_world_size()
device = get_current_accelerator()
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
return output_tensor.view(-1, *tensor.size()[1:])
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
"""Performs all reduce in the given process group."""
device = get_current_accelerator()
is_ndarray = isinstance(data, np.ndarray)
is_tensor = isinstance(data, torch.Tensor)
if is_ndarray:
data = torch.from_numpy(data)
elif not is_tensor:
data = torch.tensor(data, dtype=torch.float, device=device)
reduce_ops = {
ReduceOp.MEAN: dist.ReduceOp.SUM,
ReduceOp.SUM: dist.ReduceOp.SUM,
ReduceOp.MAX: dist.ReduceOp.MAX,
ReduceOp.MIN: dist.ReduceOp.MIN,
}
dist.all_reduce(data, op=reduce_ops[op], group=group)
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
data /= dist.get_world_size(group=group)
if is_tensor:
return data
elif is_ndarray:
return data.numpy()
elif data.numel() == 1:
return data.item()
else:
return data.tolist()
@contextmanager
def main_process_first(local_only: bool = True) -> None:
"""A context manager for torch distributed environment to do something on the main process firstly."""
if get_world_size() > 1:
is_main_process = get_local_rank() == 0 if local_only else get_rank() == 0
try:
if not is_main_process:
dist.barrier()
yield
finally:
if is_main_process:
dist.barrier()
else:
yield

View File

@@ -0,0 +1,123 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Optional
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import TensorLike
from .helper import ReduceOp, all_reduce, get_current_accelerator, get_rank, get_world_size, is_distributed
@dataclass
class DistributedStrategy:
"""Distributed strategy."""
dp_size: Optional[int] = None
tp_size: int = 1
def __post_init__(self) -> None:
if not is_distributed():
self.dp_size = 1
elif self.dp_size is None:
self.dp_size = get_world_size() // self.tp_size
elif self.dp_size * self.tp_size != get_world_size():
raise ValueError(
f"dp_size * tp_size must equal to world_size, "
f"got {self.dp_size} * {self.tp_size} != {get_world_size()}."
)
@property
def mesh_shape(self) -> tuple[int, int]:
"""Mesh shape."""
return (self.dp_size, self.tp_size)
@property
def mesh_dim_names(self) -> tuple[str, str]:
"""Mesh dimension names."""
return ("dp", "tp")
class DistributedInterface:
"""Distributed interface."""
_instance: Optional["DistributedInterface"] = None
_initialized: bool = False
is_distributed = is_distributed()
"""Check if distributed environment is available."""
rank = get_rank()
"""Global rank."""
world_size = get_world_size()
"""Global world size."""
device_mesh: Optional[DeviceMesh] = None
"""Device mesh."""
current_accelerator = get_current_accelerator()
"""Current accelerator."""
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
"""Singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, strategy: DistributedStrategy) -> None:
if self._initialized:
return
self.strategy = strategy
if self.is_distributed:
self.device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=strategy.mesh_shape,
mesh_dim_names=strategy.mesh_dim_names,
)
else:
self.device_mesh = None
self._initialized = True
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self.is_distributed}, "
f"rank={self.rank}, world_size={self.world_size}, "
f"device_mesh={self.device_mesh}, current_accelerator={self.current_accelerator}"
)
def dp_rank(self) -> int:
"""Data parallel rank."""
if self.device_mesh is None:
return 0
return self.device_mesh["dp"].get_rank()
def dp_size(self) -> int:
"""Data parallel size."""
if self.device_mesh is None:
return 1
return self.device_mesh["dp"].size()
def all_reduce_over_dp(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN) -> TensorLike:
"""All reduce tensor."""
if self.device_mesh is None:
return data
return all_reduce(data, op, self.device_mesh["dp"].get_group())
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))

View File

@@ -28,6 +28,21 @@ from .sample_args import SampleArguments
from .training_args import TrainingArguments from .training_args import TrainingArguments
def validate_args(
data_args: DataArguments,
model_args: ModelArguments,
training_args: TrainingArguments,
sample_args: SampleArguments,
):
"""Validate arguments."""
if (
model_args.quant_config is not None
and training_args.dist_config is not None
and training_args.dist_config.name == "deepspeed"
):
raise ValueError("Quantization is not supported with deepspeed backend.")
def get_args( def get_args(
args: Optional[Union[dict[str, Any], list[str]]] = None, args: Optional[Union[dict[str, Any], list[str]]] = None,
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]: ) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
@@ -56,6 +71,8 @@ def get_args(
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
validate_args(*parsed_args)
return tuple(parsed_args) return tuple(parsed_args)

View File

@@ -0,0 +1,105 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/training_args.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import Enum, unique
from typing import Any, Optional, Union
class PluginConfig(dict):
"""Dictionary that allows attribute access."""
@property
def name(self) -> str:
"""Plugin name."""
if "name" not in self:
raise ValueError("Plugin configuration must have a 'name' field.")
return self["name"]
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(f"Attribute {key} not found.")
def __setattr__(self, key: str, value: Any):
self[key] = value
PluginArgument = Optional[Union[PluginConfig, dict, str]]
@unique
class AutoClass(str, Enum):
"""Auto class for model config."""
CAUSALLM = "llm"
CLASSIFICATION = "cls"
OTHER = "other"
@unique
class SampleBackend(str, Enum):
HF = "hf"
VLLM = "vllm"
def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.
Args:
data: The string or dictionary to convert.
Returns:
The converted dictionary.
"""
for key, value in data.items():
if isinstance(value, dict):
data[key] = _convert_str_dict(value)
elif isinstance(value, str):
if value.lower() in ("true", "false"):
data[key] = value.lower() == "true"
elif value.isdigit():
data[key] = int(value)
elif value.replace(".", "", 1).isdigit():
data[key] = float(value)
return data
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
"""Get the plugin configuration from the argument value.
Args:
config: The argument value to get the plugin configuration from.
Returns:
The plugin configuration.
"""
if config is None:
return None
if isinstance(config, str) and config.startswith("{"):
config = json.loads(config)
config = _convert_str_dict(config)
if "name" not in config:
raise ValueError("Plugin configuration must have a 'name' field.")
return PluginConfig(config)

View File

@@ -15,6 +15,8 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
@dataclass @dataclass
class ModelArguments: class ModelArguments:
@@ -29,7 +31,24 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Use fast processor from Hugging Face."}, metadata={"help": "Use fast processor from Hugging Face."},
) )
auto_model_class: str = field( auto_class: AutoClass = field(
default="causallm", default=AutoClass.CAUSALLM,
metadata={"help": "Model class from Hugging Face."}, metadata={"help": "Model class from Hugging Face."},
) )
peft_config: PluginConfig = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
)
kernel_config: PluginConfig = field(
default=None,
metadata={"help": "Kernel configuration for the model."},
)
quant_config: PluginConfig = field(
default=None,
metadata={"help": "Quantization configuration for the model."},
)
def __post_init__(self) -> None:
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -14,12 +14,8 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from .arg_utils import SampleBackend
class SampleBackend(Enum):
HF = "hf"
VLLM = "vllm"
@dataclass @dataclass

View File

@@ -15,6 +15,8 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .arg_utils import PluginArgument, get_plugin_config
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
@@ -38,3 +40,10 @@ class TrainingArguments:
default=False, default=False,
metadata={"help": "Use bf16 for training."}, metadata={"help": "Use bf16 for training."},
) )
dist_config: PluginArgument = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)
def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -29,7 +29,7 @@ Train Phase:
""" """
from ..config.training_args import TrainingArguments from ..config.training_args import TrainingArguments
from ..extras.types import TorchDataset from ..utils.types import TorchDataset
from .model_worker import ModelWorker from .model_worker import ModelWorker
from .trainer_utils.data_collator import DataCollator from .trainer_utils.data_collator import DataCollator
@@ -49,13 +49,10 @@ class BaseTrainer:
self.optimizer = None self.optimizer = None
self.lr_scheduler = None self.lr_scheduler = None
def init_device_mesh(self) -> None:
pass
def init_model_and_optimizer(self) -> None: def init_model_and_optimizer(self) -> None:
self.model_config = self.model_worker.get_model_config() self.model_worker.init_model_config()
# with self.dist_plugin.get_model_init_context(): # with self.dist_plugin.get_model_init_context():
# self.model = self.model_worker.get_model(self.model_config) # self.model = self.model_worker.init_model(self.model_config)
def create_dataloader(self) -> None: def create_dataloader(self) -> None:
pass pass

View File

@@ -34,7 +34,7 @@ from omegaconf import OmegaConf
from torch.utils.data import Dataset from torch.utils.data import Dataset
from ..config.data_args import DataArguments from ..config.data_args import DataArguments
from ..extras.types import DatasetInfo, HFDataset, Sample from ..utils.types import DatasetInfo, HFDataset, Sample
class DataEngine(Dataset): class DataEngine(Dataset):
@@ -64,9 +64,9 @@ class DataEngine(Dataset):
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath) self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(self.args.dataset): # local file(s) elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"file_name": self.args.dataset}} self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}} self.dataset_infos = {"default": {"path": self.args.dataset}}
def load_dataset(self) -> None: def load_dataset(self) -> None:
"""Load datasets according to dataset info.""" """Load datasets according to dataset info."""
@@ -74,14 +74,14 @@ class DataEngine(Dataset):
split = value.get("split", "train") split = value.get("split", "train")
streaming = value.get("streaming", False) streaming = value.get("streaming", False)
self.streaming |= streaming self.streaming |= streaming
if "hf_hub_url" in value: if value.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset from datasets import load_dataset
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming) self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
else: # data loader plugin else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin().auto_load_data(value) self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
def build_data_index(self) -> None: def build_data_index(self) -> None:
"""Build dataset index.""" """Build dataset index."""
@@ -112,9 +112,9 @@ class DataEngine(Dataset):
""" """
converter = self.dataset_infos[dataset_name].get("converter") converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None: if converter is not None:
from ..plugins.data_plugins.converter import get_converter from ..plugins.data_plugins.converter import DataConverterPlugin
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)} return {"_dataset_name": dataset_name, **DataConverterPlugin(converter)(raw_sample)}
else: else:
return {"_dataset_name": dataset_name, **raw_sample} return {"_dataset_name": dataset_name, **raw_sample}
@@ -147,7 +147,7 @@ class DataEngine(Dataset):
else: else:
from ..plugins.data_plugins.loader import DataSelectorPlugin from ..plugins.data_plugins.loader import DataSelectorPlugin
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index) selected_index = DataSelectorPlugin().select(self.data_index, index)
if isinstance(selected_index, list): if isinstance(selected_index, list):
return [ return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
@@ -187,7 +187,7 @@ class DataEngine(Dataset):
if __name__ == "__main__": if __name__ == "__main__":
from ..config.parser import get_args from ..config.arg_parser import get_args
data_args, *_ = get_args() data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args) data_engine = DataEngine(data_args=data_args)

View File

@@ -24,10 +24,12 @@ Init Phase:
from typing import Optional from typing import Optional
import torch
from transformers import AutoConfig, AutoProcessor from transformers import AutoConfig, AutoProcessor
from ..config.model_args import ModelArguments from ..accelerator.helper import DeviceType
from ..extras.types import DistModel, HFConfig, HFModel, Processor from ..config.model_args import AutoClass, ModelArguments
from ..utils.types import HFConfig, HFModel, Processor
class ModelWorker: class ModelWorker:
@@ -38,16 +40,15 @@ class ModelWorker:
"""Tokenizer or multi-modal processor.""" """Tokenizer or multi-modal processor."""
self.model_config: Optional[HFConfig] = None self.model_config: Optional[HFConfig] = None
"""Model configuration.""" """Model configuration."""
self.unwrapped_model: Optional[HFModel] = None self.model: Optional[HFModel] = None
"""Unwrapped model.""" """HF model."""
self.model: Optional[DistModel] = None self.is_adapter = False
"""Distributed model.""" """Whether the model has adapter."""
self.init_processor()
self.init_model_config()
self.init_model()
self.init_adapter()
def init_processor(self) -> None: def init_processor(self) -> None:
if self.processor is not None:
return
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
self.args.model, self.args.model,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
@@ -55,38 +56,58 @@ class ModelWorker:
) )
def init_model_config(self) -> None: def init_model_config(self) -> None:
if self.model_config is not None:
return
self.model_config = AutoConfig.from_pretrained( self.model_config = AutoConfig.from_pretrained(
self.args.model, self.args.model,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
) )
def init_model(self) -> None: def init_model(self) -> None:
if self.args.auto_model_class == "causallm": if self.model is not None:
return
self.init_model_config()
if self.args.auto_class == AutoClass.CAUSALLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys(): if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText ModelClass = AutoModelForImageTextToText
else: else:
AutoClass = AutoModelForCausalLM ModelClass = AutoModelForCausalLM
elif self.args.auto_model_class == "classification": elif self.args.auto_class == AutoClass.CLASSIFICATION:
from transformers import AutoModelForTokenClassification from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification ModelClass = AutoModelForTokenClassification
else: else:
from transformers import AutoModel from transformers import AutoModel
AutoClass = AutoModel ModelClass = AutoModel
self.unwrapped_model = AutoClass.from_pretrained( default_device_type = torch.get_default_device().type
self.args.model, if default_device_type == DeviceType.META:
config=self.model_config, self.model = ModelClass.from_config(self.model_config)
dtype="auto", else:
device_map="cpu", self.model = ModelClass.from_pretrained(
trust_remote_code=self.args.trust_remote_code, self.args.model,
) config=self.model_config,
dtype="auto",
device_map=default_device_type,
trust_remote_code=self.args.trust_remote_code,
)
def init_adapter(self) -> None: def init_adapter(self) -> None:
pass if self.is_adapter:
return
if self.args.peft_config is not None:
from ..plugins.model_plugins.peft import PeftPlugin
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
self.is_adapter = True
def get_processor(self) -> Processor: def get_processor(self) -> Processor:
return self.processor return self.processor
@@ -95,4 +116,4 @@ class ModelWorker:
return self.model_config return self.model_config
def get_model(self) -> HFModel: def get_model(self) -> HFModel:
return self.unwrapped_model return self.model

View File

@@ -15,7 +15,7 @@
from typing import Any from typing import Any
from ...extras.types import Processor, Tensor, TorchDataset from ...utils.types import Processor, Tensor, TorchDataset
class DataCollator: class DataCollator:

View File

@@ -44,7 +44,7 @@ WELCOME = (
def launch(): def launch():
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "sft": if command == "sft": # train command will fallback to sft command
from .trainers.sft_trainer import run_sft from .trainers.sft_trainer import run_sft
run_sft() run_sft()

View File

@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import Callable, TypedDict from typing import Any, Literal, TypedDict
from typing_extensions import NotRequired, Required from typing_extensions import NotRequired
from ....extras import logging from ...utils import logging
from ...extras.types import DPOSample, Sample, SFTSample from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -26,35 +27,48 @@ logger = logging.get_logger(__name__)
class AlpacaSample(TypedDict, total=False): class AlpacaSample(TypedDict, total=False):
system: NotRequired[str] system: NotRequired[str]
instruction: NotRequired[str] instruction: str
input: NotRequired[str] input: NotRequired[str]
output: NotRequired[str] output: str
ShareGPTMessage = TypedDict( SharegptMessage = TypedDict(
"ShareGPTMessage", "SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
{
"from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system")
"value": Required[str], # Content of the message
},
) )
class ShareGPTSample(TypedDict, total=False): class SharegptSample(TypedDict, total=False):
"""Type definition for raw ShareGPT sample.""" conversations: list[SharegptMessage]
tools: NotRequired[str]
conversations: Required[list[ShareGPTMessage]]
class OpenaiMessage(TypedDict, total=False):
role: Literal["user", "assistant", "tool"]
content: str
class OpenaiSample(TypedDict, total=False):
messages: list[OpenaiMessage]
class PairSample(TypedDict, total=False): class PairSample(TypedDict, total=False):
prompt: NotRequired[str] chosen: list[OpenaiMessage]
chosen: NotRequired[list[dict]] rejected: list[OpenaiMessage]
rejected: NotRequired[list[dict]]
class DataConverterPlugin(BasePlugin):
"""Plugin for data converters."""
def __call__(self, raw_sample: dict[str, Any]) -> Sample:
return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample. """Convert Alpaca sample to SFT sample.
See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en
Args: Args:
raw_sample (AlpacaSample): Alpaca sample. raw_sample (AlpacaSample): Alpaca sample.
@@ -67,20 +81,6 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0} {"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
) )
if "history" in raw_sample:
for idx, item in enumerate(raw_sample["history"]):
if len(item) != 2:
logger.warning_rank0(
f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping."
)
continue
old_prompt, old_response = item
messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0})
messages.append(
{"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0}
)
if "instruction" in raw_sample or "input" in raw_sample: if "instruction" in raw_sample or "input" in raw_sample:
messages.append( messages.append(
{ {
@@ -100,149 +100,85 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages} return {"messages": messages}
def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample: @DataConverterPlugin("sharegpt").register
"""Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample. def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
Retains only SFT-relevant scenarios and removes parity checks. See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en
Args: Args:
raw_sample (ShareGPTSample): A raw sample in ShareGPT format. raw_sample (SharegptSample): ShareGPT sample.
Returns: Returns:
dict: A dictionary containing the formatted 'messages' list for SFT training. SFTSample: SFT sample.
Returns an empty list if the input data is invalid.
""" """
tag_mapping = { tag_mapping = {
"system": "system",
"human": "user", "human": "user",
"gpt": "assistant", "gpt": "assistant",
"observation": "observation", "observation": "tool",
"function_call": "function", "function_call": "assistant",
} }
messages = raw_sample.get("conversations", []) messages = []
aligned_messages = [] tools = raw_sample.get("tools", "")
system_content = ""
# Extract system message if present (typically the first message) for message in raw_sample.get("conversations", []):
if messages and messages[0]["from"] == "system": tag = message["from"]
system_content = messages[0]["value"] if tag not in tag_mapping:
messages = messages[1:] logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
messages.append(
{
"role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}],
"loss_weight": 1.0,
}
)
else:
messages.append(
{
"role": tag_mapping[tag],
"content": [{"type": "text", "value": message["value"]}],
"loss_weight": 1.0 if tag == "gpt" else 0.0,
}
)
if system_content: if tools:
aligned_messages.append( if messages and messages[0]["role"] == "system":
{"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0} messages[0]["content"].append({"type": "tools", "value": tools})
) else:
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
has_invalid_role = False return {"messages": messages}
for message in messages:
sender = message["from"]
# validate sender is in supported tags
if sender not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}")
has_invalid_role = True
break
aligned_messages.append(
{
"role": tag_mapping[sender],
"content": [{"type": "text", "value": message["value"]}],
"loss_weight": 0.0 if sender in ("human", "observation") else 1.0,
}
)
if has_invalid_role:
logger.warning_rank0("Skipping invalid example due to unsupported role tags.")
return {"messages": []}
return {"messages": aligned_messages}
@DataConverterPlugin("pair").register
def pair_converter(raw_sample: PairSample) -> DPOSample: def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to standard DPO sample. """Convert Pair sample to DPO sample.
See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
Args: Args:
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields. raw_sample (PairSample): pair sample with chosen, rejected fields.
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
Returns: Returns:
DPOSample: DPO sample with chosen_messages and rejected_messages. DPOSample: DPO sample with chosen_messages and rejected_messages.
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
""" """
chosen_messages = []
assert "chosen" in raw_sample, "chosen field is required in pair sample."
assert "rejected" in raw_sample, "rejected field is required in pair sample."
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
)
if "chosen" in raw_sample: def process_message(raw_messages: list[OpenaiMessage]):
value = raw_sample.get("chosen", "") messages = []
for item in value: for message in raw_messages:
if item.get("role", "") == "system": messages.append(
chosen_messages.append( {
{ "role": message["role"],
"role": "system", "content": [{"type": "text", "value": message["content"]}],
"content": [{"type": "text", "value": item.get("content", "")}], "loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
"loss_weight": 0.0, }
} )
)
if item.get("role", "") == "user":
chosen_messages.append(
{
"role": "user",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "assistant":
chosen_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 1.0,
}
)
rejected_messages = [] return messages
if "rejected" in raw_sample:
value = raw_sample.get("rejected", "") chosen_messages = process_message(raw_sample.get("chosen", []))
for item in value: rejected_messages = process_message(raw_sample.get("rejected", []))
if item.get("role", "") == "system":
rejected_messages.append(
{
"role": "system",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "user":
rejected_messages.append(
{
"role": "user",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "assistant":
rejected_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 1.0,
}
)
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages} return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
CONVERTERS = {
"alpaca": alpaca_converter,
"pair": pair_converter,
"sharegpt": sharegpt_converter,
}
def get_converter(converter_name: str) -> Callable[[dict], Sample]:
if converter_name not in CONVERTERS:
raise ValueError(f"Converter {converter_name} not found.")
return CONVERTERS[converter_name]

View File

@@ -14,57 +14,59 @@
import os import os
from dataclasses import dataclass import random
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from datasets import load_dataset from datasets import load_dataset
from ...extras.types import DatasetInfo, HFDataset from ...utils.plugin import BasePlugin
from ...utils.types import DatasetInfo, HFDataset
@dataclass class DataLoaderPlugin(BasePlugin):
class DataLoaderPlugin:
"""Plugin for loading dataset.""" """Plugin for loading dataset."""
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]: def load(self, dataset_info: DatasetInfo) -> HFDataset:
"""Get dataset builder name. path = dataset_info["path"]
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
dataset_dir = dataset_info.get("dataset_dir", ".")
split = dataset_info.get("split", "train") split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False) streaming = dataset_info.get("streaming", False)
if "file_name" in dataset_info: return super().__call__(path, split, streaming)
filepath = os.path.join(dataset_dir, dataset_info["file_name"])
return self.load_data_from_file(filepath, split, streaming)
else:
raise NotImplementedError()
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = self._get_builder_name(os.listdir(filepath)[0])
dataset = load_dataset(filetype, data_dir=filepath, split=split)
elif os.path.isfile(filepath):
filetype = self._get_builder_name(filepath)
dataset = load_dataset(filetype, data_files=filepath, split=split)
else:
raise ValueError(f"Can not load dataset from {filepath}.")
if streaming:
dataset = dataset.to_iterable_dataset()
return dataset
@dataclass def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
class DataIndexPlugin: """Get dataset builder name.
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
filetype = os.path.splitext(path)[-1][1:]
if filetype in ["arrow", "csv", "json", "jsonl", "parquet", "txt"]:
return filetype.replace("jsonl", "json").replace("txt", "text")
else:
raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0])
dataset = load_dataset(filetype, data_dir=filepath, split=split)
elif os.path.isfile(filepath):
filetype = _get_builder_name(filepath)
dataset = load_dataset(filetype, data_files=filepath, split=split)
else:
raise ValueError(f"Can not load dataset from {filepath}.")
if streaming: # faster when data is streamed from local files
dataset = dataset.to_iterable_dataset()
return dataset
class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index.""" """Plugin for adjusting dataset index."""
def adjust_data_index( def adjust_data_index(
@@ -81,39 +83,32 @@ class DataIndexPlugin:
list[tuple[str, int]]: Adjusted dataset index. list[tuple[str, int]]: Adjusted dataset index.
""" """
if size is not None: if size is not None:
data_index = self.adjust_by_size(data_index, size) data_index = random.choices(data_index, k=size)
if weight is not None: if weight is not None:
data_index = self.adjust_by_weight(data_index, weight) data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index return data_index
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
raise NotImplementedError()
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]: class DataSelectorPlugin(BasePlugin):
raise NotImplementedError()
@dataclass
class DataSelectorPlugin:
"""Plugin for selecting dataset samples.""" """Plugin for selecting dataset samples."""
data_index: list[tuple[str, int]] def select(
"""List of (dataset_name, sample_index)""" self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
) -> Union[tuple[str, int], list[tuple[str, int]]]:
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
"""Select dataset samples. """Select dataset samples.
Args: Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples. index (Union[slice, list[int], Any]): Index of dataset samples.
Returns: Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples. Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
""" """
if isinstance(index, slice): if isinstance(index, slice):
return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))] return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list): elif isinstance(index, list):
return [self.data_index[i] for i in index] return [data_index[i] for i in index]
else: else:
raise ValueError(f"Invalid index type {type(index)}.") raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -21,10 +21,3 @@ class KernelType(str, Enum):
FLASH_ATTENTION = "flash_attention" FLASH_ATTENTION = "flash_attention"
ROPE = "rope" ROPE = "rope"
MOE = "moe" MOE = "moe"
class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"
NPU = "npu"
XPU = "xpu"

View File

@@ -18,10 +18,10 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch_npu import torch_npu
from .....accelerator.helper import is_torch_npu_available from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....extras.packages import is_transformers_version_greater_than from .....utils.packages import is_transformers_version_greater_than
from .....extras.types import HFModel from .....utils.types import HFModel
from ..constants import DeviceType, KernelType from ..constants import KernelType
from ..registry import MetaMoEKernel from ..registry import MetaMoEKernel

View File

@@ -17,9 +17,9 @@ import types
import torch import torch
from .....accelerator.helper import is_torch_npu_available from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....extras.types import HFModel from .....utils.types import HFModel
from ..constants import DeviceType, KernelType from ..constants import KernelType
from ..registry import MetaSwiGluKernel from ..registry import MetaSwiGluKernel

View File

@@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
from abc import ABC, ABCMeta, abstractmethod from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Union
from ....accelerator.helper import get_current_accelerator from ....accelerator.helper import DeviceType, get_current_accelerator
from ....extras.types import HFModel from ....utils.types import HFModel
from .constants import DeviceType, KernelType from .constants import KernelType
class KernelRegistry: class KernelRegistry:
@@ -27,11 +27,13 @@ class KernelRegistry:
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry": def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self) -> None: def __init__(self) -> None:
if self._initialized: if self._initialized:
return return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {} self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True self._initialized = True
@@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels return discovered_kernels
# Iterate through registry and collect all kernels for current device # Iterate through registry and collect all kernels for current device
for kernel_type, devices in KERNEL_REGISTRY._registry.items(): for devices in KERNEL_REGISTRY._registry.values():
kernel_cls = devices.get(device_type) kernel_cls = devices.get(device_type)
if kernel_cls is not None: if kernel_cls is not None:
discovered_kernels.append(kernel_cls) discovered_kernels.append(kernel_cls)
@@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels return discovered_kernels
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel": def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement. """Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only Corresponding replacement logic is maintained inside each kernel; the only
@@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B") model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel) model = apply_kernel(model, NpuRMSNormKernel)
""" """
if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type: if not issubclass(kernel, MetaKernel):
return kernel.apply(model, **kwargs) raise ValueError(f"{kernel} must be a MetaKernel instance.")
raise ValueError( if kernel.device != get_current_accelerator().type:
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead." raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.")
)
return kernel.apply(model, **kwargs)
def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel": def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel":
"""Apply all available kernels to the model.""" """Apply all available kernels to the model."""
for kernel in discover_kernels(model): for kernel in discover_kernels(model):
model = apply_kernel(model, kernel, **kwargs) model = apply_kernel(model, kernel, **kwargs)
return model return model

View File

@@ -14,9 +14,9 @@
import re import re
import types import types
from .....accelerator.helper import is_torch_npu_available from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....extras.types import HFModel from .....utils.types import HFModel
from ..constants import DeviceType, KernelType from ..constants import KernelType
from ..registry import MetaRMSNormKernel from ..registry import MetaRMSNormKernel

View File

@@ -16,9 +16,9 @@ import sys
import torch import torch
from .....accelerator.helper import is_torch_npu_available from .....accelerator.helper import DeviceType, is_torch_npu_available
from .....extras.types import HFModel from .....utils.types import HFModel
from ..constants import DeviceType, KernelType from ..constants import KernelType
from ..registry import MetaRoPEKernel from ..registry import MetaRoPEKernel

View File

@@ -0,0 +1,42 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model
from ...utils.plugin import BasePlugin
from ...utils.types import HFModel
class LoraConfigDict(TypedDict, total=False):
name: Literal["lora"]
"""Plugin name."""
r: int
"""Lora rank."""
lora_alpha: int
"""Lora alpha."""
target_modules: list[str]
"""Target modules."""
class PeftPlugin(BasePlugin):
pass
@PeftPlugin("lora").register
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model

View File

@@ -1,13 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -13,22 +13,21 @@
# limitations under the License. # limitations under the License.
from ..config.parser import get_args from ..accelerator.interface import DistributedInterface, DistributedStrategy
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine from ..core.model_worker import ModelWorker
class SFTTrainer(BaseTrainer): class SFTTrainer(BaseTrainer):
pass pass
def run_sft(): def run_sft(user_args):
model_args, data_args, training_args, _ = get_args() model_args, data_args, training_args, _ = get_args(user_args)
model_engine = ModelEngine(model_args) DistributedInterface(DistributedStrategy())
data_engine = DataEngine(data_args) data_engine = DataEngine(data_args)
model = model_engine.get_model() model_worker = ModelWorker(model_args)
processor = model_engine.get_processor() trainer = SFTTrainer(training_args, model_worker, data_engine)
data_loader = data_engine.get_data_loader(processor)
trainer = SFTTrainer(training_args, model, processor, data_loader)
trainer.fit() trainer.fit()

View File

View File

@@ -0,0 +1,13 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,123 @@
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import threading
from functools import lru_cache
from typing import Optional
_thread_lock = threading.RLock()
_default_handler: Optional["logging.Handler"] = None
_default_log_level: "logging._Level" = logging.INFO
class _Logger(logging.Logger):
r"""A logger that supports rank0 logging."""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def _get_default_logging_level() -> "logging._Level":
r"""Return the default logging level."""
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "_Logger":
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
r"""Configure root logger using a stdout stream handler with an explicit format."""
global _default_handler
with _thread_lock:
if _default_handler: # already configured
return
formatter = logging.Formatter(
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def add_handler(handler: "logging.Handler") -> None:
r"""Add a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""Remove a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.info(*args, **kwargs)
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
@lru_cache(None)
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_rank0_once = warning_rank0_once

View File

@@ -0,0 +1,86 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
from . import logging
logger = logging.get_logger(__name__)
class BasePlugin:
"""Base class for plugins.
A plugin is a callable object that can be registered and called by name.
"""
_registry: dict[str, Callable] = {}
def __init__(self, name: Optional[str] = None):
"""Initialize the plugin with a name.
Args:
name (str): The name of the plugin.
"""
self.name = name
@property
def register(self) -> Callable:
"""Decorator to register a function as a plugin.
Example usage:
```python
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
```
"""
if self.name is None:
raise ValueError("Plugin name is not specified.")
if self.name in self._registry:
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable:
self._registry[self.name] = func
return func
return decorator
def __call__(self, *args, **kwargs) -> Callable:
"""Call the registered function with the given arguments.
Example usage:
```python
PrintPlugin("hello")()
```
"""
if self.name not in self._registry:
raise ValueError(f"Plugin {self.name} is not registered.")
return self._registry[self.name](*args, **kwargs)
if __name__ == "__main__":
class PrintPlugin(BasePlugin):
pass
@PrintPlugin("hello").register
def print_hello():
print("Hello world!")
PrintPlugin("hello")()

View File

@@ -19,23 +19,27 @@ from typing_extensions import NotRequired
if TYPE_CHECKING: if TYPE_CHECKING:
import datasets import datasets
import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
import transformers import transformers
from torch.distributed.fsdp import FullyShardedDataParallel
Tensor = torch.Tensor Tensor = torch.Tensor
TensorLike = Union[int, float, list[int], list[float], np.ndarray, Tensor]
TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset] TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]
HFDataset = Union[datasets.Dataset, datasets.IterableDataset] HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
DataCollator = transformers.DataCollator DataCollator = transformers.DataCollator
DataLoader = torch.utils.data.DataLoader DataLoader = torch.utils.data.DataLoader
HFConfig = transformers.PretrainedConfig HFConfig = transformers.PretrainedConfig
HFModel = transformers.PreTrainedModel HFModel = transformers.PreTrainedModel
DistModel = torch.nn.parallel.DistributedDataParallel DistModel = Union[torch.nn.parallel.DistributedDataParallel, FullyShardedDataParallel]
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin] Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
Optimizer = torch.optim.Optimizer Optimizer = torch.optim.Optimizer
Scheduler = torch.optim.lr_scheduler.LRScheduler Scheduler = torch.optim.lr_scheduler.LRScheduler
else: else:
Tensor = None Tensor = None
TensorLike = None
TorchDataset = None TorchDataset = None
HFDataset = None HFDataset = None
DataCollator = None DataCollator = None
@@ -49,12 +53,10 @@ else:
class DatasetInfo(TypedDict, total=False): class DatasetInfo(TypedDict, total=False):
hf_hub_url: NotRequired[str] path: str
"""HF hub dataset uri."""
file_name: NotRequired[str]
"""Local file path.""" """Local file path."""
dataset_dir: NotRequired[str] source: NotRequired[Literal["hf_hub", "ms_hub", "local"]]
"""Dataset directory, default to args.dataset_dir.""" """Dataset source, default to "hf_hub"."""
split: NotRequired[str] split: NotRequired[str]
"""Dataset split, default to "train".""" """Dataset split, default to "train"."""
converter: NotRequired[str] converter: NotRequired[str]
@@ -68,12 +70,12 @@ class DatasetInfo(TypedDict, total=False):
class Content(TypedDict): class Content(TypedDict):
type: Literal["text", "tools", "reasoning", "tool_calls", "image_url"] type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
value: str value: str
class Message(TypedDict): class Message(TypedDict):
role: Literal["system", "user", "assistant"] role: Literal["system", "user", "assistant", "tool"]
content: list[Content] content: list[Content]
loss_weight: float loss_weight: float

View File

@@ -12,25 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from torch.distributed.device_mesh import DeviceMesh import os
from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy
class DeviceMeshManager: def test_distributed_interface():
"""Device mesh manager.""" DistributedInterface(DistributedStrategy())
assert DistributedInterface.rank == int(os.getenv("RANK", "0"))
_instance: Optional["DeviceMeshManager"] = None assert DistributedInterface.world_size == int(os.getenv("WORLD_SIZE", "1"))
_initialized: bool = False
def __new__(cls) -> "DeviceMeshManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self.device_mesh: Optional[DeviceMesh] = None
self._initialized = True

View File

@@ -19,7 +19,7 @@ from datasets import load_dataset
from llamafactory.v1.config.data_args import DataArguments from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.plugins.data_plugins.converter import get_converter from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
@pytest.mark.parametrize("num_samples", [16]) @pytest.mark.parametrize("num_samples", [16])
@@ -49,99 +49,27 @@ def test_alpaca_converter(num_samples: int):
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data} assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
def test_sharegpt_converter_invalid(): def test_sharegpt_converter():
example = { example = {
"conversations": [ "conversations": [
{ {"from": "system", "value": "System"},
"from": "system", {"from": "human", "value": "User"},
"value": "Processes historical market data to generate trading signals " {"from": "gpt", "value": "Assistant"},
"based on specified technical indicators.",
},
{
"from": "human",
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
"Could you proceed with these function calls to assist me with the task?",
},
{
"from": "gpt",
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
"'name': 'backtest_trading_signals'}```\n",
},
{
"from": "tool",
"value": '<tool id="D2">\n{"analysis": {"RSI_signals": [{"date": "2025-01-10", '
'"symbol": "AAPL", "signal": "Buy"}]}}}\n</tool>\n',
},
] ]
} }
dataset_converter = get_converter("sharegpt")
assert dataset_converter(example) == {"messages": []}
def test_sharegpt_converter_valid():
example = {
"conversations": [
{
"from": "system",
"value": "Processes historical market data to generate trading signals based on "
"specified technical indicators.",
},
{
"from": "human",
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
"Could you proceed with these function calls to assist me with the task?",
},
{
"from": "gpt",
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
"'name': 'backtest_trading_signals'}```\n",
},
]
}
dataset_converter = get_converter("sharegpt")
expected_data = { expected_data = {
"messages": [ "messages": [
{ {"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
"content": [ {"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
{ {"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
"type": "text",
"value": "Processes historical market data to generate trading signals based on "
"specified technical indicators.",
}
],
"loss_weight": 0.0,
"role": "system",
},
{
"content": [
{
"type": "text",
"value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. "
"Could you proceed with these function calls to assist me with the task?",
}
],
"loss_weight": 0.0,
"role": "user",
},
{
"content": [
{
"type": "text",
"value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', "
"'name': 'backtest_trading_signals'}```\n",
}
],
"loss_weight": 1.0,
"role": "assistant",
},
] ]
} }
assert dataset_converter(example) == expected_data assert DataConverterPlugin("sharegpt")(example) == expected_data
@pytest.mark.parametrize("num_samples", [16]) @pytest.mark.parametrize("num_samples", [16])
def test_pair_converter(num_samples: int): def test_pair_converter(num_samples: int):
data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml") data_args = DataArguments(dataset="llamafactory/tiny-preference-dataset/dataset_info.yaml")
data_engine = DataEngine(data_args) data_engine = DataEngine(data_args)
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs") original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
indexes = random.choices(range(len(data_engine)), k=num_samples) indexes = random.choices(range(len(data_engine)), k=num_samples)
@@ -189,6 +117,5 @@ def test_pair_converter(num_samples: int):
if __name__ == "__main__": if __name__ == "__main__":
test_alpaca_converter(1) test_alpaca_converter(1)
test_sharegpt_converter_invalid() test_sharegpt_converter()
test_sharegpt_converter_valid()
test_pair_converter(1) test_pair_converter(1)

View File

@@ -17,10 +17,13 @@ from unittest.mock import MagicMock, patch
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator
class TestKernelPlugin(unittest.TestCase): class TestKernelPlugin(unittest.TestCase):
@patch("torch.accelerator.current_accelerator") @patch("torch.accelerator.current_accelerator")
def test_apply_kernel(self, mock_get_accelerator): def test_apply_kernel(self, mock_get_accelerator):
get_current_accelerator.cache_clear()
mock_device = MagicMock() mock_device = MagicMock()
mock_device.type = "npu" mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device mock_get_accelerator.return_value = mock_device
@@ -47,6 +50,7 @@ class TestKernelPlugin(unittest.TestCase):
class Test_Use_V1_Kernels(unittest.TestCase): class Test_Use_V1_Kernels(unittest.TestCase):
@patch("torch.accelerator.current_accelerator") @patch("torch.accelerator.current_accelerator")
def test_use_v1_kernels(self, mock_get_accelerator): def test_use_v1_kernels(self, mock_get_accelerator):
get_current_accelerator.cache_clear()
mock_device = MagicMock() mock_device = MagicMock()
mock_device.type = "npu" mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device mock_get_accelerator.return_value = mock_device