mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 01:20:35 +08:00
[v1] add init plugin (#9716)
This commit is contained in:
@@ -34,10 +34,14 @@ from typing import Any, Optional
|
||||
from torch.distributed import barrier, destroy_process_group, init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from . import helper
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Dim(str, Enum):
|
||||
"""Dimension names."""
|
||||
|
||||
@@ -157,6 +161,7 @@ class DistributedInterface:
|
||||
self.data_device_mesh = None
|
||||
|
||||
self._initialized = True
|
||||
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
"ModelArguments",
|
||||
"ModelClass",
|
||||
"SampleArguments",
|
||||
"SampleBackend",
|
||||
"TrainingArguments",
|
||||
"get_args",
|
||||
]
|
||||
|
||||
@@ -27,14 +27,14 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
use_fast_processor: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
init_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Initialization configuration for the model."},
|
||||
)
|
||||
peft_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
@@ -49,6 +49,7 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.init_config = get_plugin_config(self.init_config)
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
self.quant_config = get_plugin_config(self.quant_config)
|
||||
|
||||
@@ -22,7 +22,7 @@ from .arg_utils import PluginConfig, get_plugin_config
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
default=os.path.join("outputs", str(uuid4())),
|
||||
default=os.path.join("outputs", str(uuid4().hex)),
|
||||
metadata={"help": "Path to the output directory."},
|
||||
)
|
||||
micro_batch_size: int = field(
|
||||
|
||||
77
src/llamafactory/v1/core/base_sampler.py
Normal file
77
src/llamafactory/v1/core/base_sampler.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||
from ..utils.types import HFModel, Processor, TorchDataset
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel = None,
|
||||
processor: Processor = None,
|
||||
) -> None:
|
||||
"""Initialize the engine.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
processor: Processor.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, messages):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self, data: TorchDataset) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
) -> None:
|
||||
self.args = args
|
||||
|
||||
|
||||
class BaseSampler:
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
) -> None:
|
||||
if args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(args, model_args, model, processor)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
|
||||
|
||||
async def generate(self, messages):
|
||||
return await self.engine.generate(messages)
|
||||
|
||||
async def batch_infer(self, data: TorchDataset) -> None:
|
||||
return await self.engine.batch_infer(data)
|
||||
@@ -1,44 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..config.sample_args import SampleArguments, SampleBackend
|
||||
from .model_loader import ModelLoader
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self):
|
||||
pass
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
self.args = sample_args
|
||||
|
||||
|
||||
class ChatSampler:
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
if sample_args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(model_loader, sample_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
||||
@@ -14,17 +14,24 @@
|
||||
|
||||
"""The definition of model loader.
|
||||
|
||||
Init Phase:
|
||||
How to use:
|
||||
model_loader = ModelLoader(model_args, is_trainable=True)
|
||||
model_loader.processor: Get the tokenizer or multi-modal processor.
|
||||
model_loader.model_config: Get the model configuration.
|
||||
model_loader.model: Get the HF model.
|
||||
|
||||
Init Workflow:
|
||||
1. Init processor.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..accelerator.helper import DeviceType
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.model_args import ModelArguments, ModelClass
|
||||
from ..utils import logging
|
||||
@@ -55,11 +62,14 @@ class ModelLoader:
|
||||
"""HF model."""
|
||||
|
||||
def _init_processor(self) -> Processor:
|
||||
"""Init processor."""
|
||||
"""Init processor.
|
||||
|
||||
NOTE: Transformers v5 always use fast tokenizer.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
|
||||
"""
|
||||
return AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def _init_model_config(self) -> HFConfig:
|
||||
@@ -92,12 +102,22 @@ class ModelLoader:
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
# map the entire model to the current accelerator
|
||||
if self.args.init_config is not None:
|
||||
from ..plugins.model_plugins.initialization import InitPlugin
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_accelerator
|
||||
|
||||
if init_device.type == DeviceType.META:
|
||||
with init_empty_weights():
|
||||
model = AutoClass.from_config(self.model_config)
|
||||
else:
|
||||
model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=DistributedInterface().current_accelerator,
|
||||
device_map=init_device,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from ...accelerator.helper import DeviceType
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
|
||||
class InitPlugin(BasePlugin):
|
||||
def __call__(self) -> torch.device:
|
||||
return super().__call__()
|
||||
|
||||
|
||||
@InitPlugin("init_on_meta").register
|
||||
def init_on_meta() -> torch.device:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_rank0").register
|
||||
def init_on_rank0() -> torch.device:
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
else:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_default").register
|
||||
def init_on_default() -> torch.device:
|
||||
return DistributedInterface().current_accelerator
|
||||
|
||||
35
src/llamafactory/v1/samplers/cli_sampler.py
Normal file
35
src/llamafactory/v1/samplers/cli_sampler.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..config import InputArgument, SampleBackend, get_args
|
||||
from ..core.base_sampler import BaseSampler
|
||||
from ..core.model_loader import ModelLoader
|
||||
|
||||
|
||||
def run_chat(args: InputArgument = None):
|
||||
data_args, model_args, _, sample_args = get_args(args)
|
||||
if sample_args.sample_backend != SampleBackend.HF:
|
||||
model_args.init_plugin = {"name": "init_on_meta"}
|
||||
|
||||
model_loader = ModelLoader(model_args)
|
||||
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
|
||||
if data_args.dataset is not None:
|
||||
sampler.batch_infer()
|
||||
else:
|
||||
sampler.generate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_chat()
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LLaMA-Factory test configuration.
|
||||
"""LlamaFactory test configuration.
|
||||
|
||||
Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
@@ -110,10 +110,9 @@ def _handle_device_visibility(items: list[Item]):
|
||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
"""Modify test collection based on markers and environment."""
|
||||
# Handle version compatibility (from HEAD)
|
||||
if not is_transformers_version_greater_than("4.57.0"):
|
||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||
for item in items:
|
||||
if "tests_v1" in str(item.fspath):
|
||||
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
|
||||
item.add_marker(skip_bc)
|
||||
|
||||
_handle_slow_tests(items)
|
||||
@@ -156,6 +155,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||
else:
|
||||
monkeypatch.setenv(env_key, "0")
|
||||
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
|
||||
@@ -24,7 +24,6 @@ def test_get_args_from_yaml(tmp_path: pathlib.Path):
|
||||
### model
|
||||
model: "llamafactory/tiny-random-qwen2.5"
|
||||
trust_remote_code: true
|
||||
use_fast_processor: true
|
||||
model_class: "llm"
|
||||
kernel_config:
|
||||
name: "auto"
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LLaMA-Factory test configuration.
|
||||
"""LlamaFactory test configuration.
|
||||
|
||||
Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
@@ -22,6 +22,7 @@ import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||
|
||||
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
|
||||
@@ -109,10 +110,9 @@ def _handle_device_visibility(items: list[Item]):
|
||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
"""Modify test collection based on markers and environment."""
|
||||
# Handle version compatibility (from HEAD)
|
||||
if not is_transformers_version_greater_than("4.57.0"):
|
||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||
for item in items:
|
||||
if "tests_v1" in str(item.fspath):
|
||||
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
|
||||
item.add_marker(skip_bc)
|
||||
|
||||
_handle_slow_tests(items)
|
||||
@@ -120,6 +120,14 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
_handle_device_visibility(items)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_distributed_state():
|
||||
"""Cleanup distributed state after each test."""
|
||||
yield
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Set environment variables for distributed tests if specific devices are requested."""
|
||||
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||
else:
|
||||
monkeypatch.setenv(env_key, "0")
|
||||
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
|
||||
56
tests_v1/plugins/model_plugins/test_init_plugin.py
Normal file
56
tests_v1/plugins/model_plugins/test_init_plugin.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
|
||||
|
||||
def test_init_on_meta():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
init_config={"name": "init_on_meta"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
assert model_loader.model.device.type == "meta"
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cuda", "npu"])
|
||||
def test_init_on_rank0():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
init_config={"name": "init_on_rank0"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
assert model_loader.model.device.type == "cpu"
|
||||
else:
|
||||
assert model_loader.model.device.type == "meta"
|
||||
|
||||
|
||||
def test_init_on_default():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen2.5",
|
||||
init_config={"name": "init_on_default"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type
|
||||
Reference in New Issue
Block a user