mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
[v1] add models & accelerator (#9579)
This commit is contained in:
@@ -12,44 +12,51 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
"""The definition of trainer.
|
||||
|
||||
Init Phase:
|
||||
|
||||
1. Init dataloader.
|
||||
2. Init model worker.
|
||||
3. Init optimizer (deepspeed).
|
||||
4. Shard model.
|
||||
5. Init optimizer (fsdp).
|
||||
6. Init scheduler.
|
||||
|
||||
Train Phase:
|
||||
1. Train Loop
|
||||
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..extras.types import Model, Processor, Tensor, TorchDataset
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
def __init__(self, processor: Processor) -> None:
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
for feature in features:
|
||||
pass
|
||||
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
from ..extras.types import TorchDataset
|
||||
from .model_worker import ModelWorker
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: Model,
|
||||
processor: Processor,
|
||||
dataset: TorchDataset,
|
||||
data_collator: DataCollator,
|
||||
model_worker: ModelWorker,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.dataset = dataset
|
||||
self.data_collator = data_collator
|
||||
self.model_worker = model_worker
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def init_device_mesh(self) -> None:
|
||||
pass
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
self.model_config = self.model_worker.get_model_config()
|
||||
# with self.dist_plugin.get_model_init_context():
|
||||
# self.model = self.model_worker.get_model(self.model_config)
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -12,9 +12,35 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..config.sample_args import SampleArguments
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..config.sample_args import SampleArguments, SampleBackend
|
||||
from .model_worker import ModelWorker
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self):
|
||||
pass
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
|
||||
self.model = model_worker.get_model()
|
||||
self.processor = model_worker.get_processor()
|
||||
self.args = sample_args
|
||||
|
||||
|
||||
class ChatSampler:
|
||||
def __init__(self, sample_args: SampleArguments) -> None:
|
||||
self.args = sample_args
|
||||
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
|
||||
if sample_args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(model_worker, sample_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
||||
|
||||
@@ -12,11 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of data engine.
|
||||
|
||||
Init Data engine:
|
||||
1. Parse dataset info from arguments.
|
||||
2. Load datasets according to dataset info.
|
||||
3. Build data index (and reweight samples if necessary).
|
||||
|
||||
Get Data Sample:
|
||||
1. Get sample from data index.
|
||||
2. Convert sample to standard format.
|
||||
3. Return sample.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterable, Iterable
|
||||
from typing import Any, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import Dataset
|
||||
@@ -45,15 +57,13 @@ class DataEngine(Dataset):
|
||||
|
||||
def get_dataset_info(self) -> None:
|
||||
"""Get dataset info from data arguments."""
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(
|
||||
os.path.join(self.args.dataset_dir, self.args.dataset)
|
||||
): # local file
|
||||
self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
|
||||
self.dataset_infos = OmegaConf.load(self.args.dataset)
|
||||
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
||||
repo_id, filename = os.path.split(self.args.dataset)
|
||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||
self.dataset_infos = OmegaConf.load(filepath)
|
||||
elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s)
|
||||
elif os.path.exists(self.args.dataset): # local file(s)
|
||||
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
|
||||
@@ -65,11 +75,13 @@ class DataEngine(Dataset):
|
||||
streaming = value.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if "hf_hub_url" in value:
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
|
||||
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
|
||||
|
||||
def build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
@@ -145,11 +157,11 @@ class DataEngine(Dataset):
|
||||
dataset_name, sample_index = selected_index
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
|
||||
def __iter__(self) -> Iterable:
|
||||
def __iter__(self) -> Iterable[Sample]:
|
||||
"""Get dataset iterator.
|
||||
|
||||
Returns:
|
||||
Iterable: Dataset iterator.
|
||||
Iterable[Sample]: Dataset iterator.
|
||||
"""
|
||||
if self.streaming:
|
||||
pass
|
||||
@@ -159,11 +171,11 @@ class DataEngine(Dataset):
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterable:
|
||||
async def __aiter__(self) -> AsyncIterable[Sample]:
|
||||
"""Get dataset async iterator.
|
||||
|
||||
Returns:
|
||||
AsyncIterable: Dataset async iterator.
|
||||
AsyncIterable[Sample]: Dataset async iterator.
|
||||
"""
|
||||
if self.streaming:
|
||||
pass
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..config.model_args import ModelArguments
|
||||
from ..extras.types import Model, Processor
|
||||
|
||||
|
||||
class ModelEngine:
|
||||
def __init__(self, model_args: ModelArguments) -> None:
|
||||
self.args = model_args
|
||||
|
||||
def get_model(self) -> Model:
|
||||
pass
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
pass
|
||||
98
src/llamafactory/v1/core/model_worker.py
Normal file
98
src/llamafactory/v1/core/model_worker.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model worker.
|
||||
|
||||
Init Phase:
|
||||
1. Init processor.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..config.model_args import ModelArguments
|
||||
from ..extras.types import DistModel, HFConfig, HFModel, Processor
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self, model_args: ModelArguments) -> None:
|
||||
self.args = model_args
|
||||
"""Model arguments."""
|
||||
self.processor: Optional[Processor] = None
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.model_config: Optional[HFConfig] = None
|
||||
"""Model configuration."""
|
||||
self.unwrapped_model: Optional[HFModel] = None
|
||||
"""Unwrapped model."""
|
||||
self.model: Optional[DistModel] = None
|
||||
"""Distributed model."""
|
||||
self.init_processor()
|
||||
self.init_model_config()
|
||||
self.init_model()
|
||||
self.init_adapter()
|
||||
|
||||
def init_processor(self) -> None:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_model(self) -> None:
|
||||
if self.args.auto_model_class == "causallm":
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||
AutoClass = AutoModelForImageTextToText
|
||||
else:
|
||||
AutoClass = AutoModelForCausalLM
|
||||
elif self.args.auto_model_class == "classification":
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
self.unwrapped_model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map="cpu",
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_adapter(self) -> None:
|
||||
pass
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
return self.processor
|
||||
|
||||
def get_model_config(self) -> HFConfig:
|
||||
return self.model_config
|
||||
|
||||
def get_model(self) -> HFModel:
|
||||
return self.unwrapped_model
|
||||
0
src/llamafactory/v1/core/trainer_utils/__init__.py
Normal file
0
src/llamafactory/v1/core/trainer_utils/__init__.py
Normal file
0
src/llamafactory/v1/core/trainer_utils/callback.py
Normal file
0
src/llamafactory/v1/core/trainer_utils/callback.py
Normal file
47
src/llamafactory/v1/core/trainer_utils/data_collator.py
Normal file
47
src/llamafactory/v1/core/trainer_utils/data_collator.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...extras.types import Processor, Tensor, TorchDataset
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
def __init__(self, processor: Processor) -> None:
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
for feature in features:
|
||||
pass
|
||||
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
Reference in New Issue
Block a user