[v1] add init plugin (#9716)

This commit is contained in:
Yaowei Zheng
2026-01-04 20:51:46 +08:00
committed by GitHub
parent 81b8a50aa5
commit f60a6e3d01
14 changed files with 307 additions and 74 deletions

View File

@@ -34,10 +34,14 @@ from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils import logging
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from . import helper from . import helper
logger = logging.get_logger(__name__)
class Dim(str, Enum): class Dim(str, Enum):
"""Dimension names.""" """Dimension names."""
@@ -157,6 +161,7 @@ class DistributedInterface:
self.data_device_mesh = None self.data_device_mesh = None
self._initialized = True self._initialized = True
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
def __str__(self) -> str: def __str__(self) -> str:
return ( return (

View File

@@ -0,0 +1,32 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments
__all__ = [
"DataArguments",
"InputArgument",
"ModelArguments",
"ModelClass",
"SampleArguments",
"SampleBackend",
"TrainingArguments",
"get_args",
]

View File

@@ -27,14 +27,14 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Trust remote code from Hugging Face."}, metadata={"help": "Trust remote code from Hugging Face."},
) )
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field( model_class: ModelClass = field(
default=ModelClass.LLM, default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."}, metadata={"help": "Model class from Hugging Face."},
) )
init_config: PluginConfig | None = field(
default=None,
metadata={"help": "Initialization configuration for the model."},
)
peft_config: PluginConfig | None = field( peft_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "PEFT configuration for the model."}, metadata={"help": "PEFT configuration for the model."},
@@ -49,6 +49,7 @@ class ModelArguments:
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config) self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config) self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config) self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
output_dir: str = field( output_dir: str = field(
default=os.path.join("outputs", str(uuid4())), default=os.path.join("outputs", str(uuid4().hex)),
metadata={"help": "Path to the output directory."}, metadata={"help": "Path to the output directory."},
) )
micro_batch_size: int = field( micro_batch_size: int = field(

View File

@@ -0,0 +1,77 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Processor, TorchDataset
class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel = None,
processor: Processor = None,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
processor: Processor.
"""
...
@abstractmethod
async def generate(self, messages):
pass
@abstractmethod
async def batch_infer(self, data: TorchDataset) -> None:
pass
class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
) -> None:
self.args = args
class BaseSampler:
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, processor)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages):
return await self.engine.generate(messages)
async def batch_infer(self, data: TorchDataset) -> None:
return await self.engine.batch_infer(data)

View File

@@ -1,44 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -14,17 +14,24 @@
"""The definition of model loader. """The definition of model loader.
Init Phase: How to use:
model_loader = ModelLoader(model_args, is_trainable=True)
model_loader.processor: Get the tokenizer or multi-modal processor.
model_loader.model_config: Get the model configuration.
model_loader.model: Get the HF model.
Init Workflow:
1. Init processor. 1. Init processor.
2. Init model config. 2. Init model config.
3. Init model. 3. Init model.
4. Init adapter. 4. Init adapter.
""" """
import torch import torch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoProcessor from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging from ..utils import logging
@@ -55,11 +62,14 @@ class ModelLoader:
"""HF model.""" """HF model."""
def _init_processor(self) -> Processor: def _init_processor(self) -> Processor:
"""Init processor.""" """Init processor.
NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return AutoProcessor.from_pretrained( return AutoProcessor.from_pretrained(
self.args.model, self.args.model,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
) )
def _init_model_config(self) -> HFConfig: def _init_model_config(self) -> HFConfig:
@@ -92,14 +102,24 @@ class ModelLoader:
AutoClass = AutoModel AutoClass = AutoModel
# map the entire model to the current accelerator if self.args.init_config is not None:
model = AutoClass.from_pretrained( from ..plugins.model_plugins.initialization import InitPlugin
self.args.model,
config=self.model_config, init_device = InitPlugin(self.args.init_config.name)()
dtype="auto", else:
device_map=DistributedInterface().current_accelerator, init_device = DistributedInterface().current_accelerator
trust_remote_code=self.args.trust_remote_code,
) if init_device.type == DeviceType.META:
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None: if self.args.peft_config is None:
if self.is_train: if self.is_train:

View File

@@ -0,0 +1,43 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...accelerator.helper import DeviceType
from ...accelerator.interface import DistributedInterface
from ...utils.plugin import BasePlugin
class InitPlugin(BasePlugin):
def __call__(self) -> torch.device:
return super().__call__()
@InitPlugin("init_on_meta").register
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register
def init_on_default() -> torch.device:
return DistributedInterface().current_accelerator

View File

@@ -0,0 +1,35 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config import InputArgument, SampleBackend, get_args
from ..core.base_sampler import BaseSampler
from ..core.model_loader import ModelLoader
def run_chat(args: InputArgument = None):
data_args, model_args, _, sample_args = get_args(args)
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}
model_loader = ModelLoader(model_args)
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
if data_args.dataset is not None:
sampler.batch_infer()
else:
sampler.generate()
if __name__ == "__main__":
run_chat()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""LLaMA-Factory test configuration. """LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers. Contains shared fixtures, pytest configuration, and custom markers.
""" """
@@ -110,11 +110,10 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]): def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment.""" """Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD) # Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"): skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") for item in items:
for item in items: if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
if "tests_v1" in str(item.fspath): item.add_marker(skip_bc)
item.add_marker(skip_bc)
_handle_slow_tests(items) _handle_slow_tests(items)
_handle_runs_on(items) _handle_runs_on(items)
@@ -156,6 +155,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":

View File

@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
### model ### model
model: "llamafactory/tiny-random-qwen2.5" model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true trust_remote_code: true
use_fast_processor: true
model_class: "llm" model_class: "llm"
kernel_config: kernel_config:
name: "auto" name: "auto"

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""LLaMA-Factory test configuration. """LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers. Contains shared fixtures, pytest configuration, and custom markers.
""" """
@@ -22,6 +22,7 @@ import sys
import pytest import pytest
import torch import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]): def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment.""" """Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD) # Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"): skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests") for item in items:
for item in items: if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
if "tests_v1" in str(item.fspath): item.add_marker(skip_bc)
item.add_marker(skip_bc)
_handle_slow_tests(items) _handle_slow_tests(items)
_handle_runs_on(items) _handle_runs_on(items)
_handle_device_visibility(items) _handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested.""" """Set environment variables for distributed tests if specific devices are requested."""
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":

View File

@@ -0,0 +1,56 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_loader import ModelLoader
def test_init_on_meta():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_meta"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == "meta"
@pytest.mark.runs_on(["cuda", "npu"])
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_rank0"},
)
)
model_loader = ModelLoader(model_args=model_args)
if DistributedInterface().get_rank() == 0:
assert model_loader.model.device.type == "cpu"
else:
assert model_loader.model.device.type == "meta"
def test_init_on_default():
_, model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen2.5",
init_config={"name": "init_on_default"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type