From f60a6e3d015962198b7c626936f117e83260bde9 Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Sun, 4 Jan 2026 20:51:46 +0800 Subject: [PATCH] [v1] add init plugin (#9716) --- src/llamafactory/v1/accelerator/interface.py | 5 ++ src/llamafactory/v1/config/__init__.py | 32 ++++++++ src/llamafactory/v1/config/model_args.py | 9 ++- src/llamafactory/v1/config/training_args.py | 2 +- src/llamafactory/v1/core/base_sampler.py | 77 +++++++++++++++++++ src/llamafactory/v1/core/chat_sampler.py | 44 ----------- src/llamafactory/v1/core/model_loader.py | 44 ++++++++--- .../{added_token.py => add_token.py} | 0 .../plugins/model_plugins/initialization.py | 43 +++++++++++ src/llamafactory/v1/samplers/cli_sampler.py | 35 +++++++++ tests/conftest.py | 12 +-- tests_v1/config/test_args_parser.py | 1 - tests_v1/conftest.py | 21 +++-- .../plugins/model_plugins/test_init_plugin.py | 56 ++++++++++++++ 14 files changed, 307 insertions(+), 74 deletions(-) create mode 100644 src/llamafactory/v1/core/base_sampler.py delete mode 100644 src/llamafactory/v1/core/chat_sampler.py rename src/llamafactory/v1/plugins/model_plugins/{added_token.py => add_token.py} (100%) create mode 100644 src/llamafactory/v1/samplers/cli_sampler.py create mode 100644 tests_v1/plugins/model_plugins/test_init_plugin.py diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index f8a5856d2..c7cdf4f5f 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -34,10 +34,14 @@ from typing import Any, Optional from torch.distributed import barrier, destroy_process_group, init_process_group from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from ..utils import logging from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike from . import helper +logger = logging.get_logger(__name__) + + class Dim(str, Enum): """Dimension names.""" @@ -157,6 +161,7 @@ class DistributedInterface: self.data_device_mesh = None self._initialized = True + logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.") def __str__(self) -> str: return ( diff --git a/src/llamafactory/v1/config/__init__.py b/src/llamafactory/v1/config/__init__.py index e69de29bb..b9aceeb31 100644 --- a/src/llamafactory/v1/config/__init__.py +++ b/src/llamafactory/v1/config/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .arg_parser import InputArgument, get_args +from .arg_utils import ModelClass, SampleBackend +from .data_args import DataArguments +from .model_args import ModelArguments +from .sample_args import SampleArguments +from .training_args import TrainingArguments + + +__all__ = [ + "DataArguments", + "InputArgument", + "ModelArguments", + "ModelClass", + "SampleArguments", + "SampleBackend", + "TrainingArguments", + "get_args", +] diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py index 370ed02d1..b79ee86de 100644 --- a/src/llamafactory/v1/config/model_args.py +++ b/src/llamafactory/v1/config/model_args.py @@ -27,14 +27,14 @@ class ModelArguments: default=False, metadata={"help": "Trust remote code from Hugging Face."}, ) - use_fast_processor: bool = field( - default=True, - metadata={"help": "Use fast processor from Hugging Face."}, - ) model_class: ModelClass = field( default=ModelClass.LLM, metadata={"help": "Model class from Hugging Face."}, ) + init_config: PluginConfig | None = field( + default=None, + metadata={"help": "Initialization configuration for the model."}, + ) peft_config: PluginConfig | None = field( default=None, metadata={"help": "PEFT configuration for the model."}, @@ -49,6 +49,7 @@ class ModelArguments: ) def __post_init__(self) -> None: + self.init_config = get_plugin_config(self.init_config) self.peft_config = get_plugin_config(self.peft_config) self.kernel_config = get_plugin_config(self.kernel_config) self.quant_config = get_plugin_config(self.quant_config) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 574ff015e..9bee3095b 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config @dataclass class TrainingArguments: output_dir: str = field( - default=os.path.join("outputs", str(uuid4())), + default=os.path.join("outputs", str(uuid4().hex)), metadata={"help": "Path to the output directory."}, ) micro_batch_size: int = field( diff --git a/src/llamafactory/v1/core/base_sampler.py b/src/llamafactory/v1/core/base_sampler.py new file mode 100644 index 000000000..fbd36ab69 --- /dev/null +++ b/src/llamafactory/v1/core/base_sampler.py @@ -0,0 +1,77 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from ..config import ModelArguments, SampleArguments, SampleBackend +from ..utils.types import HFModel, Processor, TorchDataset + + +class BaseEngine(ABC): + @abstractmethod + def __init__( + self, + args: SampleArguments, + model_args: ModelArguments, + model: HFModel = None, + processor: Processor = None, + ) -> None: + """Initialize the engine. + + Args: + args: Sample arguments. + model_args: Model arguments. + model: Model. + processor: Processor. + """ + ... + + @abstractmethod + async def generate(self, messages): + pass + + @abstractmethod + async def batch_infer(self, data: TorchDataset) -> None: + pass + + +class HuggingFaceEngine(BaseEngine): + def __init__( + self, + args: SampleArguments, + model_args: ModelArguments, + model: HFModel, + processor: Processor, + ) -> None: + self.args = args + + +class BaseSampler: + def __init__( + self, + args: SampleArguments, + model_args: ModelArguments, + model: HFModel, + processor: Processor, + ) -> None: + if args.sample_backend == SampleBackend.HF: + self.engine = HuggingFaceEngine(args, model_args, model, processor) + else: + raise ValueError(f"Unknown sample backend: {args.sample_backend}") + + async def generate(self, messages): + return await self.engine.generate(messages) + + async def batch_infer(self, data: TorchDataset) -> None: + return await self.engine.batch_infer(data) diff --git a/src/llamafactory/v1/core/chat_sampler.py b/src/llamafactory/v1/core/chat_sampler.py deleted file mode 100644 index a4dc9d6da..000000000 --- a/src/llamafactory/v1/core/chat_sampler.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod - -from ..config.sample_args import SampleArguments, SampleBackend -from .model_loader import ModelLoader - - -class BaseEngine(ABC): - @abstractmethod - def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ... - - @abstractmethod - async def generate(self): - pass - - @abstractmethod - async def batch_infer(self): - pass - - -class HuggingFaceEngine(BaseEngine): - def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None: - self.args = sample_args - - -class ChatSampler: - def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None: - if sample_args.sample_backend == SampleBackend.HF: - self.engine = HuggingFaceEngine(model_loader, sample_args) - else: - raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}") diff --git a/src/llamafactory/v1/core/model_loader.py b/src/llamafactory/v1/core/model_loader.py index ef6ca9324..069292274 100644 --- a/src/llamafactory/v1/core/model_loader.py +++ b/src/llamafactory/v1/core/model_loader.py @@ -14,17 +14,24 @@ """The definition of model loader. -Init Phase: +How to use: +model_loader = ModelLoader(model_args, is_trainable=True) +model_loader.processor: Get the tokenizer or multi-modal processor. +model_loader.model_config: Get the model configuration. +model_loader.model: Get the HF model. + +Init Workflow: 1. Init processor. 2. Init model config. 3. Init model. 4. Init adapter. - """ import torch +from accelerate import init_empty_weights from transformers import AutoConfig, AutoProcessor +from ..accelerator.helper import DeviceType from ..accelerator.interface import DistributedInterface from ..config.model_args import ModelArguments, ModelClass from ..utils import logging @@ -55,11 +62,14 @@ class ModelLoader: """HF model.""" def _init_processor(self) -> Processor: - """Init processor.""" + """Init processor. + + NOTE: Transformers v5 always use fast tokenizer. + https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642 + """ return AutoProcessor.from_pretrained( self.args.model, trust_remote_code=self.args.trust_remote_code, - use_fast=self.args.use_fast_processor, ) def _init_model_config(self) -> HFConfig: @@ -92,14 +102,24 @@ class ModelLoader: AutoClass = AutoModel - # map the entire model to the current accelerator - model = AutoClass.from_pretrained( - self.args.model, - config=self.model_config, - dtype="auto", - device_map=DistributedInterface().current_accelerator, - trust_remote_code=self.args.trust_remote_code, - ) + if self.args.init_config is not None: + from ..plugins.model_plugins.initialization import InitPlugin + + init_device = InitPlugin(self.args.init_config.name)() + else: + init_device = DistributedInterface().current_accelerator + + if init_device.type == DeviceType.META: + with init_empty_weights(): + model = AutoClass.from_config(self.model_config) + else: + model = AutoClass.from_pretrained( + self.args.model, + config=self.model_config, + dtype="auto", + device_map=init_device, + trust_remote_code=self.args.trust_remote_code, + ) if self.args.peft_config is None: if self.is_train: diff --git a/src/llamafactory/v1/plugins/model_plugins/added_token.py b/src/llamafactory/v1/plugins/model_plugins/add_token.py similarity index 100% rename from src/llamafactory/v1/plugins/model_plugins/added_token.py rename to src/llamafactory/v1/plugins/model_plugins/add_token.py diff --git a/src/llamafactory/v1/plugins/model_plugins/initialization.py b/src/llamafactory/v1/plugins/model_plugins/initialization.py index e69de29bb..5e6c8bb99 100644 --- a/src/llamafactory/v1/plugins/model_plugins/initialization.py +++ b/src/llamafactory/v1/plugins/model_plugins/initialization.py @@ -0,0 +1,43 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...accelerator.helper import DeviceType +from ...accelerator.interface import DistributedInterface +from ...utils.plugin import BasePlugin + + +class InitPlugin(BasePlugin): + def __call__(self) -> torch.device: + return super().__call__() + + +@InitPlugin("init_on_meta").register +def init_on_meta() -> torch.device: + return torch.device(DeviceType.META.value) + + +@InitPlugin("init_on_rank0").register +def init_on_rank0() -> torch.device: + if DistributedInterface().get_rank() == 0: + return torch.device(DeviceType.CPU.value) + else: + return torch.device(DeviceType.META.value) + + +@InitPlugin("init_on_default").register +def init_on_default() -> torch.device: + return DistributedInterface().current_accelerator diff --git a/src/llamafactory/v1/samplers/cli_sampler.py b/src/llamafactory/v1/samplers/cli_sampler.py new file mode 100644 index 000000000..c08bc838a --- /dev/null +++ b/src/llamafactory/v1/samplers/cli_sampler.py @@ -0,0 +1,35 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..config import InputArgument, SampleBackend, get_args +from ..core.base_sampler import BaseSampler +from ..core.model_loader import ModelLoader + + +def run_chat(args: InputArgument = None): + data_args, model_args, _, sample_args = get_args(args) + if sample_args.sample_backend != SampleBackend.HF: + model_args.init_plugin = {"name": "init_on_meta"} + + model_loader = ModelLoader(model_args) + sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor) + if data_args.dataset is not None: + sampler.batch_infer() + else: + sampler.generate() + + +if __name__ == "__main__": + run_chat() diff --git a/tests/conftest.py b/tests/conftest.py index 65c779fc2..cd20a0d2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LLaMA-Factory test configuration. +"""LlamaFactory test configuration. Contains shared fixtures, pytest configuration, and custom markers. """ @@ -110,11 +110,10 @@ def _handle_device_visibility(items: list[Item]): def pytest_collection_modifyitems(config: Config, items: list[Item]): """Modify test collection based on markers and environment.""" # Handle version compatibility (from HEAD) - if not is_transformers_version_greater_than("4.57.0"): - skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") - for item in items: - if "tests_v1" in str(item.fspath): - item.add_marker(skip_bc) + skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") + for item in items: + if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"): + item.add_marker(skip_bc) _handle_slow_tests(items) _handle_runs_on(items) @@ -156,6 +155,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) - monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") else: monkeypatch.setenv(env_key, "0") + if CURRENT_DEVICE == "cuda": monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) elif CURRENT_DEVICE == "npu": diff --git a/tests_v1/config/test_args_parser.py b/tests_v1/config/test_args_parser.py index b39f4532c..37ba7bc36 100644 --- a/tests_v1/config/test_args_parser.py +++ b/tests_v1/config/test_args_parser.py @@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path): ### model model: "llamafactory/tiny-random-qwen2.5" trust_remote_code: true - use_fast_processor: true model_class: "llm" kernel_config: name: "auto" diff --git a/tests_v1/conftest.py b/tests_v1/conftest.py index 018d723a8..bf1a4d76a 100644 --- a/tests_v1/conftest.py +++ b/tests_v1/conftest.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LLaMA-Factory test configuration. +"""LlamaFactory test configuration. Contains shared fixtures, pytest configuration, and custom markers. """ @@ -22,6 +22,7 @@ import sys import pytest import torch +import torch.distributed as dist from pytest import Config, FixtureRequest, Item, MonkeyPatch from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count @@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]): def pytest_collection_modifyitems(config: Config, items: list[Item]): """Modify test collection based on markers and environment.""" # Handle version compatibility (from HEAD) - if not is_transformers_version_greater_than("4.57.0"): - skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") - for item in items: - if "tests_v1" in str(item.fspath): - item.add_marker(skip_bc) + skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") + for item in items: + if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"): + item.add_marker(skip_bc) _handle_slow_tests(items) _handle_runs_on(items) _handle_device_visibility(items) +@pytest.fixture(autouse=True) +def _cleanup_distributed_state(): + """Cleanup distributed state after each test.""" + yield + if dist.is_initialized(): + dist.destroy_process_group() + + @pytest.fixture(autouse=True) def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: """Set environment variables for distributed tests if specific devices are requested.""" @@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) - monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") else: monkeypatch.setenv(env_key, "0") + if CURRENT_DEVICE == "cuda": monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) elif CURRENT_DEVICE == "npu": diff --git a/tests_v1/plugins/model_plugins/test_init_plugin.py b/tests_v1/plugins/model_plugins/test_init_plugin.py new file mode 100644 index 000000000..80e9d178b --- /dev/null +++ b/tests_v1/plugins/model_plugins/test_init_plugin.py @@ -0,0 +1,56 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from llamafactory.v1.accelerator.interface import DistributedInterface +from llamafactory.v1.config.arg_parser import get_args +from llamafactory.v1.core.model_loader import ModelLoader + + +def test_init_on_meta(): + _, model_args, *_ = get_args( + dict( + model="llamafactory/tiny-random-qwen2.5", + init_config={"name": "init_on_meta"}, + ) + ) + model_loader = ModelLoader(model_args=model_args) + assert model_loader.model.device.type == "meta" + + +@pytest.mark.runs_on(["cuda", "npu"]) +def test_init_on_rank0(): + _, model_args, *_ = get_args( + dict( + model="llamafactory/tiny-random-qwen2.5", + init_config={"name": "init_on_rank0"}, + ) + ) + model_loader = ModelLoader(model_args=model_args) + if DistributedInterface().get_rank() == 0: + assert model_loader.model.device.type == "cpu" + else: + assert model_loader.model.device.type == "meta" + + +def test_init_on_default(): + _, model_args, *_ = get_args( + dict( + model="llamafactory/tiny-random-qwen2.5", + init_config={"name": "init_on_default"}, + ) + ) + model_loader = ModelLoader(model_args=model_args) + assert model_loader.model.device.type == DistributedInterface().current_accelerator.type