mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -17,11 +17,10 @@
|
||||
Init Phase:
|
||||
|
||||
1. Init dataloader.
|
||||
2. Init model worker.
|
||||
3. Init optimizer (deepspeed).
|
||||
4. Shard model.
|
||||
5. Init optimizer (fsdp).
|
||||
6. Init scheduler.
|
||||
2. Init optimizer (deepspeed).
|
||||
3. Shard model.
|
||||
4. Init optimizer (fsdp).
|
||||
5. Init scheduler.
|
||||
|
||||
Train Phase:
|
||||
1. Train Loop
|
||||
@@ -29,8 +28,7 @@ Train Phase:
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..utils.types import TorchDataset
|
||||
from .model_worker import ModelWorker
|
||||
from ..utils.types import HFModel, Processor, TorchDataset
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
|
||||
|
||||
@@ -38,21 +36,20 @@ class BaseTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
dataset: TorchDataset,
|
||||
data_collator: DataCollator,
|
||||
model_worker: ModelWorker,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.dataset = dataset
|
||||
self.data_collator = data_collator
|
||||
self.model_worker = model_worker
|
||||
self.data_collator = DataCollator()
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
self.model_worker.init_model_config()
|
||||
# with self.dist_plugin.get_model_init_context():
|
||||
# self.model = self.model_worker.init_model(self.model_config)
|
||||
pass
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -15,12 +15,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..config.sample_args import SampleArguments, SampleBackend
|
||||
from .model_worker import ModelWorker
|
||||
from .model_loader import ModelLoader
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
|
||||
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self):
|
||||
@@ -32,15 +32,13 @@ class BaseEngine(ABC):
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
|
||||
self.model = model_worker.get_model()
|
||||
self.processor = model_worker.get_processor()
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
self.args = sample_args
|
||||
|
||||
|
||||
class ChatSampler:
|
||||
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
if sample_args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(model_worker, sample_args)
|
||||
self.engine = HuggingFaceEngine(model_loader, sample_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
||||
|
||||
@@ -26,7 +26,7 @@ Get Data Sample:
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterable, Iterable
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Union
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -38,7 +38,11 @@ from ..utils.types import DatasetInfo, HFDataset, Sample
|
||||
|
||||
|
||||
class DataEngine(Dataset):
|
||||
"""Data engine."""
|
||||
"""Data engine.
|
||||
|
||||
Args:
|
||||
data_args: Data arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, data_args: DataArguments) -> None:
|
||||
self.args = data_args
|
||||
@@ -51,11 +55,11 @@ class DataEngine(Dataset):
|
||||
"""List of (dataset_name, sample_index)"""
|
||||
self.streaming: bool = False
|
||||
"""Whether dataset is streaming."""
|
||||
self.get_dataset_info()
|
||||
self.load_dataset()
|
||||
self.build_data_index()
|
||||
self._get_dataset_info()
|
||||
self._load_dataset()
|
||||
self._build_data_index()
|
||||
|
||||
def get_dataset_info(self) -> None:
|
||||
def _get_dataset_info(self) -> None:
|
||||
"""Get dataset info from data arguments."""
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
|
||||
self.dataset_infos = OmegaConf.load(self.args.dataset)
|
||||
@@ -68,31 +72,32 @@ class DataEngine(Dataset):
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset}}
|
||||
|
||||
def load_dataset(self) -> None:
|
||||
def _load_dataset(self) -> None:
|
||||
"""Load datasets according to dataset info."""
|
||||
for key, value in self.dataset_infos.items():
|
||||
split = value.get("split", "train")
|
||||
streaming = value.get("streaming", False)
|
||||
for dataset_name, dataset_info in self.dataset_infos.items():
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if value.get("source", "hf_hub") == "hf_hub":
|
||||
if dataset_info.get("source", "hf_hub") == "hf_hub":
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
|
||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
|
||||
self.datasets[dataset_name] = DataLoaderPlugin(dataset_info["source"]).load(dataset_info)
|
||||
|
||||
def build_data_index(self) -> None:
|
||||
def _build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
for dataset_name, dataset in self.datasets.items():
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if self.streaming:
|
||||
streaming = self.dataset_infos[dataset_name].get("streaming", False)
|
||||
if streaming:
|
||||
data_index = [(dataset_name, -1) for _ in range(1000)]
|
||||
else:
|
||||
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
|
||||
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if size or weight: # data index plugin
|
||||
from ..plugins.data_plugins.loader import DataIndexPlugin
|
||||
|
||||
@@ -144,7 +149,7 @@ class DataEngine(Dataset):
|
||||
if isinstance(index, int):
|
||||
dataset_name, sample_index = self.data_index[index]
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
else:
|
||||
else: # data selector plugin
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
|
||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||
@@ -163,30 +168,18 @@ class DataEngine(Dataset):
|
||||
Returns:
|
||||
Iterable[Sample]: Dataset iterator.
|
||||
"""
|
||||
if self.streaming:
|
||||
pass
|
||||
else:
|
||||
# TODO: add shuffle here
|
||||
pass
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterable[Sample]:
|
||||
"""Get dataset async iterator.
|
||||
|
||||
Returns:
|
||||
AsyncIterable[Sample]: Dataset async iterator.
|
||||
"""
|
||||
if self.streaming:
|
||||
pass
|
||||
else:
|
||||
# TODO: add shuffle here
|
||||
pass
|
||||
# NOTE: hf iterable dataset uses worker ids while map dataset does not
|
||||
# NOTE: add worker id and shuffle to the map dataset
|
||||
# https://github.com/huggingface/datasets/blob/4.0.0/src/datasets/iterable_dataset.py#L2214
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
data_args, *_ = get_args()
|
||||
|
||||
128
src/llamafactory/v1/core/model_loader.py
Normal file
128
src/llamafactory/v1/core/model_loader.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model loader.
|
||||
|
||||
Init Phase:
|
||||
1. Init processor.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.model_args import ModelArguments, ModelClass
|
||||
from ..utils import logging
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ModelLoader:
|
||||
"""Model loader.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments.
|
||||
is_trainable: Whether to train the model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
||||
self.args = model_args
|
||||
"""Model arguments."""
|
||||
self.is_train = is_train
|
||||
"""Whether to train the model."""
|
||||
self.processor = self._init_processor()
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.model_config = self._init_model_config()
|
||||
"""Model configuration."""
|
||||
self.model = self._init_model()
|
||||
"""HF model."""
|
||||
|
||||
def _init_processor(self) -> Processor:
|
||||
"""Init processor."""
|
||||
return AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def _init_model_config(self) -> HFConfig:
|
||||
"""Init model config."""
|
||||
return AutoConfig.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def _init_model(self) -> HFModel:
|
||||
"""Init model.
|
||||
|
||||
Let transformers handle the model init context.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||
"""
|
||||
if self.args.model_class == ModelClass.LLM:
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||
AutoClass = AutoModelForImageTextToText
|
||||
else:
|
||||
AutoClass = AutoModelForCausalLM
|
||||
|
||||
elif self.args.model_class == ModelClass.CLS:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
# map the entire model to the current accelerator
|
||||
model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=DistributedInterface.current_accelerator,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
if self.args.peft_config is None:
|
||||
if self.is_train:
|
||||
logger.info_rank0("Fine-tuning mode: full tuning")
|
||||
model = model.to(torch.float32)
|
||||
else:
|
||||
logger.info_rank0("Inference the original model")
|
||||
else:
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
_, model_args, *_ = get_args()
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
print(model_loader.processor)
|
||||
print(model_loader.model_config)
|
||||
print(model_loader.model)
|
||||
@@ -1,119 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model worker.
|
||||
|
||||
Init Phase:
|
||||
1. Init processor.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..accelerator.helper import DeviceType
|
||||
from ..config.model_args import AutoClass, ModelArguments
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self, model_args: ModelArguments) -> None:
|
||||
self.args = model_args
|
||||
"""Model arguments."""
|
||||
self.processor: Optional[Processor] = None
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.model_config: Optional[HFConfig] = None
|
||||
"""Model configuration."""
|
||||
self.model: Optional[HFModel] = None
|
||||
"""HF model."""
|
||||
self.is_adapter = False
|
||||
"""Whether the model has adapter."""
|
||||
|
||||
def init_processor(self) -> None:
|
||||
if self.processor is not None:
|
||||
return
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
if self.model_config is not None:
|
||||
return
|
||||
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_model(self) -> None:
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
self.init_model_config()
|
||||
|
||||
if self.args.auto_class == AutoClass.CAUSALLM:
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||
ModelClass = AutoModelForImageTextToText
|
||||
else:
|
||||
ModelClass = AutoModelForCausalLM
|
||||
elif self.args.auto_class == AutoClass.CLASSIFICATION:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
ModelClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
ModelClass = AutoModel
|
||||
|
||||
default_device_type = torch.get_default_device().type
|
||||
if default_device_type == DeviceType.META:
|
||||
self.model = ModelClass.from_config(self.model_config)
|
||||
else:
|
||||
self.model = ModelClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=default_device_type,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_adapter(self) -> None:
|
||||
if self.is_adapter:
|
||||
return
|
||||
|
||||
if self.args.peft_config is not None:
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
|
||||
|
||||
self.is_adapter = True
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
return self.processor
|
||||
|
||||
def get_model_config(self) -> HFConfig:
|
||||
return self.model_config
|
||||
|
||||
def get_model(self) -> HFModel:
|
||||
return self.model
|
||||
Reference in New Issue
Block a user