[v1] add models & accelerator (#9579)

This commit is contained in:
Yaowei Zheng
2025-12-08 02:30:25 +08:00
committed by GitHub
parent 739954910a
commit 5744f1ea94
27 changed files with 335 additions and 105 deletions

View File

@@ -12,44 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
"""The definition of trainer.
Init Phase:
1. Init dataloader.
2. Init model worker.
3. Init optimizer (deepspeed).
4. Shard model.
5. Init optimizer (fsdp).
6. Init scheduler.
Train Phase:
1. Train Loop
"""
from ..config.training_args import TrainingArguments
from ..extras.types import Model, Processor, Tensor, TorchDataset
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
# sft: messages
# dpo: chosen_messages, rejected_messages
from ..extras.types import TorchDataset
from .model_worker import ModelWorker
from .trainer_utils.data_collator import DataCollator
class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: Model,
processor: Processor,
dataset: TorchDataset,
data_collator: DataCollator,
model_worker: ModelWorker,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.dataset = dataset
self.data_collator = data_collator
self.model_worker = model_worker
self.optimizer = None
self.lr_scheduler = None
def init_device_mesh(self) -> None:
pass
def init_model_and_optimizer(self) -> None:
self.model_config = self.model_worker.get_model_config()
# with self.dist_plugin.get_model_init_context():
# self.model = self.model_worker.get_model(self.model_config)
def create_dataloader(self) -> None:
pass

View File

@@ -12,9 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.sample_args import SampleArguments
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_worker import ModelWorker
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_worker: ModelWorker) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
self.model = model_worker.get_model()
self.processor = model_worker.get_processor()
self.args = sample_args
class ChatSampler:
def __init__(self, sample_args: SampleArguments) -> None:
self.args = sample_args
def __init__(self, model_worker: ModelWorker, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_worker, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -12,11 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of data engine.
Init Data engine:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
"""
import os
from collections.abc import AsyncIterable, Iterable
from typing import Any, Union
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from torch.utils.data import Dataset
@@ -45,15 +57,13 @@ class DataEngine(Dataset):
def get_dataset_info(self) -> None:
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(
os.path.join(self.args.dataset_dir, self.args.dataset)
): # local file
self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
self.dataset_infos = OmegaConf.load(self.args.dataset)
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s)
elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
@@ -65,11 +75,13 @@ class DataEngine(Dataset):
streaming = value.get("streaming", False)
self.streaming |= streaming
if "hf_hub_url" in value:
from datasets import load_dataset
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
def build_data_index(self) -> None:
"""Build dataset index."""
@@ -145,11 +157,11 @@ class DataEngine(Dataset):
dataset_name, sample_index = selected_index
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
def __iter__(self) -> Iterable:
def __iter__(self) -> Iterable[Sample]:
"""Get dataset iterator.
Returns:
Iterable: Dataset iterator.
Iterable[Sample]: Dataset iterator.
"""
if self.streaming:
pass
@@ -159,11 +171,11 @@ class DataEngine(Dataset):
raise NotImplementedError()
async def __aiter__(self) -> AsyncIterable:
async def __aiter__(self) -> AsyncIterable[Sample]:
"""Get dataset async iterator.
Returns:
AsyncIterable: Dataset async iterator.
AsyncIterable[Sample]: Dataset async iterator.
"""
if self.streaming:
pass

View File

@@ -1,27 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..config.model_args import ModelArguments
from ..extras.types import Model, Processor
class ModelEngine:
def __init__(self, model_args: ModelArguments) -> None:
self.args = model_args
def get_model(self) -> Model:
pass
def get_processor(self) -> Processor:
pass

View File

@@ -0,0 +1,98 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model worker.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
from typing import Optional
from transformers import AutoConfig, AutoProcessor
from ..config.model_args import ModelArguments
from ..extras.types import DistModel, HFConfig, HFModel, Processor
class ModelWorker:
def __init__(self, model_args: ModelArguments) -> None:
self.args = model_args
"""Model arguments."""
self.processor: Optional[Processor] = None
"""Tokenizer or multi-modal processor."""
self.model_config: Optional[HFConfig] = None
"""Model configuration."""
self.unwrapped_model: Optional[HFModel] = None
"""Unwrapped model."""
self.model: Optional[DistModel] = None
"""Distributed model."""
self.init_processor()
self.init_model_config()
self.init_model()
self.init_adapter()
def init_processor(self) -> None:
self.processor = AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def init_model_config(self) -> None:
self.model_config = AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def init_model(self) -> None:
if self.args.auto_model_class == "causallm":
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.auto_model_class == "classification":
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
self.unwrapped_model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map="cpu",
trust_remote_code=self.args.trust_remote_code,
)
def init_adapter(self) -> None:
pass
def get_processor(self) -> Processor:
return self.processor
def get_model_config(self) -> HFConfig:
return self.model_config
def get_model(self) -> HFModel:
return self.unwrapped_model

View File

@@ -0,0 +1,47 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from ...extras.types import Processor, Tensor, TorchDataset
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
# sft: messages
# dpo: chosen_messages, rejected_messages
class DataLoader:
"""Default DataLoader."""
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic