[v1] model loader (#9613)

This commit is contained in:
Yaowei Zheng
2025-12-14 11:50:52 +08:00
committed by GitHub
parent fdd24276ed
commit aeda079014
27 changed files with 449 additions and 305 deletions

View File

@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.49.0,<=4.57.1")
check_version("transformers>=4.49.0,<=4.57.3")
check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.14.0,<=0.17.1")

View File

@@ -19,17 +19,13 @@ import os
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from ..utils.types import Tensor, TensorLike
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
from ..utils.types import ProcessGroup, Tensor, TensorLike
@unique
@@ -107,7 +103,7 @@ def is_torch_xpu_available():
return get_current_accelerator().type == DeviceType.XPU
def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor:
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and concats them along the first dim."""
world_size = get_world_size()
device = get_current_accelerator()
@@ -116,7 +112,7 @@ def all_gather(tensor: Tensor, group: Optional["ProcessGroup"] = None) -> Tensor
return output_tensor.view(-1, *tensor.size()[1:])
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional["ProcessGroup"] = None) -> TensorLike:
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
"""Performs all reduce in the given process group."""
device = get_current_accelerator()
is_ndarray = isinstance(data, np.ndarray)

View File

@@ -16,12 +16,14 @@
# limitations under the License.
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
from torch.distributed import init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import Tensor, TensorLike
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from .helper import (
ReduceOp,
all_gather,
@@ -35,10 +37,6 @@ from .helper import (
)
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
class Dim(str, Enum):
"""Dimension names."""
@@ -130,21 +128,33 @@ class DistributedInterface:
return cls._instance
def __init__(self, strategy: DistributedStrategy) -> None:
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
if self._initialized:
return
self.strategy = strategy
if config is None:
self.strategy = DistributedStrategy()
timeout = 18000
else:
self.strategy = DistributedStrategy(
mp_replicate_size=config.get("mp_replicate_size", 1),
mp_shard_size=config.get("mp_shard_size", None),
dp_size=config.get("dp_size", None),
cp_size=config.get("cp_size", 1),
)
timeout = config.get("timeout", 18000)
if self._is_distributed:
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=strategy.model_mesh_shape,
mesh_dim_names=strategy.model_mesh_dim_names,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
mesh_shape=strategy.data_mesh_shape,
mesh_dim_names=strategy.data_mesh_dim_names,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
else:
self.model_device_mesh = None
@@ -172,7 +182,7 @@ class DistributedInterface:
return cls.model_device_mesh[dim.value]
@classmethod
def get_group(cls, dim: Optional[Dim] = None) -> Optional["ProcessGroup"]:
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if cls.model_device_mesh is None or dim is None:
return None

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from pathlib import Path
@@ -28,6 +27,9 @@ from .sample_args import SampleArguments
from .training_args import TrainingArguments
InputArgument = Optional[Union[dict[str, Any], list[str]]]
def validate_args(
data_args: DataArguments,
model_args: ModelArguments,
@@ -43,9 +45,7 @@ def validate_args(
raise ValueError("Quantization is not supported with deepspeed backend.")
def get_args(
args: Optional[Union[dict[str, Any], list[str]]] = None,
) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
"""Parse arguments from command line or config file."""
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")

View File

@@ -18,7 +18,7 @@
import json
from enum import Enum, unique
from typing import Any, Optional, Union
from typing import Optional, Union
class PluginConfig(dict):
@@ -32,25 +32,16 @@ class PluginConfig(dict):
return self["name"]
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(f"Attribute {key} not found.")
def __setattr__(self, key: str, value: Any):
self[key] = value
PluginArgument = Optional[Union[PluginConfig, dict, str]]
@unique
class AutoClass(str, Enum):
class ModelClass(str, Enum):
"""Auto class for model config."""
CAUSALLM = "llm"
CLASSIFICATION = "cls"
LLM = "llm"
CLS = "cls"
OTHER = "other"

View File

@@ -14,8 +14,9 @@
from dataclasses import dataclass, field
from typing import Optional
from .arg_utils import AutoClass, PluginConfig, get_plugin_config
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass
@@ -31,19 +32,19 @@ class ModelArguments:
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
auto_class: AutoClass = field(
default=AutoClass.CAUSALLM,
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
)
peft_config: PluginConfig = field(
peft_config: Optional[PluginConfig] = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
)
kernel_config: PluginConfig = field(
kernel_config: Optional[PluginConfig] = field(
default=None,
metadata={"help": "Kernel configuration for the model."},
)
quant_config: PluginConfig = field(
quant_config: Optional[PluginConfig] = field(
default=None,
metadata={"help": "Quantization configuration for the model."},
)

View File

@@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Optional
from uuid import uuid4
from .arg_utils import PluginArgument, get_plugin_config
from .arg_utils import PluginConfig, get_plugin_config
@dataclass
class TrainingArguments:
output_dir: str = field(
default="",
default=os.path.join("outputs", str(uuid4())),
metadata={"help": "Path to the output directory."},
)
micro_batch_size: int = field(
@@ -40,7 +42,7 @@ class TrainingArguments:
default=False,
metadata={"help": "Use bf16 for training."},
)
dist_config: PluginArgument = field(
dist_config: Optional[PluginConfig] = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)

View File

@@ -17,11 +17,10 @@
Init Phase:
1. Init dataloader.
2. Init model worker.
3. Init optimizer (deepspeed).
4. Shard model.
5. Init optimizer (fsdp).
6. Init scheduler.
2. Init optimizer (deepspeed).
3. Shard model.
4. Init optimizer (fsdp).
5. Init scheduler.
Train Phase:
1. Train Loop
@@ -29,8 +28,7 @@ Train Phase:
"""
from ..config.training_args import TrainingArguments
from ..utils.types import TorchDataset
from .model_worker import ModelWorker
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
@@ -38,21 +36,20 @@ class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: HFModel,
processor: Processor,
dataset: TorchDataset,
data_collator: DataCollator,
model_worker: ModelWorker,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.dataset = dataset
self.data_collator = data_collator
self.model_worker = model_worker
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None
def init_model_and_optimizer(self) -> None:
self.model_worker.init_model_config()
# with self.dist_plugin.get_model_init_context():
# self.model = self.model_worker.init_model(self.model_config)
pass
def create_dataloader(self) -> None:
pass

View File

@@ -15,12 +15,12 @@
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_worker import ModelWorker
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
@@ -32,15 +32,13 @@ class BaseEngine(ABC):
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
self.model = model_worker.get_model()
self.processor = model_worker.get_processor()
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_worker, sample_args)
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -26,7 +26,7 @@ Get Data Sample:
"""
import os
from collections.abc import AsyncIterable, Iterable
from collections.abc import Iterable
from typing import Any, Union
from huggingface_hub import hf_hub_download
@@ -38,7 +38,11 @@ from ..utils.types import DatasetInfo, HFDataset, Sample
class DataEngine(Dataset):
"""Data engine."""
"""Data engine.
Args:
data_args: Data arguments.
"""
def __init__(self, data_args: DataArguments) -> None:
self.args = data_args
@@ -51,11 +55,11 @@ class DataEngine(Dataset):
"""List of (dataset_name, sample_index)"""
self.streaming: bool = False
"""Whether dataset is streaming."""
self.get_dataset_info()
self.load_dataset()
self.build_data_index()
self._get_dataset_info()
self._load_dataset()
self._build_data_index()
def get_dataset_info(self) -> None:
def _get_dataset_info(self) -> None:
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
self.dataset_infos = OmegaConf.load(self.args.dataset)
@@ -68,31 +72,32 @@ class DataEngine(Dataset):
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"path": self.args.dataset}}
def load_dataset(self) -> None:
def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
for key, value in self.dataset_infos.items():
split = value.get("split", "train")
streaming = value.get("streaming", False)
for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if value.get("source", "hf_hub") == "hf_hub":
if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
self.datasets[dataset_name] = DataLoaderPlugin(dataset_info["source"]).load(dataset_info)
def build_data_index(self) -> None:
def _build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if self.streaming:
streaming = self.dataset_infos[dataset_name].get("streaming", False)
if streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin
from ..plugins.data_plugins.loader import DataIndexPlugin
@@ -144,7 +149,7 @@ class DataEngine(Dataset):
if isinstance(index, int):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else:
else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin
selected_index = DataSelectorPlugin().select(self.data_index, index)
@@ -163,30 +168,18 @@ class DataEngine(Dataset):
Returns:
Iterable[Sample]: Dataset iterator.
"""
if self.streaming:
pass
else:
# TODO: add shuffle here
pass
raise NotImplementedError()
async def __aiter__(self) -> AsyncIterable[Sample]:
"""Get dataset async iterator.
Returns:
AsyncIterable[Sample]: Dataset async iterator.
"""
if self.streaming:
pass
else:
# TODO: add shuffle here
pass
# NOTE: hf iterable dataset uses worker ids while map dataset does not
# NOTE: add worker id and shuffle to the map dataset
# https://github.com/huggingface/datasets/blob/4.0.0/src/datasets/iterable_dataset.py#L2214
raise NotImplementedError()
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args
data_args, *_ = get_args()

View File

@@ -0,0 +1,128 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model loader.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
import torch
from transformers import AutoConfig, AutoProcessor
from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
logger = logging.get_logger(__name__)
class ModelLoader:
"""Model loader.
Args:
model_args: Model arguments.
is_trainable: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
self.args = model_args
"""Model arguments."""
self.is_train = is_train
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
"""HF model."""
def _init_processor(self) -> Processor:
"""Init processor."""
return AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def _init_model_config(self) -> HFConfig:
"""Init model config."""
return AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def _init_model(self) -> HFModel:
"""Init model.
Let transformers handle the model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.model_class == ModelClass.LLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
# map the entire model to the current accelerator
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=DistributedInterface.current_accelerator,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None:
if self.is_train:
logger.info_rank0("Fine-tuning mode: full tuning")
model = model.to(torch.float32)
else:
logger.info_rank0("Inference the original model")
else:
from ..plugins.model_plugins.peft import PeftPlugin
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
return model
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args)
print(model_loader.processor)
print(model_loader.model_config)
print(model_loader.model)

View File

@@ -1,119 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model worker.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
from typing import Optional
import torch
from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..config.model_args import AutoClass, ModelArguments
from ..utils.types import HFConfig, HFModel, Processor
class ModelWorker:
def __init__(self, model_args: ModelArguments) -> None:
self.args = model_args
"""Model arguments."""
self.processor: Optional[Processor] = None
"""Tokenizer or multi-modal processor."""
self.model_config: Optional[HFConfig] = None
"""Model configuration."""
self.model: Optional[HFModel] = None
"""HF model."""
self.is_adapter = False
"""Whether the model has adapter."""
def init_processor(self) -> None:
if self.processor is not None:
return
self.processor = AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def init_model_config(self) -> None:
if self.model_config is not None:
return
self.model_config = AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def init_model(self) -> None:
if self.model is not None:
return
self.init_model_config()
if self.args.auto_class == AutoClass.CAUSALLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
ModelClass = AutoModelForImageTextToText
else:
ModelClass = AutoModelForCausalLM
elif self.args.auto_class == AutoClass.CLASSIFICATION:
from transformers import AutoModelForTokenClassification
ModelClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
ModelClass = AutoModel
default_device_type = torch.get_default_device().type
if default_device_type == DeviceType.META:
self.model = ModelClass.from_config(self.model_config)
else:
self.model = ModelClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=default_device_type,
trust_remote_code=self.args.trust_remote_code,
)
def init_adapter(self) -> None:
if self.is_adapter:
return
if self.args.peft_config is not None:
from ..plugins.model_plugins.peft import PeftPlugin
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
self.is_adapter = True
def get_processor(self) -> Processor:
return self.processor
def get_model_config(self) -> HFConfig:
return self.model_config
def get_model(self) -> HFModel:
return self.model

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
from typing import Literal, Optional, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model
@@ -31,12 +31,27 @@ class LoraConfigDict(TypedDict, total=False):
"""Target modules."""
class FreezeConfigDict(TypedDict, total=False):
name: Literal["freeze"]
"""Plugin name."""
freeze_trainable_layers: int
"""Freeze trainable layers."""
freeze_trainable_modules: Optional[list[str]]
"""Freeze trainable modules."""
class PeftPlugin(BasePlugin):
pass
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
return super().__call__(model, config)
@PeftPlugin("lora").register
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model
@PeftPlugin("freeze").register
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from ..accelerator.interface import DistributedInterface, DistributedStrategy
from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_worker import ModelWorker
from ..core.model_loader import ModelLoader
class SFTTrainer(BaseTrainer):
@@ -26,8 +26,13 @@ class SFTTrainer(BaseTrainer):
def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(DistributedStrategy())
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
model_worker = ModelWorker(model_args)
trainer = SFTTrainer(training_args, model_worker, data_engine)
model_loader = ModelLoader(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_loader.model,
processor=model_loader.processor,
dataset=data_engine,
)
trainer.fit()

View File

@@ -0,0 +1,92 @@
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's verl library.
# https://github.com/volcengine/verl/blob/v0.6.1/verl/utils/torch_dtypes.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Union
import torch
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
from ..accelerator.interface import DistributedInterface
class DtypeRegistry:
HALF_LIST = ["fp16", "float16", "half", torch.float16]
FLOAT_LIST = ["fp32", "float32", "float", torch.float32]
BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16]
class DtypeInterface:
"""Type of precision used."""
_is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator)
_is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator)
_is_fp32_available = True
@staticmethod
def is_available(precision: Union[str, torch.dtype]) -> bool:
if precision in DtypeRegistry.HALF_LIST:
return DtypeInterface._is_fp16_available
elif precision in DtypeRegistry.FLOAT_LIST:
return DtypeInterface._is_fp32_available
elif precision in DtypeRegistry.BFLOAT_LIST:
return DtypeInterface._is_bf16_available
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
return precision in DtypeRegistry.HALF_LIST
@staticmethod
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
return precision in DtypeRegistry.FLOAT_LIST
@staticmethod
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
return precision in DtypeRegistry.BFLOAT_LIST
@staticmethod
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
if precision in DtypeRegistry.HALF_LIST:
return torch.float16
elif precision in DtypeRegistry.FLOAT_LIST:
return torch.float32
elif precision in DtypeRegistry.BFLOAT_LIST:
return torch.bfloat16
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod
def to_str(precision: torch.dtype) -> str:
if precision == torch.float16:
return "float16"
elif precision == torch.float32:
return "float32"
elif precision == torch.bfloat16:
return "bfloat16"
else:
raise RuntimeError(f"Unexpected precision: {precision}")
@contextmanager
def set_dtype(self, precision: Union[str, torch.dtype]):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.to_dtype(precision))
try:
yield
finally:
torch.set_default_dtype(original_dtype)

View File

@@ -29,7 +29,7 @@ _default_log_level: "logging._Level" = logging.INFO
class _Logger(logging.Logger):
r"""A logger that supports rank0 logging."""
"""A logger that supports rank0 logging."""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
@@ -42,7 +42,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level":
r"""Return the default logging level."""
"""Return the default logging level."""
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
@@ -62,7 +62,7 @@ def _get_library_root_logger() -> "_Logger":
def _configure_library_root_logger() -> None:
r"""Configure root logger using a stdout stream handler with an explicit format."""
"""Configure root logger using a stdout stream handler with an explicit format."""
global _default_handler
with _thread_lock:
@@ -82,7 +82,7 @@ def _configure_library_root_logger() -> None:
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
@@ -91,13 +91,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def add_handler(handler: "logging.Handler") -> None:
r"""Add a handler to the root logger."""
"""Add a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""Remove a handler to the root logger."""
"""Remove a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)

View File

@@ -38,7 +38,7 @@ class BasePlugin:
self.name = name
@property
def register(self) -> Callable:
def register(self):
"""Decorator to register a function as a plugin.
Example usage:
@@ -60,7 +60,7 @@ class BasePlugin:
return decorator
def __call__(self, *args, **kwargs) -> Callable:
def __call__(self, *args, **kwargs):
"""Call the registered function with the given arguments.
Example usage:
@@ -75,6 +75,9 @@ class BasePlugin:
if __name__ == "__main__":
"""
python -m llamafactory.v1.utils.plugin
"""
class PrintPlugin(BasePlugin):
pass

View File

@@ -23,6 +23,7 @@ if TYPE_CHECKING:
import torch
import torch.utils.data
import transformers
from torch.distributed import ProcessGroup
from torch.distributed.fsdp import FullyShardedDataParallel
Tensor = torch.Tensor
@@ -37,6 +38,7 @@ if TYPE_CHECKING:
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
Optimizer = torch.optim.Optimizer
Scheduler = torch.optim.lr_scheduler.LRScheduler
ProcessGroup = ProcessGroup
else:
Tensor = None
TensorLike = None
@@ -50,6 +52,7 @@ else:
Processor = None
Optimizer = None
Scheduler = None
ProcessGroup = None
class DatasetInfo(TypedDict, total=False):
@@ -69,6 +72,19 @@ class DatasetInfo(TypedDict, total=False):
"""Is streaming dataset, default to False."""
class DistributedConfig(TypedDict, total=False):
mp_replicate_size: NotRequired[int]
"""Model parallel replicate size, default to 1."""
mp_shard_size: NotRequired[int]
"""Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size: NotRequired[int]
"""Data parallel size, default to world_size // cp_size."""
cp_size: NotRequired[int]
"""Context parallel size, default to 1."""
timeout: NotRequired[int]
"""Timeout for distributed communication, default to 600."""
class Content(TypedDict):
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
value: str