diff --git a/requirements.txt b/requirements.txt index 0a273440f..59105a9dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # core deps transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10' -transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10' +transformers>=4.49.0,<=4.57.3,!=4.52.0,!=4.57.0; python_version >= '3.10' datasets>=2.16.0,<=4.0.0 accelerate>=1.3.0,<=1.11.0 peft>=0.14.0,<=0.17.1 diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 5c4c24787..dc934bd26 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None: def check_dependencies() -> None: r"""Check the version of the required packages.""" - check_version("transformers>=4.49.0,<=4.57.1") + check_version("transformers>=4.49.0,<=4.57.3") check_version("datasets>=2.16.0,<=4.0.0") check_version("accelerate>=1.3.0,<=1.11.0") check_version("peft>=0.14.0,<=0.17.1") diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py index 8a7d68697..22da4916e 100644 --- a/src/llamafactory/v1/accelerator/helper.py +++ b/src/llamafactory/v1/accelerator/helper.py @@ -19,17 +19,13 @@ import os from contextlib import contextmanager from enum import Enum, unique from functools import lru_cache -from typing import TYPE_CHECKING, Optional +from typing import Optional import numpy as np import torch import torch.distributed as dist -from ..utils.types import Tensor, TensorLike - - -if TYPE_CHECKING: - from torch.distributed import ProcessGroup +from ..utils.types import ProcessGroup, Tensor, TensorLike @unique @@ -107,7 +103,7 @@ def is_torch_xpu_available(): return get_current_accelerator().type == DeviceType.XPU -def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor: +def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: """Gathers the tensor from all ranks and concats them along the first dim.""" world_size = get_world_size() device = get_current_accelerator() @@ -116,7 +112,7 @@ def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor return output_tensor.view(-1, *tensor.size()[1:]) -def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike: +def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike: """Performs all reduce in the given process group.""" device = get_current_accelerator() is_ndarray = isinstance(data, np.ndarray) diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index 47a878f26..b24833186 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -16,12 +16,14 @@ # limitations under the License. from dataclasses import dataclass +from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import Any, Optional +from torch.distributed import init_process_group from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from ..utils.types import Tensor, TensorLike +from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike from .helper import ( ReduceOp, all_gather, @@ -35,10 +37,6 @@ from .helper import ( ) -if TYPE_CHECKING: - from torch.distributed import ProcessGroup - - class Dim(str, Enum): """Dimension names.""" @@ -130,21 +128,33 @@ class DistributedInterface: return cls._instance - def __init__(self, strategy: DistributedStrategy) -> None: + def __init__(self, config: Optional[DistributedConfig] = None) -> None: if self._initialized: return - self.strategy = strategy + if config is None: + self.strategy = DistributedStrategy() + timeout = 18000 + else: + self.strategy = DistributedStrategy( + mp_replicate_size=config.get("mp_replicate_size", 1), + mp_shard_size=config.get("mp_shard_size", None), + dp_size=config.get("dp_size", None), + cp_size=config.get("cp_size", 1), + ) + timeout = config.get("timeout", 18000) + if self._is_distributed: + init_process_group(timeout=timedelta(seconds=timeout)) self.model_device_mesh = init_device_mesh( device_type=self.current_accelerator.type, - mesh_shape=strategy.model_mesh_shape, - mesh_dim_names=strategy.model_mesh_dim_names, + mesh_shape=self.strategy.model_mesh_shape, + mesh_dim_names=self.strategy.model_mesh_dim_names, ) self.data_device_mesh = init_device_mesh( device_type=self.current_accelerator.type, - mesh_shape=strategy.data_mesh_shape, - mesh_dim_names=strategy.data_mesh_dim_names, + mesh_shape=self.strategy.data_mesh_shape, + mesh_dim_names=self.strategy.data_mesh_dim_names, ) else: self.model_device_mesh = None @@ -172,7 +182,7 @@ class DistributedInterface: return cls.model_device_mesh[dim.value] @classmethod - def get_group(cls, dim: Optional[Dim] = None) -> Optional["ProcessGroup"]: + def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]: """Get process group for specified dimension.""" if cls.model_device_mesh is None or dim is None: return None diff --git a/src/llamafactory/v1/config/arg_parser.py b/src/llamafactory/v1/config/arg_parser.py index ca6103a9c..05b4f69f3 100644 --- a/src/llamafactory/v1/config/arg_parser.py +++ b/src/llamafactory/v1/config/arg_parser.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import json import sys from pathlib import Path @@ -28,6 +27,9 @@ from .sample_args import SampleArguments from .training_args import TrainingArguments +InputArgument = Optional[Union[dict[str, Any], list[str]]] + + def validate_args( data_args: DataArguments, model_args: ModelArguments, @@ -43,9 +45,7 @@ def validate_args( raise ValueError("Quantization is not supported with deepspeed backend.") -def get_args( - args: Optional[Union[dict[str, Any], list[str]]] = None, -) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]: +def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]: """Parse arguments from command line or config file.""" parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments]) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS") diff --git a/src/llamafactory/v1/config/arg_utils.py b/src/llamafactory/v1/config/arg_utils.py index f2c0183d9..43c708b0a 100644 --- a/src/llamafactory/v1/config/arg_utils.py +++ b/src/llamafactory/v1/config/arg_utils.py @@ -18,7 +18,7 @@ import json from enum import Enum, unique -from typing import Any, Optional, Union +from typing import Optional, Union class PluginConfig(dict): @@ -32,25 +32,16 @@ class PluginConfig(dict): return self["name"] - def __getattr__(self, key: str) -> Any: - try: - return self[key] - except KeyError: - raise AttributeError(f"Attribute {key} not found.") - - def __setattr__(self, key: str, value: Any): - self[key] = value - PluginArgument = Optional[Union[PluginConfig, dict, str]] @unique -class AutoClass(str, Enum): +class ModelClass(str, Enum): """Auto class for model config.""" - CAUSALLM = "llm" - CLASSIFICATION = "cls" + LLM = "llm" + CLS = "cls" OTHER = "other" diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py index 1dfe962b1..87d1a160c 100644 --- a/src/llamafactory/v1/config/model_args.py +++ b/src/llamafactory/v1/config/model_args.py @@ -14,8 +14,9 @@ from dataclasses import dataclass, field +from typing import Optional -from .arg_utils import AutoClass, PluginConfig, get_plugin_config +from .arg_utils import ModelClass, PluginConfig, get_plugin_config @dataclass @@ -31,19 +32,19 @@ class ModelArguments: default=True, metadata={"help": "Use fast processor from Hugging Face."}, ) - auto_class: AutoClass = field( - default=AutoClass.CAUSALLM, + model_class: ModelClass = field( + default=ModelClass.LLM, metadata={"help": "Model class from Hugging Face."}, ) - peft_config: PluginConfig = field( + peft_config: Optional[PluginConfig] = field( default=None, metadata={"help": "PEFT configuration for the model."}, ) - kernel_config: PluginConfig = field( + kernel_config: Optional[PluginConfig] = field( default=None, metadata={"help": "Kernel configuration for the model."}, ) - quant_config: PluginConfig = field( + quant_config: Optional[PluginConfig] = field( default=None, metadata={"help": "Quantization configuration for the model."}, ) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 5a14cab66..d1afaf483 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import os from dataclasses import dataclass, field +from typing import Optional +from uuid import uuid4 -from .arg_utils import PluginArgument, get_plugin_config +from .arg_utils import PluginConfig, get_plugin_config @dataclass class TrainingArguments: output_dir: str = field( - default="", + default=os.path.join("outputs", str(uuid4())), metadata={"help": "Path to the output directory."}, ) micro_batch_size: int = field( @@ -40,7 +42,7 @@ class TrainingArguments: default=False, metadata={"help": "Use bf16 for training."}, ) - dist_config: PluginArgument = field( + dist_config: Optional[PluginConfig] = field( default=None, metadata={"help": "Distribution configuration for training."}, ) diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 331c14173..22c106a8e 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -17,11 +17,10 @@ Init Phase: 1. Init dataloader. -2. Init model worker. -3. Init optimizer (deepspeed). -4. Shard model. -5. Init optimizer (fsdp). -6. Init scheduler. +2. Init optimizer (deepspeed). +3. Shard model. +4. Init optimizer (fsdp). +5. Init scheduler. Train Phase: 1. Train Loop @@ -29,8 +28,7 @@ Train Phase: """ from ..config.training_args import TrainingArguments -from ..utils.types import TorchDataset -from .model_worker import ModelWorker +from ..utils.types import HFModel, Processor, TorchDataset from .trainer_utils.data_collator import DataCollator @@ -38,21 +36,20 @@ class BaseTrainer: def __init__( self, args: TrainingArguments, + model: HFModel, + processor: Processor, dataset: TorchDataset, - data_collator: DataCollator, - model_worker: ModelWorker, ) -> None: self.args = args + self.model = model + self.processor = processor self.dataset = dataset - self.data_collator = data_collator - self.model_worker = model_worker + self.data_collator = DataCollator() self.optimizer = None self.lr_scheduler = None def init_model_and_optimizer(self) -> None: - self.model_worker.init_model_config() - # with self.dist_plugin.get_model_init_context(): - # self.model = self.model_worker.init_model(self.model_config) + pass def create_dataloader(self) -> None: pass diff --git a/src/llamafactory/v1/core/chat_sampler.py b/src/llamafactory/v1/core/chat_sampler.py index 412185b5f..a4dc9d6da 100644 --- a/src/llamafactory/v1/core/chat_sampler.py +++ b/src/llamafactory/v1/core/chat_sampler.py @@ -15,12 +15,12 @@ from abc import ABC, abstractmethod from ..config.sample_args import SampleArguments, SampleBackend -from .model_worker import ModelWorker +from .model_loader import ModelLoader class BaseEngine(ABC): @abstractmethod - def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ... + def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ... @abstractmethod async def generate(self): @@ -32,15 +32,13 @@ class BaseEngine(ABC): class HuggingFaceEngine(BaseEngine): - def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None: - self.model = model_worker.get_model() - self.processor = model_worker.get_processor() + def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None: self.args = sample_args class ChatSampler: - def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None: + def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None: if sample_args.sample_backend == SampleBackend.HF: - self.engine = HuggingFaceEngine(model_worker, sample_args) + self.engine = HuggingFaceEngine(model_loader, sample_args) else: raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}") diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py index d7d6cbc31..98d59d424 100644 --- a/src/llamafactory/v1/core/data_engine.py +++ b/src/llamafactory/v1/core/data_engine.py @@ -26,7 +26,7 @@ Get Data Sample: """ import os -from collections.abc import AsyncIterable, Iterable +from collections.abc import Iterable from typing import Any, Union from huggingface_hub import hf_hub_download @@ -38,7 +38,11 @@ from ..utils.types import DatasetInfo, HFDataset, Sample class DataEngine(Dataset): - """Data engine.""" + """Data engine. + + Args: + data_args: Data arguments. + """ def __init__(self, data_args: DataArguments) -> None: self.args = data_args @@ -51,11 +55,11 @@ class DataEngine(Dataset): """List of (dataset_name, sample_index)""" self.streaming: bool = False """Whether dataset is streaming.""" - self.get_dataset_info() - self.load_dataset() - self.build_data_index() + self._get_dataset_info() + self._load_dataset() + self._build_data_index() - def get_dataset_info(self) -> None: + def _get_dataset_info(self) -> None: """Get dataset info from data arguments.""" if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file self.dataset_infos = OmegaConf.load(self.args.dataset) @@ -68,31 +72,32 @@ class DataEngine(Dataset): else: # hf hub dataset, e.g. llamafactory/v1-sft-demo self.dataset_infos = {"default": {"path": self.args.dataset}} - def load_dataset(self) -> None: + def _load_dataset(self) -> None: """Load datasets according to dataset info.""" - for key, value in self.dataset_infos.items(): - split = value.get("split", "train") - streaming = value.get("streaming", False) + for dataset_name, dataset_info in self.dataset_infos.items(): + split = dataset_info.get("split", "train") + streaming = dataset_info.get("streaming", False) self.streaming |= streaming - if value.get("source", "hf_hub") == "hf_hub": + if dataset_info.get("source", "hf_hub") == "hf_hub": from datasets import load_dataset - self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming) + self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming) else: # data loader plugin from ..plugins.data_plugins.loader import DataLoaderPlugin - self.datasets[key] = DataLoaderPlugin(value["source"]).load(value) + self.datasets[dataset_name] = DataLoaderPlugin(dataset_info["source"]).load(dataset_info) - def build_data_index(self) -> None: + def _build_data_index(self) -> None: """Build dataset index.""" for dataset_name, dataset in self.datasets.items(): - size = self.dataset_infos[dataset_name].get("size") - weight = self.dataset_infos[dataset_name].get("weight") - if self.streaming: + streaming = self.dataset_infos[dataset_name].get("streaming", False) + if streaming: data_index = [(dataset_name, -1) for _ in range(1000)] else: data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))] + size = self.dataset_infos[dataset_name].get("size") + weight = self.dataset_infos[dataset_name].get("weight") if size or weight: # data index plugin from ..plugins.data_plugins.loader import DataIndexPlugin @@ -144,7 +149,7 @@ class DataEngine(Dataset): if isinstance(index, int): dataset_name, sample_index = self.data_index[index] return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) - else: + else: # data selector plugin from ..plugins.data_plugins.loader import DataSelectorPlugin selected_index = DataSelectorPlugin().select(self.data_index, index) @@ -163,30 +168,18 @@ class DataEngine(Dataset): Returns: Iterable[Sample]: Dataset iterator. """ - if self.streaming: - pass - else: - # TODO: add shuffle here - pass - - raise NotImplementedError() - - async def __aiter__(self) -> AsyncIterable[Sample]: - """Get dataset async iterator. - - Returns: - AsyncIterable[Sample]: Dataset async iterator. - """ - if self.streaming: - pass - else: - # TODO: add shuffle here - pass + # NOTE: hf iterable dataset uses worker ids while map dataset does not + # NOTE: add worker id and shuffle to the map dataset + # https://github.com/huggingface/datasets/blob/4.0.0/src/datasets/iterable_dataset.py#L2214 raise NotImplementedError() if __name__ == "__main__": + """ + python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml + python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml + """ from ..config.arg_parser import get_args data_args, *_ = get_args() diff --git a/src/llamafactory/v1/core/model_loader.py b/src/llamafactory/v1/core/model_loader.py new file mode 100644 index 000000000..870f31f0f --- /dev/null +++ b/src/llamafactory/v1/core/model_loader.py @@ -0,0 +1,128 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of model loader. + +Init Phase: +1. Init processor. +2. Init model config. +3. Init model. +4. Init adapter. + +""" + +import torch +from transformers import AutoConfig, AutoProcessor + +from ..accelerator.interface import DistributedInterface +from ..config.model_args import ModelArguments, ModelClass +from ..utils import logging +from ..utils.types import HFConfig, HFModel, Processor + + +logger = logging.get_logger(__name__) + + +class ModelLoader: + """Model loader. + + Args: + model_args: Model arguments. + is_trainable: Whether to train the model. + """ + + def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None: + self.args = model_args + """Model arguments.""" + self.is_train = is_train + """Whether to train the model.""" + self.processor = self._init_processor() + """Tokenizer or multi-modal processor.""" + self.model_config = self._init_model_config() + """Model configuration.""" + self.model = self._init_model() + """HF model.""" + + def _init_processor(self) -> Processor: + """Init processor.""" + return AutoProcessor.from_pretrained( + self.args.model, + trust_remote_code=self.args.trust_remote_code, + use_fast=self.args.use_fast_processor, + ) + + def _init_model_config(self) -> HFConfig: + """Init model config.""" + return AutoConfig.from_pretrained( + self.args.model, + trust_remote_code=self.args.trust_remote_code, + ) + + def _init_model(self) -> HFModel: + """Init model. + + Let transformers handle the model init context. + https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538 + """ + if self.args.model_class == ModelClass.LLM: + from transformers import AutoModelForCausalLM, AutoModelForImageTextToText + + if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys(): + AutoClass = AutoModelForImageTextToText + else: + AutoClass = AutoModelForCausalLM + + elif self.args.model_class == ModelClass.CLS: + from transformers import AutoModelForTokenClassification + + AutoClass = AutoModelForTokenClassification + else: + from transformers import AutoModel + + AutoClass = AutoModel + + # map the entire model to the current accelerator + model = AutoClass.from_pretrained( + self.args.model, + config=self.model_config, + dtype="auto", + device_map=DistributedInterface.current_accelerator, + trust_remote_code=self.args.trust_remote_code, + ) + + if self.args.peft_config is None: + if self.is_train: + logger.info_rank0("Fine-tuning mode: full tuning") + model = model.to(torch.float32) + else: + logger.info_rank0("Inference the original model") + else: + from ..plugins.model_plugins.peft import PeftPlugin + + model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train) + + return model + + +if __name__ == "__main__": + """ + python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5 + """ + from ..config.arg_parser import get_args + + _, model_args, *_ = get_args() + model_loader = ModelLoader(model_args=model_args) + print(model_loader.processor) + print(model_loader.model_config) + print(model_loader.model) diff --git a/src/llamafactory/v1/core/model_worker.py b/src/llamafactory/v1/core/model_worker.py deleted file mode 100644 index 183e76807..000000000 --- a/src/llamafactory/v1/core/model_worker.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The definition of model worker. - -Init Phase: -1. Init processor. -2. Init model config. -3. Init model. -4. Init adapter. - -""" - -from typing import Optional - -import torch -from transformers import AutoConfig, AutoProcessor - -from ..accelerator.helper import DeviceType -from ..config.model_args import AutoClass, ModelArguments -from ..utils.types import HFConfig, HFModel, Processor - - -class ModelWorker: - def __init__(self, model_args: ModelArguments) -> None: - self.args = model_args - """Model arguments.""" - self.processor: Optional[Processor] = None - """Tokenizer or multi-modal processor.""" - self.model_config: Optional[HFConfig] = None - """Model configuration.""" - self.model: Optional[HFModel] = None - """HF model.""" - self.is_adapter = False - """Whether the model has adapter.""" - - def init_processor(self) -> None: - if self.processor is not None: - return - - self.processor = AutoProcessor.from_pretrained( - self.args.model, - trust_remote_code=self.args.trust_remote_code, - use_fast=self.args.use_fast_processor, - ) - - def init_model_config(self) -> None: - if self.model_config is not None: - return - - self.model_config = AutoConfig.from_pretrained( - self.args.model, - trust_remote_code=self.args.trust_remote_code, - ) - - def init_model(self) -> None: - if self.model is not None: - return - - self.init_model_config() - - if self.args.auto_class == AutoClass.CAUSALLM: - from transformers import AutoModelForCausalLM, AutoModelForImageTextToText - - if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys(): - ModelClass = AutoModelForImageTextToText - else: - ModelClass = AutoModelForCausalLM - elif self.args.auto_class == AutoClass.CLASSIFICATION: - from transformers import AutoModelForTokenClassification - - ModelClass = AutoModelForTokenClassification - else: - from transformers import AutoModel - - ModelClass = AutoModel - - default_device_type = torch.get_default_device().type - if default_device_type == DeviceType.META: - self.model = ModelClass.from_config(self.model_config) - else: - self.model = ModelClass.from_pretrained( - self.args.model, - config=self.model_config, - dtype="auto", - device_map=default_device_type, - trust_remote_code=self.args.trust_remote_code, - ) - - def init_adapter(self) -> None: - if self.is_adapter: - return - - if self.args.peft_config is not None: - from ..plugins.model_plugins.peft import PeftPlugin - - self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config) - - self.is_adapter = True - - def get_processor(self) -> Processor: - return self.processor - - def get_model_config(self) -> HFConfig: - return self.model_config - - def get_model(self) -> HFModel: - return self.model diff --git a/src/llamafactory/v1/plugins/model_plugins/peft.py b/src/llamafactory/v1/plugins/model_plugins/peft.py index 279c5928d..0dc29ba88 100644 --- a/src/llamafactory/v1/plugins/model_plugins/peft.py +++ b/src/llamafactory/v1/plugins/model_plugins/peft.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, TypedDict +from typing import Literal, Optional, TypedDict from peft import LoraConfig, PeftModel, get_peft_model @@ -31,12 +31,27 @@ class LoraConfigDict(TypedDict, total=False): """Target modules.""" +class FreezeConfigDict(TypedDict, total=False): + name: Literal["freeze"] + """Plugin name.""" + freeze_trainable_layers: int + """Freeze trainable layers.""" + freeze_trainable_modules: Optional[list[str]] + """Freeze trainable modules.""" + + class PeftPlugin(BasePlugin): - pass + def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel: + return super().__call__(model, config) @PeftPlugin("lora").register -def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel: +def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel: peft_config = LoraConfig(**config) model = get_peft_model(model, peft_config) return model + + +@PeftPlugin("freeze").register +def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel: + raise NotImplementedError() diff --git a/src/llamafactory/v1/trainers/sft_trainer.py b/src/llamafactory/v1/trainers/sft_trainer.py index 781e0a7bb..c48c1092f 100644 --- a/src/llamafactory/v1/trainers/sft_trainer.py +++ b/src/llamafactory/v1/trainers/sft_trainer.py @@ -13,11 +13,11 @@ # limitations under the License. -from ..accelerator.interface import DistributedInterface, DistributedStrategy +from ..accelerator.interface import DistributedInterface from ..config.arg_parser import get_args from ..core.base_trainer import BaseTrainer from ..core.data_engine import DataEngine -from ..core.model_worker import ModelWorker +from ..core.model_loader import ModelLoader class SFTTrainer(BaseTrainer): @@ -26,8 +26,13 @@ class SFTTrainer(BaseTrainer): def run_sft(user_args): model_args, data_args, training_args, _ = get_args(user_args) - DistributedInterface(DistributedStrategy()) + DistributedInterface(training_args.dist_config) data_engine = DataEngine(data_args) - model_worker = ModelWorker(model_args) - trainer = SFTTrainer(training_args, model_worker, data_engine) + model_loader = ModelLoader(model_args) + trainer = SFTTrainer( + args=training_args, + model=model_loader.model, + processor=model_loader.processor, + dataset=data_engine, + ) trainer.fit() diff --git a/src/llamafactory/v1/utils/dtype.py b/src/llamafactory/v1/utils/dtype.py new file mode 100644 index 000000000..09dfb0112 --- /dev/null +++ b/src/llamafactory/v1/utils/dtype.py @@ -0,0 +1,92 @@ +# Copyright 2025 Bytedance Ltd. and the LlamaFactory team. +# +# This code is inspired by the Bytedance's verl library. +# https://github.com/volcengine/verl/blob/v0.6.1/verl/utils/torch_dtypes.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from typing import Union + +import torch +from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device + +from ..accelerator.interface import DistributedInterface + + +class DtypeRegistry: + HALF_LIST = ["fp16", "float16", "half", torch.float16] + FLOAT_LIST = ["fp32", "float32", "float", torch.float32] + BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16] + + +class DtypeInterface: + """Type of precision used.""" + + _is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator) + _is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator) + _is_fp32_available = True + + @staticmethod + def is_available(precision: Union[str, torch.dtype]) -> bool: + if precision in DtypeRegistry.HALF_LIST: + return DtypeInterface._is_fp16_available + elif precision in DtypeRegistry.FLOAT_LIST: + return DtypeInterface._is_fp32_available + elif precision in DtypeRegistry.BFLOAT_LIST: + return DtypeInterface._is_bf16_available + else: + raise RuntimeError(f"Unexpected precision: {precision}") + + @staticmethod + def is_fp16(precision: Union[str, torch.dtype]) -> bool: + return precision in DtypeRegistry.HALF_LIST + + @staticmethod + def is_fp32(precision: Union[str, torch.dtype]) -> bool: + return precision in DtypeRegistry.FLOAT_LIST + + @staticmethod + def is_bf16(precision: Union[str, torch.dtype]) -> bool: + return precision in DtypeRegistry.BFLOAT_LIST + + @staticmethod + def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype: + if precision in DtypeRegistry.HALF_LIST: + return torch.float16 + elif precision in DtypeRegistry.FLOAT_LIST: + return torch.float32 + elif precision in DtypeRegistry.BFLOAT_LIST: + return torch.bfloat16 + else: + raise RuntimeError(f"Unexpected precision: {precision}") + + @staticmethod + def to_str(precision: torch.dtype) -> str: + if precision == torch.float16: + return "float16" + elif precision == torch.float32: + return "float32" + elif precision == torch.bfloat16: + return "bfloat16" + else: + raise RuntimeError(f"Unexpected precision: {precision}") + + @contextmanager + def set_dtype(self, precision: Union[str, torch.dtype]): + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(self.to_dtype(precision)) + try: + yield + finally: + torch.set_default_dtype(original_dtype) diff --git a/src/llamafactory/v1/utils/logging.py b/src/llamafactory/v1/utils/logging.py index b2f9362f1..ebb890986 100644 --- a/src/llamafactory/v1/utils/logging.py +++ b/src/llamafactory/v1/utils/logging.py @@ -29,7 +29,7 @@ _default_log_level: "logging._Level" = logging.INFO class _Logger(logging.Logger): - r"""A logger that supports rank0 logging.""" + """A logger that supports rank0 logging.""" def info_rank0(self, *args, **kwargs) -> None: self.info(*args, **kwargs) @@ -42,7 +42,7 @@ class _Logger(logging.Logger): def _get_default_logging_level() -> "logging._Level": - r"""Return the default logging level.""" + """Return the default logging level.""" env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None) if env_level_str: if env_level_str.upper() in logging._nameToLevel: @@ -62,7 +62,7 @@ def _get_library_root_logger() -> "_Logger": def _configure_library_root_logger() -> None: - r"""Configure root logger using a stdout stream handler with an explicit format.""" + """Configure root logger using a stdout stream handler with an explicit format.""" global _default_handler with _thread_lock: @@ -82,7 +82,7 @@ def _configure_library_root_logger() -> None: def get_logger(name: Optional[str] = None) -> "_Logger": - r"""Return a logger with the specified name. It it not supposed to be accessed externally.""" + """Return a logger with the specified name. It it not supposed to be accessed externally.""" if name is None: name = _get_library_name() @@ -91,13 +91,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger": def add_handler(handler: "logging.Handler") -> None: - r"""Add a handler to the root logger.""" + """Add a handler to the root logger.""" _configure_library_root_logger() _get_library_root_logger().addHandler(handler) def remove_handler(handler: logging.Handler) -> None: - r"""Remove a handler to the root logger.""" + """Remove a handler to the root logger.""" _configure_library_root_logger() _get_library_root_logger().removeHandler(handler) diff --git a/src/llamafactory/v1/utils/plugin.py b/src/llamafactory/v1/utils/plugin.py index 0cbf25765..2c1138a87 100644 --- a/src/llamafactory/v1/utils/plugin.py +++ b/src/llamafactory/v1/utils/plugin.py @@ -38,7 +38,7 @@ class BasePlugin: self.name = name @property - def register(self) -> Callable: + def register(self): """Decorator to register a function as a plugin. Example usage: @@ -60,7 +60,7 @@ class BasePlugin: return decorator - def __call__(self, *args, **kwargs) -> Callable: + def __call__(self, *args, **kwargs): """Call the registered function with the given arguments. Example usage: @@ -75,6 +75,9 @@ class BasePlugin: if __name__ == "__main__": + """ + python -m llamafactory.v1.utils.plugin + """ class PrintPlugin(BasePlugin): pass diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py index eb81298de..5d1899609 100644 --- a/src/llamafactory/v1/utils/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: import torch import torch.utils.data import transformers + from torch.distributed import ProcessGroup from torch.distributed.fsdp import FullyShardedDataParallel Tensor = torch.Tensor @@ -37,6 +38,7 @@ if TYPE_CHECKING: Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin] Optimizer = torch.optim.Optimizer Scheduler = torch.optim.lr_scheduler.LRScheduler + ProcessGroup = ProcessGroup else: Tensor = None TensorLike = None @@ -50,6 +52,7 @@ else: Processor = None Optimizer = None Scheduler = None + ProcessGroup = None class DatasetInfo(TypedDict, total=False): @@ -69,6 +72,19 @@ class DatasetInfo(TypedDict, total=False): """Is streaming dataset, default to False.""" +class DistributedConfig(TypedDict, total=False): + mp_replicate_size: NotRequired[int] + """Model parallel replicate size, default to 1.""" + mp_shard_size: NotRequired[int] + """Model parallel shard size, default to world_size // mp_replicate_size.""" + dp_size: NotRequired[int] + """Data parallel size, default to world_size // cp_size.""" + cp_size: NotRequired[int] + """Context parallel size, default to 1.""" + timeout: NotRequired[int] + """Timeout for distributed communication, default to 600.""" + + class Content(TypedDict): type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"] value: str diff --git a/tests/conftest.py b/tests/conftest.py index ab436b5da..9ad5bf959 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ Contains shared fixtures, pytest configuration, and custom markers. """ import pytest +from pytest import Config, Item from llamafactory.extras.misc import get_current_device, is_env_enabled from llamafactory.train.test_utils import patch_valuehead_model @@ -29,21 +30,15 @@ except Exception: CURRENT_DEVICE = "cpu" -def pytest_configure(config): +def pytest_configure(config: Config): """Register custom pytest markers.""" config.addinivalue_line( "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"' or set RUN_SLOW=1 to run)" ) - config.addinivalue_line( - "markers", "skip_on_devices: skip test on specified devices, e.g., @pytest.mark.skip_on_devices('npu', 'xpu')" - ) - config.addinivalue_line( - "markers", "require_device: test requires specific device, e.g., @pytest.mark.require_device('cuda')" - ) config.addinivalue_line("markers", "runs_on: test requires specific device, e.g., @pytest.mark.runs_on(['cpu'])") -def _handle_runs_on(items): +def _handle_runs_on(items: list[Item]): """Skip tests on specified devices based on runs_on marker. Usage: @@ -68,7 +63,7 @@ def _handle_runs_on(items): ) -def _handle_slow_tests(items): +def _handle_slow_tests(items: list[Item]): """Skip slow tests unless RUN_SLOW environment variable is set. Usage: @@ -85,51 +80,9 @@ def _handle_slow_tests(items): item.add_marker(skip_slow) -def _handle_device_skips(items): - """Skip tests on specified devices based on skip_on_devices marker. - - Usage: - @pytest.mark.skip_on_devices("npu", "xpu") - def test_something(): - pass - """ - for item in items: - skip_marker = item.get_closest_marker("skip_on_devices") - if skip_marker: - skip_devices = skip_marker.args - if CURRENT_DEVICE in skip_devices: - item.add_marker( - pytest.mark.skip( - reason=f"test skipped on {CURRENT_DEVICE.upper()} (skip list: {', '.join(skip_devices)})" - ) - ) - - -def _handle_device_requirements(items): - """Skip tests that require a specific device when running on other devices. - - Usage: - @pytest.mark.require_device("cuda") - def test_gpu_only(): - pass - """ - for item in items: - require_marker = item.get_closest_marker("require_device") - if require_marker: - required_device = require_marker.args[0] if require_marker.args else None - if required_device and CURRENT_DEVICE != required_device: - item.add_marker( - pytest.mark.skip( - reason=f"test requires {required_device.upper()} (current: {CURRENT_DEVICE.upper()})" - ) - ) - - -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(config: Config, items: list[Item]): """Modify test collection based on markers and environment.""" _handle_slow_tests(items) - _handle_device_skips(items) - _handle_device_requirements(items) _handle_runs_on(items) diff --git a/tests/model/test_base.py b/tests/model/test_base.py index e40f84fe2..5c3ed801f 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -31,15 +31,13 @@ INFER_ARGS = { @pytest.mark.runs_on(["cpu", "npu"]) -@pytest.mark.skip_on_devices("npu") def test_base(): model = load_infer_model(**INFER_ARGS) ref_model = load_reference_model(TINY_LLAMA3) compare_model(model, ref_model) -@pytest.mark.runs_on(["cpu", "npu"]) -@pytest.mark.skip_on_devices("npu") +@pytest.mark.runs_on(["cpu"]) @pytest.mark.usefixtures("fix_valuehead_cpu_loading") def test_valuehead(): model = load_infer_model(add_valuehead=True, **INFER_ARGS) diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py index 3b238edb9..abbc7fc21 100644 --- a/tests/model/test_lora.py +++ b/tests/model/test_lora.py @@ -104,7 +104,6 @@ def test_lora_train_valuehead(): @pytest.mark.runs_on(["cpu", "npu"]) -@pytest.mark.skip_on_devices("npu") def test_lora_inference(): model = load_infer_model(**INFER_ARGS) ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload() diff --git a/tests/version.txt b/tests/version.txt index be83783ea..a9b60f523 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.4.103 +0.9.4.104 diff --git a/tests_v1/accelerator/test_interface.py b/tests_v1/accelerator/test_interface.py index 2651ebf78..be2485f92 100644 --- a/tests_v1/accelerator/test_interface.py +++ b/tests_v1/accelerator/test_interface.py @@ -15,11 +15,11 @@ import os -from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy +from llamafactory.v1.accelerator.interface import DistributedInterface def test_distributed_interface(): - DistributedInterface(DistributedStrategy()) + DistributedInterface() assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0")) assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1")) assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0")) diff --git a/tests_v1/conftest.py b/tests_v1/conftest.py new file mode 100644 index 000000000..455b33497 --- /dev/null +++ b/tests_v1/conftest.py @@ -0,0 +1,29 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pytest import Config, Item + +from llamafactory.v1.utils.packages import is_transformers_version_greater_than + + +def pytest_collection_modifyitems(config: Config, items: list[Item]): + if is_transformers_version_greater_than("4.57.0"): + return + + skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") + + for item in items: + if "tests_v1" in str(item.fspath): + item.add_marker(skip_bc) diff --git a/tests_v1/core/test_model_loader.py b/tests_v1/core/test_model_loader.py new file mode 100644 index 000000000..cee2e9f79 --- /dev/null +++ b/tests_v1/core/test_model_loader.py @@ -0,0 +1,33 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from llamafactory.v1.config.model_args import ModelArguments +from llamafactory.v1.core.model_loader import ModelLoader + + +def test_tiny_qwen(): + from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast + + model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5") + model_loader = ModelLoader(model_args) + assert isinstance(model_loader.processor, Qwen2TokenizerFast) + assert isinstance(model_loader.model.config, Qwen2Config) + assert isinstance(model_loader.model, Qwen2ForCausalLM) + assert model_loader.model.dtype == torch.bfloat16 + + +if __name__ == "__main__": + test_tiny_qwen() diff --git a/tests_v1/plugins/data_plugins/test_converter.py b/tests_v1/plugins/data_plugins/test_converter.py index cebdaec76..0ad40aa19 100644 --- a/tests_v1/plugins/data_plugins/test_converter.py +++ b/tests_v1/plugins/data_plugins/test_converter.py @@ -24,7 +24,7 @@ from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin @pytest.mark.parametrize("num_samples", [16]) def test_alpaca_converter(num_samples: int): - data_args = DataArguments(dataset="llamafactory/v1-sft-demo/dataset_info.yaml") + data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml") data_engine = DataEngine(data_args) original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train") indexes = random.choices(range(len(data_engine)), k=num_samples) @@ -54,6 +54,8 @@ def test_sharegpt_converter(): "conversations": [ {"from": "system", "value": "System"}, {"from": "human", "value": "User"}, + {"from": "function_call", "value": "Tool"}, + {"from": "observation", "value": "Observation"}, {"from": "gpt", "value": "Assistant"}, ] } @@ -61,6 +63,8 @@ def test_sharegpt_converter(): "messages": [ {"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"}, {"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"}, + {"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"}, + {"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"}, {"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"}, ] } @@ -69,7 +73,7 @@ def test_sharegpt_converter(): @pytest.mark.parametrize("num_samples", [16]) def test_pair_converter(num_samples: int): - data_args = DataArguments(dataset="llamafactory/tiny-preference-dataset/dataset_info.yaml") + data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml") data_engine = DataEngine(data_args) original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs") indexes = random.choices(range(len(data_engine)), k=num_samples) @@ -112,7 +116,7 @@ def test_pair_converter(num_samples: int): }, ], } - assert data_engine[index] == {"_dataset_name": "dpo_zh_demo", **expected_data} + assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data} if __name__ == "__main__":