[v1] model loader (#9613)

This commit is contained in:
Yaowei Zheng
2025-12-14 11:50:52 +08:00
committed by GitHub
parent fdd24276ed
commit aeda079014
27 changed files with 449 additions and 305 deletions

View File

@@ -17,11 +17,10 @@
Init Phase:
1. Init dataloader.
2. Init model worker.
3. Init optimizer (deepspeed).
4. Shard model.
5. Init optimizer (fsdp).
6. Init scheduler.
2. Init optimizer (deepspeed).
3. Shard model.
4. Init optimizer (fsdp).
5. Init scheduler.
Train Phase:
1. Train Loop
@@ -29,8 +28,7 @@ Train Phase:
"""
from ..config.training_args import TrainingArguments
from ..utils.types import TorchDataset
from .model_worker import ModelWorker
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
@@ -38,21 +36,20 @@ class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: HFModel,
processor: Processor,
dataset: TorchDataset,
data_collator: DataCollator,
model_worker: ModelWorker,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.dataset = dataset
self.data_collator = data_collator
self.model_worker = model_worker
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None
def init_model_and_optimizer(self) -> None:
self.model_worker.init_model_config()
# with self.dist_plugin.get_model_init_context():
# self.model = self.model_worker.init_model(self.model_config)
pass
def create_dataloader(self) -> None:
pass

View File

@@ -15,12 +15,12 @@
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_worker import ModelWorker
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
@@ -32,15 +32,13 @@ class BaseEngine(ABC):
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
self.model = model_worker.get_model()
self.processor = model_worker.get_processor()
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_worker, sample_args)
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -26,7 +26,7 @@ Get Data Sample:
"""
import os
from collections.abc import AsyncIterable, Iterable
from collections.abc import Iterable
from typing import Any, Union
from huggingface_hub import hf_hub_download
@@ -38,7 +38,11 @@ from ..utils.types import DatasetInfo, HFDataset, Sample
class DataEngine(Dataset):
"""Data engine."""
"""Data engine.
Args:
data_args: Data arguments.
"""
def __init__(self, data_args: DataArguments) -> None:
self.args = data_args
@@ -51,11 +55,11 @@ class DataEngine(Dataset):
"""List of (dataset_name, sample_index)"""
self.streaming: bool = False
"""Whether dataset is streaming."""
self.get_dataset_info()
self.load_dataset()
self.build_data_index()
self._get_dataset_info()
self._load_dataset()
self._build_data_index()
def get_dataset_info(self) -> None:
def _get_dataset_info(self) -> None:
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
self.dataset_infos = OmegaConf.load(self.args.dataset)
@@ -68,31 +72,32 @@ class DataEngine(Dataset):
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"path": self.args.dataset}}
def load_dataset(self) -> None:
def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
for key, value in self.dataset_infos.items():
split = value.get("split", "train")
streaming = value.get("streaming", False)
for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if value.get("source", "hf_hub") == "hf_hub":
if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
self.datasets[dataset_name] = DataLoaderPlugin(dataset_info["source"]).load(dataset_info)
def build_data_index(self) -> None:
def _build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if self.streaming:
streaming = self.dataset_infos[dataset_name].get("streaming", False)
if streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin
from ..plugins.data_plugins.loader import DataIndexPlugin
@@ -144,7 +149,7 @@ class DataEngine(Dataset):
if isinstance(index, int):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else:
else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin
selected_index = DataSelectorPlugin().select(self.data_index, index)
@@ -163,30 +168,18 @@ class DataEngine(Dataset):
Returns:
Iterable[Sample]: Dataset iterator.
"""
if self.streaming:
pass
else:
# TODO: add shuffle here
pass
raise NotImplementedError()
async def __aiter__(self) -> AsyncIterable[Sample]:
"""Get dataset async iterator.
Returns:
AsyncIterable[Sample]: Dataset async iterator.
"""
if self.streaming:
pass
else:
# TODO: add shuffle here
pass
# NOTE: hf iterable dataset uses worker ids while map dataset does not
# NOTE: add worker id and shuffle to the map dataset
# https://github.com/huggingface/datasets/blob/4.0.0/src/datasets/iterable_dataset.py#L2214
raise NotImplementedError()
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args
data_args, *_ = get_args()

View File

@@ -0,0 +1,128 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model loader.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
import torch
from transformers import AutoConfig, AutoProcessor
from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
logger = logging.get_logger(__name__)
class ModelLoader:
"""Model loader.
Args:
model_args: Model arguments.
is_trainable: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
self.args = model_args
"""Model arguments."""
self.is_train = is_train
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
"""HF model."""
def _init_processor(self) -> Processor:
"""Init processor."""
return AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def _init_model_config(self) -> HFConfig:
"""Init model config."""
return AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def _init_model(self) -> HFModel:
"""Init model.
Let transformers handle the model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.model_class == ModelClass.LLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
# map the entire model to the current accelerator
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=DistributedInterface.current_accelerator,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None:
if self.is_train:
logger.info_rank0("Fine-tuning mode: full tuning")
model = model.to(torch.float32)
else:
logger.info_rank0("Inference the original model")
else:
from ..plugins.model_plugins.peft import PeftPlugin
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
return model
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args)
print(model_loader.processor)
print(model_loader.model_config)
print(model_loader.model)

View File

@@ -1,119 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model worker.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
from typing import Optional
import torch
from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..config.model_args import AutoClass, ModelArguments
from ..utils.types import HFConfig, HFModel, Processor
class ModelWorker:
def __init__(self, model_args: ModelArguments) -> None:
self.args = model_args
"""Model arguments."""
self.processor: Optional[Processor] = None
"""Tokenizer or multi-modal processor."""
self.model_config: Optional[HFConfig] = None
"""Model configuration."""
self.model: Optional[HFModel] = None
"""HF model."""
self.is_adapter = False
"""Whether the model has adapter."""
def init_processor(self) -> None:
if self.processor is not None:
return
self.processor = AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def init_model_config(self) -> None:
if self.model_config is not None:
return
self.model_config = AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def init_model(self) -> None:
if self.model is not None:
return
self.init_model_config()
if self.args.auto_class == AutoClass.CAUSALLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
ModelClass = AutoModelForImageTextToText
else:
ModelClass = AutoModelForCausalLM
elif self.args.auto_class == AutoClass.CLASSIFICATION:
from transformers import AutoModelForTokenClassification
ModelClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
ModelClass = AutoModel
default_device_type = torch.get_default_device().type
if default_device_type == DeviceType.META:
self.model = ModelClass.from_config(self.model_config)
else:
self.model = ModelClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=default_device_type,
trust_remote_code=self.args.trust_remote_code,
)
def init_adapter(self) -> None:
if self.is_adapter:
return
if self.args.peft_config is not None:
from ..plugins.model_plugins.peft import PeftPlugin
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
self.is_adapter = True
def get_processor(self) -> Processor:
return self.processor
def get_model_config(self) -> HFConfig:
return self.model_config
def get_model(self) -> HFModel:
return self.model