From 203069e11c66f38767321a5d00d7e16e8e636609 Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Fri, 12 Dec 2025 19:22:06 +0800 Subject: [PATCH] [v1] add accelerator (#9607) --- data/v1_dpo_demo.yaml | 2 +- data/v1_sft_demo.yaml | 9 +- src/llamafactory/v1/accelerator/helper.py | 145 +++++++++-- src/llamafactory/v1/accelerator/interface.py | 123 +++++++++ .../v1/config/{parser.py => arg_parser.py} | 17 ++ src/llamafactory/v1/config/arg_utils.py | 105 ++++++++ src/llamafactory/v1/config/model_args.py | 23 +- src/llamafactory/v1/config/sample_args.py | 6 +- src/llamafactory/v1/config/training_args.py | 9 + src/llamafactory/v1/core/base_trainer.py | 9 +- src/llamafactory/v1/core/data_engine.py | 20 +- src/llamafactory/v1/core/model_worker.py | 71 ++++-- .../v1/core/trainer_utils/data_collator.py | 2 +- src/llamafactory/v1/launcher.py | 2 +- .../v1/plugins/data_plugins/converter.py | 240 +++++++----------- .../v1/plugins/data_plugins/loader.py | 105 ++++---- .../model_plugins/kernels/attn}/__init__.py | 0 .../model_plugins/kernels/constants.py | 7 - .../kernels/mlp/npu_fused_moe.py | 8 +- .../model_plugins/kernels/mlp/npu_swiglu.py | 6 +- .../plugins/model_plugins/kernels/registry.py | 26 +- .../kernels/rms_norm/npu_rms_norm.py | 6 +- .../model_plugins/kernels/rope/npu_rope.py | 6 +- .../v1/plugins/model_plugins/peft.py | 42 +++ .../trainer_plugins/distributed/accelerate.py | 13 - .../distributed/deepspeed.py} | 0 src/llamafactory/v1/trainers/sft_trainer.py | 17 +- src/llamafactory/v1/utils/__init__.py | 0 src/llamafactory/v1/utils/constants.py | 13 + src/llamafactory/v1/utils/logging.py | 123 +++++++++ .../v1/{extras => utils}/packages.py | 0 src/llamafactory/v1/utils/plugin.py | 86 +++++++ .../v1/{extras => utils}/types.py | 18 +- .../accelerator/test_interface.py | 26 +- .../plugins/data_plugins/test_converter.py | 95 +------ .../model_plugins/test_kernel_plugin.py | 4 + 36 files changed, 941 insertions(+), 443 deletions(-) create mode 100644 src/llamafactory/v1/accelerator/interface.py rename src/llamafactory/v1/config/{parser.py => arg_parser.py} (84%) create mode 100644 src/llamafactory/v1/config/arg_utils.py rename src/llamafactory/v1/{extras => plugins/model_plugins/kernels/attn}/__init__.py (100%) rename src/llamafactory/v1/plugins/{model_plugins/kernels/fa/__init__.py => trainer_plugins/distributed/deepspeed.py} (100%) create mode 100644 src/llamafactory/v1/utils/__init__.py create mode 100644 src/llamafactory/v1/utils/constants.py create mode 100644 src/llamafactory/v1/utils/logging.py rename src/llamafactory/v1/{extras => utils}/packages.py (100%) create mode 100644 src/llamafactory/v1/utils/plugin.py rename src/llamafactory/v1/{extras => utils}/types.py (82%) rename src/llamafactory/v1/accelerator/distributed.py => tests_v1/accelerator/test_interface.py (51%) diff --git a/data/v1_dpo_demo.yaml b/data/v1_dpo_demo.yaml index 56c6d565..2edca22f 100644 --- a/data/v1_dpo_demo.yaml +++ b/data/v1_dpo_demo.yaml @@ -1,4 +1,4 @@ dpo_zh_demo: - hf_hub_url: HuggingFaceH4/orca_dpo_pairs + path: HuggingFaceH4/orca_dpo_pairs split: train_prefs converter: pair diff --git a/data/v1_sft_demo.yaml b/data/v1_sft_demo.yaml index 9687431d..8fed3b1d 100644 --- a/data/v1_sft_demo.yaml +++ b/data/v1_sft_demo.yaml @@ -1,8 +1,9 @@ identity: - file_name: data/identity.json + path: data/identity.json + source: local converter: alpaca alpaca_en_demo: - file_name: alpaca_en_demo.json - dataset_dir: data + path: data/alpaca_en_demo.json + source: local converter: alpaca - num_samples: 500 + size: 500 diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py index a04202a2..a3108954 100644 --- a/src/llamafactory/v1/accelerator/helper.py +++ b/src/llamafactory/v1/accelerator/helper.py @@ -1,4 +1,7 @@ -# Copyright 2025 the LlamaFactory team. +# Copyright 2025 Bytedance Ltd. and the LlamaFactory team. +# +# This code is inspired by the Bytedance's VeOmni library. +# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +15,68 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from contextlib import contextmanager +from enum import Enum, unique from functools import lru_cache +from typing import TYPE_CHECKING, Optional +import numpy as np import torch +import torch.distributed as dist + +from ..utils.types import Tensor, TensorLike -def get_current_accelerator(check_available: bool = True): +if TYPE_CHECKING: + from torch.distributed import ProcessGroup + + +@unique +class DeviceType(str, Enum): + CPU = "cpu" + CUDA = "cuda" + META = "meta" + MPS = "mps" + NPU = "npu" + XPU = "xpu" + + +@unique +class ReduceOp(str, Enum): + SUM = "sum" + MEAN = "mean" + MAX = "max" + MIN = "min" + + +def is_distributed() -> bool: + """Check if distributed environment is available.""" + return os.getenv("RANK") is not None + + +def get_rank() -> int: + """Get rank.""" + return int(os.getenv("RANK", "0")) + + +def get_local_rank() -> int: + """Get local rank.""" + return int(os.getenv("LOCAL_RANK", "0")) + + +def get_world_size() -> int: + """Get world size.""" + return int(os.getenv("WORLD_SIZE", "1")) + + +def get_local_world_size() -> int: + """Get local world size.""" + return int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + +@lru_cache +def get_current_accelerator(check_available: bool = True) -> torch.device: """Get current accelerator. Note: this api requires torch>=2.7.0, 2.6 or lower will get an AttributeError or RuntimeError @@ -27,26 +86,78 @@ def get_current_accelerator(check_available: bool = True): accelerator = torch.accelerator.current_accelerator(check_available=check_available) if accelerator is None: - return torch.device("cpu") + return torch.device(DeviceType.CPU.value) return accelerator -@lru_cache -def is_torch_npu_available(): - return get_current_accelerator().type == "npu" - - -@lru_cache def is_torch_cuda_available(): - return get_current_accelerator().type == "cuda" + return get_current_accelerator().type == DeviceType.CUDA -@lru_cache -def is_torch_xpu_available(): - return get_current_accelerator().type == "xpu" - - -@lru_cache def is_torch_mps_available(): - return get_current_accelerator().type == "mps" + return get_current_accelerator().type == DeviceType.MPS + + +def is_torch_npu_available(): + return get_current_accelerator().type == DeviceType.NPU + + +def is_torch_xpu_available(): + return get_current_accelerator().type == DeviceType.XPU + + +def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor: + """Gathers the tensor from all ranks and concats them along the first dim.""" + world_size = get_world_size() + device = get_current_accelerator() + output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device) + dist.all_gather_into_tensor(output_tensor, tensor, group=group) + return output_tensor.view(-1, *tensor.size()[1:]) + + +def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike: + """Performs all reduce in the given process group.""" + device = get_current_accelerator() + is_ndarray = isinstance(data, np.ndarray) + is_tensor = isinstance(data, torch.Tensor) + + if is_ndarray: + data = torch.from_numpy(data) + elif not is_tensor: + data = torch.tensor(data, dtype=torch.float, device=device) + + reduce_ops = { + ReduceOp.MEAN: dist.ReduceOp.SUM, + ReduceOp.SUM: dist.ReduceOp.SUM, + ReduceOp.MAX: dist.ReduceOp.MAX, + ReduceOp.MIN: dist.ReduceOp.MIN, + } + dist.all_reduce(data, op=reduce_ops[op], group=group) + if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend + data /= dist.get_world_size(group=group) + + if is_tensor: + return data + elif is_ndarray: + return data.numpy() + elif data.numel() == 1: + return data.item() + else: + return data.tolist() + + +@contextmanager +def main_process_first(local_only: bool = True) -> None: + """A context manager for torch distributed environment to do something on the main process firstly.""" + if get_world_size() > 1: + is_main_process = get_local_rank() == 0 if local_only else get_rank() == 0 + try: + if not is_main_process: + dist.barrier() + yield + finally: + if is_main_process: + dist.barrier() + else: + yield diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py new file mode 100644 index 00000000..de4306f0 --- /dev/null +++ b/src/llamafactory/v1/accelerator/interface.py @@ -0,0 +1,123 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Optional + +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh + +from ..utils.types import TensorLike +from .helper import ReduceOp, all_reduce, get_current_accelerator, get_rank, get_world_size, is_distributed + + +@dataclass +class DistributedStrategy: + """Distributed strategy.""" + + dp_size: Optional[int] = None + tp_size: int = 1 + + def __post_init__(self) -> None: + if not is_distributed(): + self.dp_size = 1 + elif self.dp_size is None: + self.dp_size = get_world_size() // self.tp_size + elif self.dp_size * self.tp_size != get_world_size(): + raise ValueError( + f"dp_size * tp_size must equal to world_size, " + f"got {self.dp_size} * {self.tp_size} != {get_world_size()}." + ) + + @property + def mesh_shape(self) -> tuple[int, int]: + """Mesh shape.""" + return (self.dp_size, self.tp_size) + + @property + def mesh_dim_names(self) -> tuple[str, str]: + """Mesh dimension names.""" + return ("dp", "tp") + + +class DistributedInterface: + """Distributed interface.""" + + _instance: Optional["DistributedInterface"] = None + _initialized: bool = False + + is_distributed = is_distributed() + """Check if distributed environment is available.""" + rank = get_rank() + """Global rank.""" + world_size = get_world_size() + """Global world size.""" + device_mesh: Optional[DeviceMesh] = None + """Device mesh.""" + current_accelerator = get_current_accelerator() + """Current accelerator.""" + + def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface": + """Singleton pattern.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + + return cls._instance + + def __init__(self, strategy: DistributedStrategy) -> None: + if self._initialized: + return + + self.strategy = strategy + if self.is_distributed: + self.device_mesh = init_device_mesh( + device_type=self.current_accelerator.type, + mesh_shape=strategy.mesh_shape, + mesh_dim_names=strategy.mesh_dim_names, + ) + else: + self.device_mesh = None + + self._initialized = True + + def __str__(self) -> str: + return ( + f"DistributedInterface(strategy={self.strategy}), is_distributed={self.is_distributed}, " + f"rank={self.rank}, world_size={self.world_size}, " + f"device_mesh={self.device_mesh}, current_accelerator={self.current_accelerator}" + ) + + def dp_rank(self) -> int: + """Data parallel rank.""" + if self.device_mesh is None: + return 0 + + return self.device_mesh["dp"].get_rank() + + def dp_size(self) -> int: + """Data parallel size.""" + if self.device_mesh is None: + return 1 + + return self.device_mesh["dp"].size() + + def all_reduce_over_dp(self, data: TensorLike, op: ReduceOp = ReduceOp.MEAN) -> TensorLike: + """All reduce tensor.""" + if self.device_mesh is None: + return data + + return all_reduce(data, op, self.device_mesh["dp"].get_group()) + + +if __name__ == "__main__": + print(DistributedInterface(DistributedStrategy())) diff --git a/src/llamafactory/v1/config/parser.py b/src/llamafactory/v1/config/arg_parser.py similarity index 84% rename from src/llamafactory/v1/config/parser.py rename to src/llamafactory/v1/config/arg_parser.py index eca1749f..ca6103a9 100644 --- a/src/llamafactory/v1/config/parser.py +++ b/src/llamafactory/v1/config/arg_parser.py @@ -28,6 +28,21 @@ from .sample_args import SampleArguments from .training_args import TrainingArguments +def validate_args( + data_args: DataArguments, + model_args: ModelArguments, + training_args: TrainingArguments, + sample_args: SampleArguments, +): + """Validate arguments.""" + if ( + model_args.quant_config is not None + and training_args.dist_config is not None + and training_args.dist_config.name == "deepspeed" + ): + raise ValueError("Quantization is not supported with deepspeed backend.") + + def get_args( args: Optional[Union[dict[str, Any], list[str]]] = None, ) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]: @@ -56,6 +71,8 @@ def get_args( print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + validate_args(*parsed_args) + return tuple(parsed_args) diff --git a/src/llamafactory/v1/config/arg_utils.py b/src/llamafactory/v1/config/arg_utils.py new file mode 100644 index 00000000..f2c0183d --- /dev/null +++ b/src/llamafactory/v1/config/arg_utils.py @@ -0,0 +1,105 @@ +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/training_args.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from enum import Enum, unique +from typing import Any, Optional, Union + + +class PluginConfig(dict): + """Dictionary that allows attribute access.""" + + @property + def name(self) -> str: + """Plugin name.""" + if "name" not in self: + raise ValueError("Plugin configuration must have a 'name' field.") + + return self["name"] + + def __getattr__(self, key: str) -> Any: + try: + return self[key] + except KeyError: + raise AttributeError(f"Attribute {key} not found.") + + def __setattr__(self, key: str, value: Any): + self[key] = value + + +PluginArgument = Optional[Union[PluginConfig, dict, str]] + + +@unique +class AutoClass(str, Enum): + """Auto class for model config.""" + + CAUSALLM = "llm" + CLASSIFICATION = "cls" + OTHER = "other" + + +@unique +class SampleBackend(str, Enum): + HF = "hf" + VLLM = "vllm" + + +def _convert_str_dict(data: dict) -> dict: + """Parse string representation inside the dictionary. + + Args: + data: The string or dictionary to convert. + + Returns: + The converted dictionary. + """ + for key, value in data.items(): + if isinstance(value, dict): + data[key] = _convert_str_dict(value) + elif isinstance(value, str): + if value.lower() in ("true", "false"): + data[key] = value.lower() == "true" + elif value.isdigit(): + data[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + data[key] = float(value) + + return data + + +def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]: + """Get the plugin configuration from the argument value. + + Args: + config: The argument value to get the plugin configuration from. + + Returns: + The plugin configuration. + """ + if config is None: + return None + + if isinstance(config, str) and config.startswith("{"): + config = json.loads(config) + + config = _convert_str_dict(config) + if "name" not in config: + raise ValueError("Plugin configuration must have a 'name' field.") + + return PluginConfig(config) diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py index 6bb8dd58..1dfe962b 100644 --- a/src/llamafactory/v1/config/model_args.py +++ b/src/llamafactory/v1/config/model_args.py @@ -15,6 +15,8 @@ from dataclasses import dataclass, field +from .arg_utils import AutoClass, PluginConfig, get_plugin_config + @dataclass class ModelArguments: @@ -29,7 +31,24 @@ class ModelArguments: default=True, metadata={"help": "Use fast processor from Hugging Face."}, ) - auto_model_class: str = field( - default="causallm", + auto_class: AutoClass = field( + default=AutoClass.CAUSALLM, metadata={"help": "Model class from Hugging Face."}, ) + peft_config: PluginConfig = field( + default=None, + metadata={"help": "PEFT configuration for the model."}, + ) + kernel_config: PluginConfig = field( + default=None, + metadata={"help": "Kernel configuration for the model."}, + ) + quant_config: PluginConfig = field( + default=None, + metadata={"help": "Quantization configuration for the model."}, + ) + + def __post_init__(self) -> None: + self.peft_config = get_plugin_config(self.peft_config) + self.kernel_config = get_plugin_config(self.kernel_config) + self.quant_config = get_plugin_config(self.quant_config) diff --git a/src/llamafactory/v1/config/sample_args.py b/src/llamafactory/v1/config/sample_args.py index 52f24ae5..3971ee71 100644 --- a/src/llamafactory/v1/config/sample_args.py +++ b/src/llamafactory/v1/config/sample_args.py @@ -14,12 +14,8 @@ from dataclasses import dataclass, field -from enum import Enum - -class SampleBackend(Enum): - HF = "hf" - VLLM = "vllm" +from .arg_utils import SampleBackend @dataclass diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 38d62ecf..5a14cab6 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -15,6 +15,8 @@ from dataclasses import dataclass, field +from .arg_utils import PluginArgument, get_plugin_config + @dataclass class TrainingArguments: @@ -38,3 +40,10 @@ class TrainingArguments: default=False, metadata={"help": "Use bf16 for training."}, ) + dist_config: PluginArgument = field( + default=None, + metadata={"help": "Distribution configuration for training."}, + ) + + def __post_init__(self) -> None: + self.dist_config = get_plugin_config(self.dist_config) diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index cacca81c..331c1417 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -29,7 +29,7 @@ Train Phase: """ from ..config.training_args import TrainingArguments -from ..extras.types import TorchDataset +from ..utils.types import TorchDataset from .model_worker import ModelWorker from .trainer_utils.data_collator import DataCollator @@ -49,13 +49,10 @@ class BaseTrainer: self.optimizer = None self.lr_scheduler = None - def init_device_mesh(self) -> None: - pass - def init_model_and_optimizer(self) -> None: - self.model_config = self.model_worker.get_model_config() + self.model_worker.init_model_config() # with self.dist_plugin.get_model_init_context(): - # self.model = self.model_worker.get_model(self.model_config) + # self.model = self.model_worker.init_model(self.model_config) def create_dataloader(self) -> None: pass diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py index 523037b1..d7d6cbc3 100644 --- a/src/llamafactory/v1/core/data_engine.py +++ b/src/llamafactory/v1/core/data_engine.py @@ -34,7 +34,7 @@ from omegaconf import OmegaConf from torch.utils.data import Dataset from ..config.data_args import DataArguments -from ..extras.types import DatasetInfo, HFDataset, Sample +from ..utils.types import DatasetInfo, HFDataset, Sample class DataEngine(Dataset): @@ -64,9 +64,9 @@ class DataEngine(Dataset): filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") self.dataset_infos = OmegaConf.load(filepath) elif os.path.exists(self.args.dataset): # local file(s) - self.dataset_infos = {"default": {"file_name": self.args.dataset}} + self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}} else: # hf hub dataset, e.g. llamafactory/v1-sft-demo - self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}} + self.dataset_infos = {"default": {"path": self.args.dataset}} def load_dataset(self) -> None: """Load datasets according to dataset info.""" @@ -74,14 +74,14 @@ class DataEngine(Dataset): split = value.get("split", "train") streaming = value.get("streaming", False) self.streaming |= streaming - if "hf_hub_url" in value: + if value.get("source", "hf_hub") == "hf_hub": from datasets import load_dataset - self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming) + self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming) else: # data loader plugin from ..plugins.data_plugins.loader import DataLoaderPlugin - self.datasets[key] = DataLoaderPlugin().auto_load_data(value) + self.datasets[key] = DataLoaderPlugin(value["source"]).load(value) def build_data_index(self) -> None: """Build dataset index.""" @@ -112,9 +112,9 @@ class DataEngine(Dataset): """ converter = self.dataset_infos[dataset_name].get("converter") if converter is not None: - from ..plugins.data_plugins.converter import get_converter + from ..plugins.data_plugins.converter import DataConverterPlugin - return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)} + return {"_dataset_name": dataset_name, **DataConverterPlugin(converter)(raw_sample)} else: return {"_dataset_name": dataset_name, **raw_sample} @@ -147,7 +147,7 @@ class DataEngine(Dataset): else: from ..plugins.data_plugins.loader import DataSelectorPlugin - selected_index = DataSelectorPlugin(data_index=self.data_index).select(index) + selected_index = DataSelectorPlugin().select(self.data_index, index) if isinstance(selected_index, list): return [ self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) @@ -187,7 +187,7 @@ class DataEngine(Dataset): if __name__ == "__main__": - from ..config.parser import get_args + from ..config.arg_parser import get_args data_args, *_ = get_args() data_engine = DataEngine(data_args=data_args) diff --git a/src/llamafactory/v1/core/model_worker.py b/src/llamafactory/v1/core/model_worker.py index 6d38189c..183e7680 100644 --- a/src/llamafactory/v1/core/model_worker.py +++ b/src/llamafactory/v1/core/model_worker.py @@ -24,10 +24,12 @@ Init Phase: from typing import Optional +import torch from transformers import AutoConfig, AutoProcessor -from ..config.model_args import ModelArguments -from ..extras.types import DistModel, HFConfig, HFModel, Processor +from ..accelerator.helper import DeviceType +from ..config.model_args import AutoClass, ModelArguments +from ..utils.types import HFConfig, HFModel, Processor class ModelWorker: @@ -38,16 +40,15 @@ class ModelWorker: """Tokenizer or multi-modal processor.""" self.model_config: Optional[HFConfig] = None """Model configuration.""" - self.unwrapped_model: Optional[HFModel] = None - """Unwrapped model.""" - self.model: Optional[DistModel] = None - """Distributed model.""" - self.init_processor() - self.init_model_config() - self.init_model() - self.init_adapter() + self.model: Optional[HFModel] = None + """HF model.""" + self.is_adapter = False + """Whether the model has adapter.""" def init_processor(self) -> None: + if self.processor is not None: + return + self.processor = AutoProcessor.from_pretrained( self.args.model, trust_remote_code=self.args.trust_remote_code, @@ -55,38 +56,58 @@ class ModelWorker: ) def init_model_config(self) -> None: + if self.model_config is not None: + return + self.model_config = AutoConfig.from_pretrained( self.args.model, trust_remote_code=self.args.trust_remote_code, ) def init_model(self) -> None: - if self.args.auto_model_class == "causallm": + if self.model is not None: + return + + self.init_model_config() + + if self.args.auto_class == AutoClass.CAUSALLM: from transformers import AutoModelForCausalLM, AutoModelForImageTextToText if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys(): - AutoClass = AutoModelForImageTextToText + ModelClass = AutoModelForImageTextToText else: - AutoClass = AutoModelForCausalLM - elif self.args.auto_model_class == "classification": + ModelClass = AutoModelForCausalLM + elif self.args.auto_class == AutoClass.CLASSIFICATION: from transformers import AutoModelForTokenClassification - AutoClass = AutoModelForTokenClassification + ModelClass = AutoModelForTokenClassification else: from transformers import AutoModel - AutoClass = AutoModel + ModelClass = AutoModel - self.unwrapped_model = AutoClass.from_pretrained( - self.args.model, - config=self.model_config, - dtype="auto", - device_map="cpu", - trust_remote_code=self.args.trust_remote_code, - ) + default_device_type = torch.get_default_device().type + if default_device_type == DeviceType.META: + self.model = ModelClass.from_config(self.model_config) + else: + self.model = ModelClass.from_pretrained( + self.args.model, + config=self.model_config, + dtype="auto", + device_map=default_device_type, + trust_remote_code=self.args.trust_remote_code, + ) def init_adapter(self) -> None: - pass + if self.is_adapter: + return + + if self.args.peft_config is not None: + from ..plugins.model_plugins.peft import PeftPlugin + + self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config) + + self.is_adapter = True def get_processor(self) -> Processor: return self.processor @@ -95,4 +116,4 @@ class ModelWorker: return self.model_config def get_model(self) -> HFModel: - return self.unwrapped_model + return self.model diff --git a/src/llamafactory/v1/core/trainer_utils/data_collator.py b/src/llamafactory/v1/core/trainer_utils/data_collator.py index 88f49672..f91a07d7 100644 --- a/src/llamafactory/v1/core/trainer_utils/data_collator.py +++ b/src/llamafactory/v1/core/trainer_utils/data_collator.py @@ -15,7 +15,7 @@ from typing import Any -from ...extras.types import Processor, Tensor, TorchDataset +from ...utils.types import Processor, Tensor, TorchDataset class DataCollator: diff --git a/src/llamafactory/v1/launcher.py b/src/llamafactory/v1/launcher.py index 6160835b..3d18286b 100644 --- a/src/llamafactory/v1/launcher.py +++ b/src/llamafactory/v1/launcher.py @@ -44,7 +44,7 @@ WELCOME = ( def launch(): command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" - if command == "sft": + if command == "sft": # train command will fallback to sft command from .trainers.sft_trainer import run_sft run_sft() diff --git a/src/llamafactory/v1/plugins/data_plugins/converter.py b/src/llamafactory/v1/plugins/data_plugins/converter.py index 777b710e..b5778970 100644 --- a/src/llamafactory/v1/plugins/data_plugins/converter.py +++ b/src/llamafactory/v1/plugins/data_plugins/converter.py @@ -13,12 +13,13 @@ # limitations under the License. -from typing import Callable, TypedDict +from typing import Any, Literal, TypedDict -from typing_extensions import NotRequired, Required +from typing_extensions import NotRequired -from ....extras import logging -from ...extras.types import DPOSample, Sample, SFTSample +from ...utils import logging +from ...utils.plugin import BasePlugin +from ...utils.types import DPOSample, Sample, SFTSample logger = logging.get_logger(__name__) @@ -26,35 +27,48 @@ logger = logging.get_logger(__name__) class AlpacaSample(TypedDict, total=False): system: NotRequired[str] - instruction: NotRequired[str] + instruction: str input: NotRequired[str] - output: NotRequired[str] + output: str -ShareGPTMessage = TypedDict( - "ShareGPTMessage", - { - "from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system") - "value": Required[str], # Content of the message - }, +SharegptMessage = TypedDict( + "SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str} ) -class ShareGPTSample(TypedDict, total=False): - """Type definition for raw ShareGPT sample.""" +class SharegptSample(TypedDict, total=False): + conversations: list[SharegptMessage] + tools: NotRequired[str] - conversations: Required[list[ShareGPTMessage]] + +class OpenaiMessage(TypedDict, total=False): + role: Literal["user", "assistant", "tool"] + content: str + + +class OpenaiSample(TypedDict, total=False): + messages: list[OpenaiMessage] class PairSample(TypedDict, total=False): - prompt: NotRequired[str] - chosen: NotRequired[list[dict]] - rejected: NotRequired[list[dict]] + chosen: list[OpenaiMessage] + rejected: list[OpenaiMessage] +class DataConverterPlugin(BasePlugin): + """Plugin for data converters.""" + + def __call__(self, raw_sample: dict[str, Any]) -> Sample: + return super().__call__(raw_sample) + + +@DataConverterPlugin("alpaca").register def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: """Convert Alpaca sample to SFT sample. + See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en + Args: raw_sample (AlpacaSample): Alpaca sample. @@ -67,20 +81,6 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: {"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0} ) - if "history" in raw_sample: - for idx, item in enumerate(raw_sample["history"]): - if len(item) != 2: - logger.warning_rank0( - f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping." - ) - continue - - old_prompt, old_response = item - messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0}) - messages.append( - {"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0} - ) - if "instruction" in raw_sample or "input" in raw_sample: messages.append( { @@ -100,149 +100,85 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: return {"messages": messages} -def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample: - """Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample. +@DataConverterPlugin("sharegpt").register +def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: + """Convert ShareGPT sample to SFT sample. - Retains only SFT-relevant scenarios and removes parity checks. + See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en Args: - raw_sample (ShareGPTSample): A raw sample in ShareGPT format. + raw_sample (SharegptSample): ShareGPT sample. Returns: - dict: A dictionary containing the formatted 'messages' list for SFT training. - Returns an empty list if the input data is invalid. + SFTSample: SFT sample. """ tag_mapping = { + "system": "system", "human": "user", "gpt": "assistant", - "observation": "observation", - "function_call": "function", + "observation": "tool", + "function_call": "assistant", } - messages = raw_sample.get("conversations", []) - aligned_messages = [] - system_content = "" + messages = [] + tools = raw_sample.get("tools", "") - # Extract system message if present (typically the first message) - if messages and messages[0]["from"] == "system": - system_content = messages[0]["value"] - messages = messages[1:] + for message in raw_sample.get("conversations", []): + tag = message["from"] + if tag not in tag_mapping: + logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}") + elif tag == "function_call": + messages.append( + { + "role": "assistant", + "content": [{"type": "tool_calls", "value": message["value"]}], + "loss_weight": 1.0, + } + ) + else: + messages.append( + { + "role": tag_mapping[tag], + "content": [{"type": "text", "value": message["value"]}], + "loss_weight": 1.0 if tag == "gpt" else 0.0, + } + ) - if system_content: - aligned_messages.append( - {"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0} - ) + if tools: + if messages and messages[0]["role"] == "system": + messages[0]["content"].append({"type": "tools", "value": tools}) + else: + messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0}) - has_invalid_role = False - for message in messages: - sender = message["from"] - # validate sender is in supported tags - if sender not in tag_mapping: - logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}") - has_invalid_role = True - break - - aligned_messages.append( - { - "role": tag_mapping[sender], - "content": [{"type": "text", "value": message["value"]}], - "loss_weight": 0.0 if sender in ("human", "observation") else 1.0, - } - ) - - if has_invalid_role: - logger.warning_rank0("Skipping invalid example due to unsupported role tags.") - return {"messages": []} - - return {"messages": aligned_messages} + return {"messages": messages} +@DataConverterPlugin("pair").register def pair_converter(raw_sample: PairSample) -> DPOSample: - """Convert Pair sample to standard DPO sample. + """Convert Pair sample to DPO sample. + + See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs Args: - raw_sample (PairSample): pair sample with prompt, chosen, rejected fields. - see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs + raw_sample (PairSample): pair sample with chosen, rejected fields. Returns: DPOSample: DPO sample with chosen_messages and rejected_messages. - see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl """ - chosen_messages = [] - assert "chosen" in raw_sample, "chosen field is required in pair sample." - assert "rejected" in raw_sample, "rejected field is required in pair sample." - assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), ( - "chosen and rejected field should be a list[dict], or you may need to implement your custom converter." - ) - if "chosen" in raw_sample: - value = raw_sample.get("chosen", "") - for item in value: - if item.get("role", "") == "system": - chosen_messages.append( - { - "role": "system", - "content": [{"type": "text", "value": item.get("content", "")}], - "loss_weight": 0.0, - } - ) - if item.get("role", "") == "user": - chosen_messages.append( - { - "role": "user", - "content": [{"type": "text", "value": item.get("content", "")}], - "loss_weight": 0.0, - } - ) - if item.get("role", "") == "assistant": - chosen_messages.append( - { - "role": "assistant", - "content": [{"type": "text", "value": item.get("content", "")}], - "loss_weight": 1.0, - } - ) + def process_message(raw_messages: list[OpenaiMessage]): + messages = [] + for message in raw_messages: + messages.append( + { + "role": message["role"], + "content": [{"type": "text", "value": message["content"]}], + "loss_weight": 1.0 if message["role"] == "assistant" else 0.0, + } + ) - rejected_messages = [] - if "rejected" in raw_sample: - value = raw_sample.get("rejected", "") - for item in value: - if item.get("role", "") == "system": - rejected_messages.append( - { - "role": "system", - "content": [{"type": "text", "value": item.get("content", "")}], - "loss_weight": 0.0, - } - ) - if item.get("role", "") == "user": - rejected_messages.append( - { - "role": "user", - "content": [{"type": "text", "value": item.get("content", "")}], - "loss_weight": 0.0, - } - ) - if item.get("role", "") == "assistant": - rejected_messages.append( - { - "role": "assistant", - "content": [{"type": "text", "value": item.get("content", "")}], - "loss_weight": 1.0, - } - ) + return messages + + chosen_messages = process_message(raw_sample.get("chosen", [])) + rejected_messages = process_message(raw_sample.get("rejected", [])) return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages} - - -CONVERTERS = { - "alpaca": alpaca_converter, - "pair": pair_converter, - "sharegpt": sharegpt_converter, -} - - -def get_converter(converter_name: str) -> Callable[[dict], Sample]: - if converter_name not in CONVERTERS: - raise ValueError(f"Converter {converter_name} not found.") - - return CONVERTERS[converter_name] diff --git a/src/llamafactory/v1/plugins/data_plugins/loader.py b/src/llamafactory/v1/plugins/data_plugins/loader.py index 92ec6add..6329e3c3 100644 --- a/src/llamafactory/v1/plugins/data_plugins/loader.py +++ b/src/llamafactory/v1/plugins/data_plugins/loader.py @@ -14,57 +14,59 @@ import os -from dataclasses import dataclass +import random from typing import Any, Literal, Optional, Union from datasets import load_dataset -from ...extras.types import DatasetInfo, HFDataset +from ...utils.plugin import BasePlugin +from ...utils.types import DatasetInfo, HFDataset -@dataclass -class DataLoaderPlugin: +class DataLoaderPlugin(BasePlugin): """Plugin for loading dataset.""" - def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]: - """Get dataset builder name. - - Args: - path (str): Dataset path. - - Returns: - Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name. - """ - return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text") - - def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset: - dataset_dir = dataset_info.get("dataset_dir", ".") + def load(self, dataset_info: DatasetInfo) -> HFDataset: + path = dataset_info["path"] split = dataset_info.get("split", "train") streaming = dataset_info.get("streaming", False) - if "file_name" in dataset_info: - filepath = os.path.join(dataset_dir, dataset_info["file_name"]) - return self.load_data_from_file(filepath, split, streaming) - else: - raise NotImplementedError() - - def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset: - if os.path.isdir(filepath): - filetype = self._get_builder_name(os.listdir(filepath)[0]) - dataset = load_dataset(filetype, data_dir=filepath, split=split) - elif os.path.isfile(filepath): - filetype = self._get_builder_name(filepath) - dataset = load_dataset(filetype, data_files=filepath, split=split) - else: - raise ValueError(f"Can not load dataset from {filepath}.") - - if streaming: - dataset = dataset.to_iterable_dataset() - - return dataset + return super().__call__(path, split, streaming) -@dataclass -class DataIndexPlugin: +def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]: + """Get dataset builder name. + + Args: + path (str): Dataset path. + + Returns: + Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name. + """ + filetype = os.path.splitext(path)[-1][1:] + if filetype in ["arrow", "csv", "json", "jsonl", "parquet", "txt"]: + return filetype.replace("jsonl", "json").replace("txt", "text") + else: + raise ValueError(f"Unknown dataset filetype: {filetype}.") + + +@DataLoaderPlugin("local").register +def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset: + if os.path.isdir(filepath): + filetype = _get_builder_name(os.listdir(filepath)[0]) + dataset = load_dataset(filetype, data_dir=filepath, split=split) + elif os.path.isfile(filepath): + filetype = _get_builder_name(filepath) + dataset = load_dataset(filetype, data_files=filepath, split=split) + else: + raise ValueError(f"Can not load dataset from {filepath}.") + + if streaming: # faster when data is streamed from local files + dataset = dataset.to_iterable_dataset() + + return dataset + + +class DataIndexPlugin(BasePlugin): """Plugin for adjusting dataset index.""" def adjust_data_index( @@ -81,39 +83,32 @@ class DataIndexPlugin: list[tuple[str, int]]: Adjusted dataset index. """ if size is not None: - data_index = self.adjust_by_size(data_index, size) + data_index = random.choices(data_index, k=size) if weight is not None: - data_index = self.adjust_by_weight(data_index, weight) + data_index = random.choices(data_index, k=int(len(data_index) * weight)) return data_index - def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]: - raise NotImplementedError() - def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]: - raise NotImplementedError() - - -@dataclass -class DataSelectorPlugin: +class DataSelectorPlugin(BasePlugin): """Plugin for selecting dataset samples.""" - data_index: list[tuple[str, int]] - """List of (dataset_name, sample_index)""" - - def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]: + def select( + self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any] + ) -> Union[tuple[str, int], list[tuple[str, int]]]: """Select dataset samples. Args: + data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). index (Union[slice, list[int], Any]): Index of dataset samples. Returns: Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples. """ if isinstance(index, slice): - return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))] + return [data_index[i] for i in range(*index.indices(len(data_index)))] elif isinstance(index, list): - return [self.data_index[i] for i in index] + return [data_index[i] for i in index] else: raise ValueError(f"Invalid index type {type(index)}.") diff --git a/src/llamafactory/v1/extras/__init__.py b/src/llamafactory/v1/plugins/model_plugins/kernels/attn/__init__.py similarity index 100% rename from src/llamafactory/v1/extras/__init__.py rename to src/llamafactory/v1/plugins/model_plugins/kernels/attn/__init__.py diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py index 1ad988bf..55a05e94 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py @@ -21,10 +21,3 @@ class KernelType(str, Enum): FLASH_ATTENTION = "flash_attention" ROPE = "rope" MOE = "moe" - - -class DeviceType(str, Enum): - CPU = "cpu" - CUDA = "cuda" - NPU = "npu" - XPU = "xpu" diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py index 7f1e8c0b..1fa9ef47 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_fused_moe.py @@ -18,10 +18,10 @@ import torch import torch.nn.functional as F import torch_npu -from .....accelerator.helper import is_torch_npu_available -from .....extras.packages import is_transformers_version_greater_than -from .....extras.types import HFModel -from ..constants import DeviceType, KernelType +from .....accelerator.helper import DeviceType, is_torch_npu_available +from .....utils.packages import is_transformers_version_greater_than +from .....utils.types import HFModel +from ..constants import KernelType from ..registry import MetaMoEKernel diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py index 011b53b3..ff115a43 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py @@ -17,9 +17,9 @@ import types import torch -from .....accelerator.helper import is_torch_npu_available -from .....extras.types import HFModel -from ..constants import DeviceType, KernelType +from .....accelerator.helper import DeviceType, is_torch_npu_available +from .....utils.types import HFModel +from ..constants import KernelType from ..registry import MetaSwiGluKernel diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index 06595ba1..08dc5b9d 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -13,11 +13,11 @@ # limitations under the License. from abc import ABC, ABCMeta, abstractmethod -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union -from ....accelerator.helper import get_current_accelerator -from ....extras.types import HFModel -from .constants import DeviceType, KernelType +from ....accelerator.helper import DeviceType, get_current_accelerator +from ....utils.types import HFModel +from .constants import KernelType class KernelRegistry: @@ -27,11 +27,13 @@ class KernelRegistry: def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry": if cls._instance is None: cls._instance = super().__new__(cls) + return cls._instance def __init__(self) -> None: if self._initialized: return + self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {} self._initialized = True @@ -218,7 +220,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]: return discovered_kernels # Iterate through registry and collect all kernels for current device - for kernel_type, devices in KERNEL_REGISTRY._registry.items(): + for devices in KERNEL_REGISTRY._registry.values(): kernel_cls = devices.get(device_type) if kernel_cls is not None: discovered_kernels.append(kernel_cls) @@ -226,7 +228,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]: return discovered_kernels -def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel": +def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel": """Call the MetaKernel's `apply` to perform the replacement. Corresponding replacement logic is maintained inside each kernel; the only @@ -238,16 +240,18 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFMo model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B") model = apply_kernel(model, NpuRMSNormKernel) """ - if issubclass(kernel, MetaKernel) and kernel.device == get_current_accelerator().type: - return kernel.apply(model, **kwargs) + if not issubclass(kernel, MetaKernel): + raise ValueError(f"{kernel} must be a MetaKernel instance.") - raise ValueError( - f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_current_accelerator().type} instead." - ) + if kernel.device != get_current_accelerator().type: + raise ValueError(f"{kernel} must be applied to {kernel.device} device, got {get_current_accelerator().type}.") + + return kernel.apply(model, **kwargs) def apply_available_kernels(model: HFModel, **kwargs) -> "HFModel": """Apply all available kernels to the model.""" for kernel in discover_kernels(model): model = apply_kernel(model, kernel, **kwargs) + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py index 01550719..7ff00898 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py @@ -14,9 +14,9 @@ import re import types -from .....accelerator.helper import is_torch_npu_available -from .....extras.types import HFModel -from ..constants import DeviceType, KernelType +from .....accelerator.helper import DeviceType, is_torch_npu_available +from .....utils.types import HFModel +from ..constants import KernelType from ..registry import MetaRMSNormKernel diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py index a0608e7d..82fccce7 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py @@ -16,9 +16,9 @@ import sys import torch -from .....accelerator.helper import is_torch_npu_available -from .....extras.types import HFModel -from ..constants import DeviceType, KernelType +from .....accelerator.helper import DeviceType, is_torch_npu_available +from .....utils.types import HFModel +from ..constants import KernelType from ..registry import MetaRoPEKernel diff --git a/src/llamafactory/v1/plugins/model_plugins/peft.py b/src/llamafactory/v1/plugins/model_plugins/peft.py index e69de29b..279c5928 100644 --- a/src/llamafactory/v1/plugins/model_plugins/peft.py +++ b/src/llamafactory/v1/plugins/model_plugins/peft.py @@ -0,0 +1,42 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal, TypedDict + +from peft import LoraConfig, PeftModel, get_peft_model + +from ...utils.plugin import BasePlugin +from ...utils.types import HFModel + + +class LoraConfigDict(TypedDict, total=False): + name: Literal["lora"] + """Plugin name.""" + r: int + """Lora rank.""" + lora_alpha: int + """Lora alpha.""" + target_modules: list[str] + """Target modules.""" + + +class PeftPlugin(BasePlugin): + pass + + +@PeftPlugin("lora").register +def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel: + peft_config = LoraConfig(**config) + model = get_peft_model(model, peft_config) + return model diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py index ec0d6255..e69de29b 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py @@ -1,13 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py similarity index 100% rename from src/llamafactory/v1/plugins/model_plugins/kernels/fa/__init__.py rename to src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py diff --git a/src/llamafactory/v1/trainers/sft_trainer.py b/src/llamafactory/v1/trainers/sft_trainer.py index 5d254e4c..781e0a7b 100644 --- a/src/llamafactory/v1/trainers/sft_trainer.py +++ b/src/llamafactory/v1/trainers/sft_trainer.py @@ -13,22 +13,21 @@ # limitations under the License. -from ..config.parser import get_args +from ..accelerator.interface import DistributedInterface, DistributedStrategy +from ..config.arg_parser import get_args from ..core.base_trainer import BaseTrainer from ..core.data_engine import DataEngine -from ..core.model_engine import ModelEngine +from ..core.model_worker import ModelWorker class SFTTrainer(BaseTrainer): pass -def run_sft(): - model_args, data_args, training_args, _ = get_args() - model_engine = ModelEngine(model_args) +def run_sft(user_args): + model_args, data_args, training_args, _ = get_args(user_args) + DistributedInterface(DistributedStrategy()) data_engine = DataEngine(data_args) - model = model_engine.get_model() - processor = model_engine.get_processor() - data_loader = data_engine.get_data_loader(processor) - trainer = SFTTrainer(training_args, model, processor, data_loader) + model_worker = ModelWorker(model_args) + trainer = SFTTrainer(training_args, model_worker, data_engine) trainer.fit() diff --git a/src/llamafactory/v1/utils/__init__.py b/src/llamafactory/v1/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/v1/utils/constants.py b/src/llamafactory/v1/utils/constants.py new file mode 100644 index 00000000..ec0d6255 --- /dev/null +++ b/src/llamafactory/v1/utils/constants.py @@ -0,0 +1,13 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/llamafactory/v1/utils/logging.py b/src/llamafactory/v1/utils/logging.py new file mode 100644 index 00000000..b2f9362f --- /dev/null +++ b/src/llamafactory/v1/utils/logging.py @@ -0,0 +1,123 @@ +# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import threading +from functools import lru_cache +from typing import Optional + + +_thread_lock = threading.RLock() +_default_handler: Optional["logging.Handler"] = None +_default_log_level: "logging._Level" = logging.INFO + + +class _Logger(logging.Logger): + r"""A logger that supports rank0 logging.""" + + def info_rank0(self, *args, **kwargs) -> None: + self.info(*args, **kwargs) + + def warning_rank0(self, *args, **kwargs) -> None: + self.warning(*args, **kwargs) + + def warning_rank0_once(self, *args, **kwargs) -> None: + self.warning(*args, **kwargs) + + +def _get_default_logging_level() -> "logging._Level": + r"""Return the default logging level.""" + env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None) + if env_level_str: + if env_level_str.upper() in logging._nameToLevel: + return logging._nameToLevel[env_level_str.upper()] + else: + raise ValueError(f"Unknown logging level: {env_level_str}.") + + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> "_Logger": + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + r"""Configure root logger using a stdout stream handler with an explicit format.""" + global _default_handler + + with _thread_lock: + if _default_handler: # already configured + return + + formatter = logging.Formatter( + fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.setFormatter(formatter) + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def get_logger(name: Optional[str] = None) -> "_Logger": + r"""Return a logger with the specified name. It it not supposed to be accessed externally.""" + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def add_handler(handler: "logging.Handler") -> None: + r"""Add a handler to the root logger.""" + _configure_library_root_logger() + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + r"""Remove a handler to the root logger.""" + _configure_library_root_logger() + _get_library_root_logger().removeHandler(handler) + + +def info_rank0(self: "logging.Logger", *args, **kwargs) -> None: + if int(os.getenv("LOCAL_RANK", "0")) == 0: + self.info(*args, **kwargs) + + +def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None: + if int(os.getenv("LOCAL_RANK", "0")) == 0: + self.warning(*args, **kwargs) + + +@lru_cache(None) +def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None: + if int(os.getenv("LOCAL_RANK", "0")) == 0: + self.warning(*args, **kwargs) + + +logging.Logger.info_rank0 = info_rank0 +logging.Logger.warning_rank0 = warning_rank0 +logging.Logger.warning_rank0_once = warning_rank0_once diff --git a/src/llamafactory/v1/extras/packages.py b/src/llamafactory/v1/utils/packages.py similarity index 100% rename from src/llamafactory/v1/extras/packages.py rename to src/llamafactory/v1/utils/packages.py diff --git a/src/llamafactory/v1/utils/plugin.py b/src/llamafactory/v1/utils/plugin.py new file mode 100644 index 00000000..0cbf2576 --- /dev/null +++ b/src/llamafactory/v1/utils/plugin.py @@ -0,0 +1,86 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +from . import logging + + +logger = logging.get_logger(__name__) + + +class BasePlugin: + """Base class for plugins. + + A plugin is a callable object that can be registered and called by name. + """ + + _registry: dict[str, Callable] = {} + + def __init__(self, name: Optional[str] = None): + """Initialize the plugin with a name. + + Args: + name (str): The name of the plugin. + """ + self.name = name + + @property + def register(self) -> Callable: + """Decorator to register a function as a plugin. + + Example usage: + ```python + @PrintPlugin("hello").register() + def print_hello(): + print("Hello world!") + ``` + """ + if self.name is None: + raise ValueError("Plugin name is not specified.") + + if self.name in self._registry: + logger.warning_rank0_once(f"Plugin {self.name} is already registered.") + + def decorator(func: Callable) -> Callable: + self._registry[self.name] = func + return func + + return decorator + + def __call__(self, *args, **kwargs) -> Callable: + """Call the registered function with the given arguments. + + Example usage: + ```python + PrintPlugin("hello")() + ``` + """ + if self.name not in self._registry: + raise ValueError(f"Plugin {self.name} is not registered.") + + return self._registry[self.name](*args, **kwargs) + + +if __name__ == "__main__": + + class PrintPlugin(BasePlugin): + pass + + @PrintPlugin("hello").register + def print_hello(): + print("Hello world!") + + PrintPlugin("hello")() diff --git a/src/llamafactory/v1/extras/types.py b/src/llamafactory/v1/utils/types.py similarity index 82% rename from src/llamafactory/v1/extras/types.py rename to src/llamafactory/v1/utils/types.py index 16658c3d..eb81298d 100644 --- a/src/llamafactory/v1/extras/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -19,23 +19,27 @@ from typing_extensions import NotRequired if TYPE_CHECKING: import datasets + import numpy as np import torch import torch.utils.data import transformers + from torch.distributed.fsdp import FullyShardedDataParallel Tensor = torch.Tensor + TensorLike = Union[int, float, list[int], list[float], np.ndarray, Tensor] TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset] HFDataset = Union[datasets.Dataset, datasets.IterableDataset] DataCollator = transformers.DataCollator DataLoader = torch.utils.data.DataLoader HFConfig = transformers.PretrainedConfig HFModel = transformers.PreTrainedModel - DistModel = torch.nn.parallel.DistributedDataParallel + DistModel = Union[torch.nn.parallel.DistributedDataParallel, FullyShardedDataParallel] Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin] Optimizer = torch.optim.Optimizer Scheduler = torch.optim.lr_scheduler.LRScheduler else: Tensor = None + TensorLike = None TorchDataset = None HFDataset = None DataCollator = None @@ -49,12 +53,10 @@ else: class DatasetInfo(TypedDict, total=False): - hf_hub_url: NotRequired[str] - """HF hub dataset uri.""" - file_name: NotRequired[str] + path: str """Local file path.""" - dataset_dir: NotRequired[str] - """Dataset directory, default to args.dataset_dir.""" + source: NotRequired[Literal["hf_hub", "ms_hub", "local"]] + """Dataset source, default to "hf_hub".""" split: NotRequired[str] """Dataset split, default to "train".""" converter: NotRequired[str] @@ -68,12 +70,12 @@ class DatasetInfo(TypedDict, total=False): class Content(TypedDict): - type: Literal["text", "tools", "reasoning", "tool_calls", "image_url"] + type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"] value: str class Message(TypedDict): - role: Literal["system", "user", "assistant"] + role: Literal["system", "user", "assistant", "tool"] content: list[Content] loss_weight: float diff --git a/src/llamafactory/v1/accelerator/distributed.py b/tests_v1/accelerator/test_interface.py similarity index 51% rename from src/llamafactory/v1/accelerator/distributed.py rename to tests_v1/accelerator/test_interface.py index 0094bf5a..ab39915b 100644 --- a/src/llamafactory/v1/accelerator/distributed.py +++ b/tests_v1/accelerator/test_interface.py @@ -12,25 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional -from torch.distributed.device_mesh import DeviceMesh +import os + +from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy -class DeviceMeshManager: - """Device mesh manager.""" - - _instance: Optional["DeviceMeshManager"] = None - _initialized: bool = False - - def __new__(cls) -> "DeviceMeshManager": - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self) -> None: - if self._initialized: - return - - self.device_mesh: Optional[DeviceMesh] = None - self._initialized = True +def test_distributed_interface(): + DistributedInterface(DistributedStrategy()) + assert DistributedInterface.rank == int(os.getenv("RANK", "0")) + assert DistributedInterface.world_size == int(os.getenv("WORLD_SIZE", "1")) diff --git a/tests_v1/plugins/data_plugins/test_converter.py b/tests_v1/plugins/data_plugins/test_converter.py index e64b8c25..cebdaec7 100644 --- a/tests_v1/plugins/data_plugins/test_converter.py +++ b/tests_v1/plugins/data_plugins/test_converter.py @@ -19,7 +19,7 @@ from datasets import load_dataset from llamafactory.v1.config.data_args import DataArguments from llamafactory.v1.core.data_engine import DataEngine -from llamafactory.v1.plugins.data_plugins.converter import get_converter +from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin @pytest.mark.parametrize("num_samples", [16]) @@ -49,99 +49,27 @@ def test_alpaca_converter(num_samples: int): assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data} -def test_sharegpt_converter_invalid(): +def test_sharegpt_converter(): example = { "conversations": [ - { - "from": "system", - "value": "Processes historical market data to generate trading signals " - "based on specified technical indicators.", - }, - { - "from": "human", - "value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. " - "Could you proceed with these function calls to assist me with the task?", - }, - { - "from": "gpt", - "value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', " - "'name': 'backtest_trading_signals'}```\n", - }, - { - "from": "tool", - "value": '\n{"analysis": {"RSI_signals": [{"date": "2025-01-10", ' - '"symbol": "AAPL", "signal": "Buy"}]}}}\n\n', - }, + {"from": "system", "value": "System"}, + {"from": "human", "value": "User"}, + {"from": "gpt", "value": "Assistant"}, ] } - dataset_converter = get_converter("sharegpt") - assert dataset_converter(example) == {"messages": []} - - -def test_sharegpt_converter_valid(): - example = { - "conversations": [ - { - "from": "system", - "value": "Processes historical market data to generate trading signals based on " - "specified technical indicators.", - }, - { - "from": "human", - "value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. " - "Could you proceed with these function calls to assist me with the task?", - }, - { - "from": "gpt", - "value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', " - "'name': 'backtest_trading_signals'}```\n", - }, - ] - } - dataset_converter = get_converter("sharegpt") expected_data = { "messages": [ - { - "content": [ - { - "type": "text", - "value": "Processes historical market data to generate trading signals based on " - "specified technical indicators.", - } - ], - "loss_weight": 0.0, - "role": "system", - }, - { - "content": [ - { - "type": "text", - "value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. " - "Could you proceed with these function calls to assist me with the task?", - } - ], - "loss_weight": 0.0, - "role": "user", - }, - { - "content": [ - { - "type": "text", - "value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', " - "'name': 'backtest_trading_signals'}```\n", - } - ], - "loss_weight": 1.0, - "role": "assistant", - }, + {"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"}, + {"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"}, + {"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"}, ] } - assert dataset_converter(example) == expected_data + assert DataConverterPlugin("sharegpt")(example) == expected_data @pytest.mark.parametrize("num_samples", [16]) def test_pair_converter(num_samples: int): - data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml") + data_args = DataArguments(dataset="llamafactory/tiny-preference-dataset/dataset_info.yaml") data_engine = DataEngine(data_args) original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs") indexes = random.choices(range(len(data_engine)), k=num_samples) @@ -189,6 +117,5 @@ def test_pair_converter(num_samples: int): if __name__ == "__main__": test_alpaca_converter(1) - test_sharegpt_converter_invalid() - test_sharegpt_converter_valid() + test_sharegpt_converter() test_pair_converter(1) diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index 06276c33..20c61b29 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -17,10 +17,13 @@ from unittest.mock import MagicMock, patch from transformers import AutoModelForCausalLM +from llamafactory.v1.accelerator.helper import get_current_accelerator + class TestKernelPlugin(unittest.TestCase): @patch("torch.accelerator.current_accelerator") def test_apply_kernel(self, mock_get_accelerator): + get_current_accelerator.cache_clear() mock_device = MagicMock() mock_device.type = "npu" mock_get_accelerator.return_value = mock_device @@ -47,6 +50,7 @@ class TestKernelPlugin(unittest.TestCase): class Test_Use_V1_Kernels(unittest.TestCase): @patch("torch.accelerator.current_accelerator") def test_use_v1_kernels(self, mock_get_accelerator): + get_current_accelerator.cache_clear() mock_device = MagicMock() mock_device.type = "npu" mock_get_accelerator.return_value = mock_device