mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 15:20:36 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# core deps
|
||||
transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'
|
||||
transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'
|
||||
transformers>=4.49.0,<=4.57.3,!=4.52.0,!=4.57.0; python_version >= '3.10'
|
||||
datasets>=2.16.0,<=4.0.0
|
||||
accelerate>=1.3.0,<=1.11.0
|
||||
peft>=0.14.0,<=0.17.1
|
||||
|
||||
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.49.0,<=4.57.1")
|
||||
check_version("transformers>=4.49.0,<=4.57.3")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||
check_version("peft>=0.14.0,<=0.17.1")
|
||||
|
||||
@@ -19,17 +19,13 @@ import os
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils.types import Tensor, TensorLike
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
from ..utils.types import ProcessGroup, Tensor, TensorLike
|
||||
|
||||
|
||||
@unique
|
||||
@@ -107,7 +103,7 @@ def is_torch_xpu_available():
|
||||
return get_current_accelerator().type == DeviceType.XPU
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||
world_size = get_world_size()
|
||||
device = get_current_accelerator()
|
||||
@@ -116,7 +112,7 @@ def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor
|
||||
return output_tensor.view(-1, *tensor.size()[1:])
|
||||
|
||||
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
|
||||
"""Performs all reduce in the given process group."""
|
||||
device = get_current_accelerator()
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
|
||||
@@ -16,12 +16,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.distributed import init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils.types import Tensor, TensorLike
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from .helper import (
|
||||
ReduceOp,
|
||||
all_gather,
|
||||
@@ -35,10 +37,6 @@ from .helper import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Dim(str, Enum):
|
||||
"""Dimension names."""
|
||||
|
||||
@@ -130,21 +128,33 @@ class DistributedInterface:
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, strategy: DistributedStrategy) -> None:
|
||||
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.strategy = strategy
|
||||
if config is None:
|
||||
self.strategy = DistributedStrategy()
|
||||
timeout = 18000
|
||||
else:
|
||||
self.strategy = DistributedStrategy(
|
||||
mp_replicate_size=config.get("mp_replicate_size", 1),
|
||||
mp_shard_size=config.get("mp_shard_size", None),
|
||||
dp_size=config.get("dp_size", None),
|
||||
cp_size=config.get("cp_size", 1),
|
||||
)
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
mesh_shape=strategy.model_mesh_shape,
|
||||
mesh_dim_names=strategy.model_mesh_dim_names,
|
||||
mesh_shape=self.strategy.model_mesh_shape,
|
||||
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
||||
)
|
||||
self.data_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
mesh_shape=strategy.data_mesh_shape,
|
||||
mesh_dim_names=strategy.data_mesh_dim_names,
|
||||
mesh_shape=self.strategy.data_mesh_shape,
|
||||
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
||||
)
|
||||
else:
|
||||
self.model_device_mesh = None
|
||||
@@ -172,7 +182,7 @@ class DistributedInterface:
|
||||
return cls.model_device_mesh[dim.value]
|
||||
|
||||
@classmethod
|
||||
def get_group(cls, dim: Optional[Dim] = None) -> Optional["ProcessGroup"]:
|
||||
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if cls.model_device_mesh is None or dim is None:
|
||||
return None
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -28,6 +27,9 @@ from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
InputArgument = Optional[Union[dict[str, Any], list[str]]]
|
||||
|
||||
|
||||
def validate_args(
|
||||
data_args: DataArguments,
|
||||
model_args: ModelArguments,
|
||||
@@ -43,9 +45,7 @@ def validate_args(
|
||||
raise ValueError("Quantization is not supported with deepspeed backend.")
|
||||
|
||||
|
||||
def get_args(
|
||||
args: Optional[Union[dict[str, Any], list[str]]] = None,
|
||||
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
"""Parse arguments from command line or config file."""
|
||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
@@ -32,25 +32,16 @@ class PluginConfig(dict):
|
||||
|
||||
return self["name"]
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {key} not found.")
|
||||
|
||||
def __setattr__(self, key: str, value: Any):
|
||||
self[key] = value
|
||||
|
||||
|
||||
PluginArgument = Optional[Union[PluginConfig, dict, str]]
|
||||
|
||||
|
||||
@unique
|
||||
class AutoClass(str, Enum):
|
||||
class ModelClass(str, Enum):
|
||||
"""Auto class for model config."""
|
||||
|
||||
CAUSALLM = "llm"
|
||||
CLASSIFICATION = "cls"
|
||||
LLM = "llm"
|
||||
CLS = "cls"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
|
||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -31,19 +32,19 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
auto_class: AutoClass = field(
|
||||
default=AutoClass.CAUSALLM,
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
peft_config: PluginConfig = field(
|
||||
peft_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
)
|
||||
kernel_config: PluginConfig = field(
|
||||
kernel_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "Kernel configuration for the model."},
|
||||
)
|
||||
quant_config: PluginConfig = field(
|
||||
quant_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "Quantization configuration for the model."},
|
||||
)
|
||||
|
||||
@@ -12,16 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginArgument, get_plugin_config
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
default="",
|
||||
default=os.path.join("outputs", str(uuid4())),
|
||||
metadata={"help": "Path to the output directory."},
|
||||
)
|
||||
micro_batch_size: int = field(
|
||||
@@ -40,7 +42,7 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
dist_config: PluginArgument = field(
|
||||
dist_config: Optional[PluginConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
|
||||
@@ -17,11 +17,10 @@
|
||||
Init Phase:
|
||||
|
||||
1. Init dataloader.
|
||||
2. Init model worker.
|
||||
3. Init optimizer (deepspeed).
|
||||
4. Shard model.
|
||||
5. Init optimizer (fsdp).
|
||||
6. Init scheduler.
|
||||
2. Init optimizer (deepspeed).
|
||||
3. Shard model.
|
||||
4. Init optimizer (fsdp).
|
||||
5. Init scheduler.
|
||||
|
||||
Train Phase:
|
||||
1. Train Loop
|
||||
@@ -29,8 +28,7 @@ Train Phase:
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..utils.types import TorchDataset
|
||||
from .model_worker import ModelWorker
|
||||
from ..utils.types import HFModel, Processor, TorchDataset
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
|
||||
|
||||
@@ -38,21 +36,20 @@ class BaseTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
dataset: TorchDataset,
|
||||
data_collator: DataCollator,
|
||||
model_worker: ModelWorker,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.dataset = dataset
|
||||
self.data_collator = data_collator
|
||||
self.model_worker = model_worker
|
||||
self.data_collator = DataCollator()
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
self.model_worker.init_model_config()
|
||||
# with self.dist_plugin.get_model_init_context():
|
||||
# self.model = self.model_worker.init_model(self.model_config)
|
||||
pass
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -15,12 +15,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..config.sample_args import SampleArguments, SampleBackend
|
||||
from .model_worker import ModelWorker
|
||||
from .model_loader import ModelLoader
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
|
||||
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self):
|
||||
@@ -32,15 +32,13 @@ class BaseEngine(ABC):
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
|
||||
self.model = model_worker.get_model()
|
||||
self.processor = model_worker.get_processor()
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
self.args = sample_args
|
||||
|
||||
|
||||
class ChatSampler:
|
||||
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
if sample_args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(model_worker, sample_args)
|
||||
self.engine = HuggingFaceEngine(model_loader, sample_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
||||
|
||||
@@ -26,7 +26,7 @@ Get Data Sample:
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterable, Iterable
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Union
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -38,7 +38,11 @@ from ..utils.types import DatasetInfo, HFDataset, Sample
|
||||
|
||||
|
||||
class DataEngine(Dataset):
|
||||
"""Data engine."""
|
||||
"""Data engine.
|
||||
|
||||
Args:
|
||||
data_args: Data arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, data_args: DataArguments) -> None:
|
||||
self.args = data_args
|
||||
@@ -51,11 +55,11 @@ class DataEngine(Dataset):
|
||||
"""List of (dataset_name, sample_index)"""
|
||||
self.streaming: bool = False
|
||||
"""Whether dataset is streaming."""
|
||||
self.get_dataset_info()
|
||||
self.load_dataset()
|
||||
self.build_data_index()
|
||||
self._get_dataset_info()
|
||||
self._load_dataset()
|
||||
self._build_data_index()
|
||||
|
||||
def get_dataset_info(self) -> None:
|
||||
def _get_dataset_info(self) -> None:
|
||||
"""Get dataset info from data arguments."""
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
|
||||
self.dataset_infos = OmegaConf.load(self.args.dataset)
|
||||
@@ -68,31 +72,32 @@ class DataEngine(Dataset):
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset}}
|
||||
|
||||
def load_dataset(self) -> None:
|
||||
def _load_dataset(self) -> None:
|
||||
"""Load datasets according to dataset info."""
|
||||
for key, value in self.dataset_infos.items():
|
||||
split = value.get("split", "train")
|
||||
streaming = value.get("streaming", False)
|
||||
for dataset_name, dataset_info in self.dataset_infos.items():
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if value.get("source", "hf_hub") == "hf_hub":
|
||||
if dataset_info.get("source", "hf_hub") == "hf_hub":
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
|
||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
|
||||
self.datasets[dataset_name] = DataLoaderPlugin(dataset_info["source"]).load(dataset_info)
|
||||
|
||||
def build_data_index(self) -> None:
|
||||
def _build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
for dataset_name, dataset in self.datasets.items():
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if self.streaming:
|
||||
streaming = self.dataset_infos[dataset_name].get("streaming", False)
|
||||
if streaming:
|
||||
data_index = [(dataset_name, -1) for _ in range(1000)]
|
||||
else:
|
||||
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
|
||||
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if size or weight: # data index plugin
|
||||
from ..plugins.data_plugins.loader import DataIndexPlugin
|
||||
|
||||
@@ -144,7 +149,7 @@ class DataEngine(Dataset):
|
||||
if isinstance(index, int):
|
||||
dataset_name, sample_index = self.data_index[index]
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
else:
|
||||
else: # data selector plugin
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
|
||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||
@@ -163,30 +168,18 @@ class DataEngine(Dataset):
|
||||
Returns:
|
||||
Iterable[Sample]: Dataset iterator.
|
||||
"""
|
||||
if self.streaming:
|
||||
pass
|
||||
else:
|
||||
# TODO: add shuffle here
|
||||
pass
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterable[Sample]:
|
||||
"""Get dataset async iterator.
|
||||
|
||||
Returns:
|
||||
AsyncIterable[Sample]: Dataset async iterator.
|
||||
"""
|
||||
if self.streaming:
|
||||
pass
|
||||
else:
|
||||
# TODO: add shuffle here
|
||||
pass
|
||||
# NOTE: hf iterable dataset uses worker ids while map dataset does not
|
||||
# NOTE: add worker id and shuffle to the map dataset
|
||||
# https://github.com/huggingface/datasets/blob/4.0.0/src/datasets/iterable_dataset.py#L2214
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
data_args, *_ = get_args()
|
||||
|
||||
128
src/llamafactory/v1/core/model_loader.py
Normal file
128
src/llamafactory/v1/core/model_loader.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model loader.
|
||||
|
||||
Init Phase:
|
||||
1. Init processor.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.model_args import ModelArguments, ModelClass
|
||||
from ..utils import logging
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ModelLoader:
|
||||
"""Model loader.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments.
|
||||
is_trainable: Whether to train the model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
||||
self.args = model_args
|
||||
"""Model arguments."""
|
||||
self.is_train = is_train
|
||||
"""Whether to train the model."""
|
||||
self.processor = self._init_processor()
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.model_config = self._init_model_config()
|
||||
"""Model configuration."""
|
||||
self.model = self._init_model()
|
||||
"""HF model."""
|
||||
|
||||
def _init_processor(self) -> Processor:
|
||||
"""Init processor."""
|
||||
return AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def _init_model_config(self) -> HFConfig:
|
||||
"""Init model config."""
|
||||
return AutoConfig.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def _init_model(self) -> HFModel:
|
||||
"""Init model.
|
||||
|
||||
Let transformers handle the model init context.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||
"""
|
||||
if self.args.model_class == ModelClass.LLM:
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||
AutoClass = AutoModelForImageTextToText
|
||||
else:
|
||||
AutoClass = AutoModelForCausalLM
|
||||
|
||||
elif self.args.model_class == ModelClass.CLS:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
# map the entire model to the current accelerator
|
||||
model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=DistributedInterface.current_accelerator,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
if self.args.peft_config is None:
|
||||
if self.is_train:
|
||||
logger.info_rank0("Fine-tuning mode: full tuning")
|
||||
model = model.to(torch.float32)
|
||||
else:
|
||||
logger.info_rank0("Inference the original model")
|
||||
else:
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
_, model_args, *_ = get_args()
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
print(model_loader.processor)
|
||||
print(model_loader.model_config)
|
||||
print(model_loader.model)
|
||||
@@ -1,119 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model worker.
|
||||
|
||||
Init Phase:
|
||||
1. Init processor.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..accelerator.helper import DeviceType
|
||||
from ..config.model_args import AutoClass, ModelArguments
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self, model_args: ModelArguments) -> None:
|
||||
self.args = model_args
|
||||
"""Model arguments."""
|
||||
self.processor: Optional[Processor] = None
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.model_config: Optional[HFConfig] = None
|
||||
"""Model configuration."""
|
||||
self.model: Optional[HFModel] = None
|
||||
"""HF model."""
|
||||
self.is_adapter = False
|
||||
"""Whether the model has adapter."""
|
||||
|
||||
def init_processor(self) -> None:
|
||||
if self.processor is not None:
|
||||
return
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
if self.model_config is not None:
|
||||
return
|
||||
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_model(self) -> None:
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
self.init_model_config()
|
||||
|
||||
if self.args.auto_class == AutoClass.CAUSALLM:
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||
ModelClass = AutoModelForImageTextToText
|
||||
else:
|
||||
ModelClass = AutoModelForCausalLM
|
||||
elif self.args.auto_class == AutoClass.CLASSIFICATION:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
ModelClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
ModelClass = AutoModel
|
||||
|
||||
default_device_type = torch.get_default_device().type
|
||||
if default_device_type == DeviceType.META:
|
||||
self.model = ModelClass.from_config(self.model_config)
|
||||
else:
|
||||
self.model = ModelClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=default_device_type,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_adapter(self) -> None:
|
||||
if self.is_adapter:
|
||||
return
|
||||
|
||||
if self.args.peft_config is not None:
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
|
||||
|
||||
self.is_adapter = True
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
return self.processor
|
||||
|
||||
def get_model_config(self) -> HFConfig:
|
||||
return self.model_config
|
||||
|
||||
def get_model(self) -> HFModel:
|
||||
return self.model
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, Optional, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
@@ -31,12 +31,27 @@ class LoraConfigDict(TypedDict, total=False):
|
||||
"""Target modules."""
|
||||
|
||||
|
||||
class FreezeConfigDict(TypedDict, total=False):
|
||||
name: Literal["freeze"]
|
||||
"""Plugin name."""
|
||||
freeze_trainable_layers: int
|
||||
"""Freeze trainable layers."""
|
||||
freeze_trainable_modules: Optional[list[str]]
|
||||
"""Freeze trainable modules."""
|
||||
|
||||
|
||||
class PeftPlugin(BasePlugin):
|
||||
pass
|
||||
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
|
||||
return super().__call__(model, config)
|
||||
|
||||
|
||||
@PeftPlugin("lora").register
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("freeze").register
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..accelerator.interface import DistributedInterface, DistributedStrategy
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.arg_parser import get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_worker import ModelWorker
|
||||
from ..core.model_loader import ModelLoader
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
@@ -26,8 +26,13 @@ class SFTTrainer(BaseTrainer):
|
||||
|
||||
def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
DistributedInterface(DistributedStrategy())
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
model_worker = ModelWorker(model_args)
|
||||
trainer = SFTTrainer(training_args, model_worker, data_engine)
|
||||
model_loader = ModelLoader(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_loader.model,
|
||||
processor=model_loader.processor,
|
||||
dataset=data_engine,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
92
src/llamafactory/v1/utils/dtype.py
Normal file
92
src/llamafactory/v1/utils/dtype.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Bytedance's verl library.
|
||||
# https://github.com/volcengine/verl/blob/v0.6.1/verl/utils/torch_dtypes.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
|
||||
|
||||
class DtypeRegistry:
|
||||
HALF_LIST = ["fp16", "float16", "half", torch.float16]
|
||||
FLOAT_LIST = ["fp32", "float32", "float", torch.float32]
|
||||
BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16]
|
||||
|
||||
|
||||
class DtypeInterface:
|
||||
"""Type of precision used."""
|
||||
|
||||
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator)
|
||||
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator)
|
||||
_is_fp32_available = True
|
||||
|
||||
@staticmethod
|
||||
def is_available(precision: Union[str, torch.dtype]) -> bool:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return DtypeInterface._is_fp16_available
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
return DtypeInterface._is_fp32_available
|
||||
elif precision in DtypeRegistry.BFLOAT_LIST:
|
||||
return DtypeInterface._is_bf16_available
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@staticmethod
|
||||
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
|
||||
return precision in DtypeRegistry.HALF_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
|
||||
return precision in DtypeRegistry.FLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
|
||||
return precision in DtypeRegistry.BFLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return torch.float16
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
return torch.float32
|
||||
elif precision in DtypeRegistry.BFLOAT_LIST:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@staticmethod
|
||||
def to_str(precision: torch.dtype) -> str:
|
||||
if precision == torch.float16:
|
||||
return "float16"
|
||||
elif precision == torch.float32:
|
||||
return "float32"
|
||||
elif precision == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@contextmanager
|
||||
def set_dtype(self, precision: Union[str, torch.dtype]):
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self.to_dtype(precision))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.set_default_dtype(original_dtype)
|
||||
@@ -29,7 +29,7 @@ _default_log_level: "logging._Level" = logging.INFO
|
||||
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""A logger that supports rank0 logging."""
|
||||
"""A logger that supports rank0 logging."""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
self.info(*args, **kwargs)
|
||||
@@ -42,7 +42,7 @@ class _Logger(logging.Logger):
|
||||
|
||||
|
||||
def _get_default_logging_level() -> "logging._Level":
|
||||
r"""Return the default logging level."""
|
||||
"""Return the default logging level."""
|
||||
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
@@ -62,7 +62,7 @@ def _get_library_root_logger() -> "_Logger":
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
r"""Configure root logger using a stdout stream handler with an explicit format."""
|
||||
"""Configure root logger using a stdout stream handler with an explicit format."""
|
||||
global _default_handler
|
||||
|
||||
with _thread_lock:
|
||||
@@ -82,7 +82,7 @@ def _configure_library_root_logger() -> None:
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -91,13 +91,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
|
||||
|
||||
def add_handler(handler: "logging.Handler") -> None:
|
||||
r"""Add a handler to the root logger."""
|
||||
"""Add a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
r"""Remove a handler to the root logger."""
|
||||
"""Remove a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class BasePlugin:
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def register(self) -> Callable:
|
||||
def register(self):
|
||||
"""Decorator to register a function as a plugin.
|
||||
|
||||
Example usage:
|
||||
@@ -60,7 +60,7 @@ class BasePlugin:
|
||||
|
||||
return decorator
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Callable:
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Call the registered function with the given arguments.
|
||||
|
||||
Example usage:
|
||||
@@ -75,6 +75,9 @@ class BasePlugin:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.utils.plugin
|
||||
"""
|
||||
|
||||
class PrintPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
|
||||
Tensor = torch.Tensor
|
||||
@@ -37,6 +38,7 @@ if TYPE_CHECKING:
|
||||
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
|
||||
Optimizer = torch.optim.Optimizer
|
||||
Scheduler = torch.optim.lr_scheduler.LRScheduler
|
||||
ProcessGroup = ProcessGroup
|
||||
else:
|
||||
Tensor = None
|
||||
TensorLike = None
|
||||
@@ -50,6 +52,7 @@ else:
|
||||
Processor = None
|
||||
Optimizer = None
|
||||
Scheduler = None
|
||||
ProcessGroup = None
|
||||
|
||||
|
||||
class DatasetInfo(TypedDict, total=False):
|
||||
@@ -69,6 +72,19 @@ class DatasetInfo(TypedDict, total=False):
|
||||
"""Is streaming dataset, default to False."""
|
||||
|
||||
|
||||
class DistributedConfig(TypedDict, total=False):
|
||||
mp_replicate_size: NotRequired[int]
|
||||
"""Model parallel replicate size, default to 1."""
|
||||
mp_shard_size: NotRequired[int]
|
||||
"""Model parallel shard size, default to world_size // mp_replicate_size."""
|
||||
dp_size: NotRequired[int]
|
||||
"""Data parallel size, default to world_size // cp_size."""
|
||||
cp_size: NotRequired[int]
|
||||
"""Context parallel size, default to 1."""
|
||||
timeout: NotRequired[int]
|
||||
"""Timeout for distributed communication, default to 600."""
|
||||
|
||||
|
||||
class Content(TypedDict):
|
||||
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
|
||||
value: str
|
||||
|
||||
@@ -18,6 +18,7 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Item
|
||||
|
||||
from llamafactory.extras.misc import get_current_device, is_env_enabled
|
||||
from llamafactory.train.test_utils import patch_valuehead_model
|
||||
@@ -29,21 +30,15 @@ except Exception:
|
||||
CURRENT_DEVICE = "cpu"
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
def pytest_configure(config: Config):
|
||||
"""Register custom pytest markers."""
|
||||
config.addinivalue_line(
|
||||
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"' or set RUN_SLOW=1 to run)"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "skip_on_devices: skip test on specified devices, e.g., @pytest.mark.skip_on_devices('npu', 'xpu')"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "require_device: test requires specific device, e.g., @pytest.mark.require_device('cuda')"
|
||||
)
|
||||
config.addinivalue_line("markers", "runs_on: test requires specific device, e.g., @pytest.mark.runs_on(['cpu'])")
|
||||
|
||||
|
||||
def _handle_runs_on(items):
|
||||
def _handle_runs_on(items: list[Item]):
|
||||
"""Skip tests on specified devices based on runs_on marker.
|
||||
|
||||
Usage:
|
||||
@@ -68,7 +63,7 @@ def _handle_runs_on(items):
|
||||
)
|
||||
|
||||
|
||||
def _handle_slow_tests(items):
|
||||
def _handle_slow_tests(items: list[Item]):
|
||||
"""Skip slow tests unless RUN_SLOW environment variable is set.
|
||||
|
||||
Usage:
|
||||
@@ -85,51 +80,9 @@ def _handle_slow_tests(items):
|
||||
item.add_marker(skip_slow)
|
||||
|
||||
|
||||
def _handle_device_skips(items):
|
||||
"""Skip tests on specified devices based on skip_on_devices marker.
|
||||
|
||||
Usage:
|
||||
@pytest.mark.skip_on_devices("npu", "xpu")
|
||||
def test_something():
|
||||
pass
|
||||
"""
|
||||
for item in items:
|
||||
skip_marker = item.get_closest_marker("skip_on_devices")
|
||||
if skip_marker:
|
||||
skip_devices = skip_marker.args
|
||||
if CURRENT_DEVICE in skip_devices:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(
|
||||
reason=f"test skipped on {CURRENT_DEVICE.upper()} (skip list: {', '.join(skip_devices)})"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _handle_device_requirements(items):
|
||||
"""Skip tests that require a specific device when running on other devices.
|
||||
|
||||
Usage:
|
||||
@pytest.mark.require_device("cuda")
|
||||
def test_gpu_only():
|
||||
pass
|
||||
"""
|
||||
for item in items:
|
||||
require_marker = item.get_closest_marker("require_device")
|
||||
if require_marker:
|
||||
required_device = require_marker.args[0] if require_marker.args else None
|
||||
if required_device and CURRENT_DEVICE != required_device:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(
|
||||
reason=f"test requires {required_device.upper()} (current: {CURRENT_DEVICE.upper()})"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
"""Modify test collection based on markers and environment."""
|
||||
_handle_slow_tests(items)
|
||||
_handle_device_skips(items)
|
||||
_handle_device_requirements(items)
|
||||
_handle_runs_on(items)
|
||||
|
||||
|
||||
|
||||
@@ -31,15 +31,13 @@ INFER_ARGS = {
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "npu"])
|
||||
@pytest.mark.skip_on_devices("npu")
|
||||
def test_base():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA3)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "npu"])
|
||||
@pytest.mark.skip_on_devices("npu")
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_valuehead():
|
||||
model = load_infer_model(add_valuehead=True, **INFER_ARGS)
|
||||
|
||||
@@ -104,7 +104,6 @@ def test_lora_train_valuehead():
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "npu"])
|
||||
@pytest.mark.skip_on_devices("npu")
|
||||
def test_lora_inference():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.4.103
|
||||
0.9.4.104
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
|
||||
import os
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
|
||||
|
||||
def test_distributed_interface():
|
||||
DistributedInterface(DistributedStrategy())
|
||||
DistributedInterface()
|
||||
assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0"))
|
||||
assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
|
||||
assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
29
tests_v1/conftest.py
Normal file
29
tests_v1/conftest.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Item
|
||||
|
||||
from llamafactory.v1.utils.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
if is_transformers_version_greater_than("4.57.0"):
|
||||
return
|
||||
|
||||
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
|
||||
|
||||
for item in items:
|
||||
if "tests_v1" in str(item.fspath):
|
||||
item.add_marker(skip_bc)
|
||||
33
tests_v1/core/test_model_loader.py
Normal file
33
tests_v1/core/test_model_loader.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.v1.config.model_args import ModelArguments
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
|
||||
|
||||
def test_tiny_qwen():
|
||||
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
|
||||
|
||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
|
||||
model_loader = ModelLoader(model_args)
|
||||
assert isinstance(model_loader.processor, Qwen2TokenizerFast)
|
||||
assert isinstance(model_loader.model.config, Qwen2Config)
|
||||
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
||||
assert model_loader.model.dtype == torch.bfloat16
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tiny_qwen()
|
||||
@@ -24,7 +24,7 @@ from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_alpaca_converter(num_samples: int):
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo/dataset_info.yaml")
|
||||
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
|
||||
data_engine = DataEngine(data_args)
|
||||
original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
@@ -54,6 +54,8 @@ def test_sharegpt_converter():
|
||||
"conversations": [
|
||||
{"from": "system", "value": "System"},
|
||||
{"from": "human", "value": "User"},
|
||||
{"from": "function_call", "value": "Tool"},
|
||||
{"from": "observation", "value": "Observation"},
|
||||
{"from": "gpt", "value": "Assistant"},
|
||||
]
|
||||
}
|
||||
@@ -61,6 +63,8 @@ def test_sharegpt_converter():
|
||||
"messages": [
|
||||
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
|
||||
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
|
||||
{"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"},
|
||||
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
|
||||
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
|
||||
]
|
||||
}
|
||||
@@ -69,7 +73,7 @@ def test_sharegpt_converter():
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_pair_converter(num_samples: int):
|
||||
data_args = DataArguments(dataset="llamafactory/tiny-preference-dataset/dataset_info.yaml")
|
||||
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
|
||||
data_engine = DataEngine(data_args)
|
||||
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
@@ -112,7 +116,7 @@ def test_pair_converter(num_samples: int):
|
||||
},
|
||||
],
|
||||
}
|
||||
assert data_engine[index] == {"_dataset_name": "dpo_zh_demo", **expected_data}
|
||||
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user