[v1] add accelerator (#9607)

This commit is contained in:
Yaowei Zheng
2025-12-12 19:22:06 +08:00
committed by GitHub
parent 4fd94141a4
commit 203069e11c
36 changed files with 941 additions and 443 deletions

View File

@@ -29,7 +29,7 @@ Train Phase:
"""
from ..config.training_args import TrainingArguments
from ..extras.types import TorchDataset
from ..utils.types import TorchDataset
from .model_worker import ModelWorker
from .trainer_utils.data_collator import DataCollator
@@ -49,13 +49,10 @@ class BaseTrainer:
self.optimizer = None
self.lr_scheduler = None
def init_device_mesh(self) -> None:
pass
def init_model_and_optimizer(self) -> None:
self.model_config = self.model_worker.get_model_config()
self.model_worker.init_model_config()
# with self.dist_plugin.get_model_init_context():
# self.model = self.model_worker.get_model(self.model_config)
# self.model = self.model_worker.init_model(self.model_config)
def create_dataloader(self) -> None:
pass

View File

@@ -34,7 +34,7 @@ from omegaconf import OmegaConf
from torch.utils.data import Dataset
from ..config.data_args import DataArguments
from ..extras.types import DatasetInfo, HFDataset, Sample
from ..utils.types import DatasetInfo, HFDataset, Sample
class DataEngine(Dataset):
@@ -64,9 +64,9 @@ class DataEngine(Dataset):
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
self.dataset_infos = {"default": {"path": self.args.dataset}}
def load_dataset(self) -> None:
"""Load datasets according to dataset info."""
@@ -74,14 +74,14 @@ class DataEngine(Dataset):
split = value.get("split", "train")
streaming = value.get("streaming", False)
self.streaming |= streaming
if "hf_hub_url" in value:
if value.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
def build_data_index(self) -> None:
"""Build dataset index."""
@@ -112,9 +112,9 @@ class DataEngine(Dataset):
"""
converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None:
from ..plugins.data_plugins.converter import get_converter
from ..plugins.data_plugins.converter import DataConverterPlugin
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
return {"_dataset_name": dataset_name, **DataConverterPlugin(converter)(raw_sample)}
else:
return {"_dataset_name": dataset_name, **raw_sample}
@@ -147,7 +147,7 @@ class DataEngine(Dataset):
else:
from ..plugins.data_plugins.loader import DataSelectorPlugin
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
selected_index = DataSelectorPlugin().select(self.data_index, index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
@@ -187,7 +187,7 @@ class DataEngine(Dataset):
if __name__ == "__main__":
from ..config.parser import get_args
from ..config.arg_parser import get_args
data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args)

View File

@@ -24,10 +24,12 @@ Init Phase:
from typing import Optional
import torch
from transformers import AutoConfig, AutoProcessor
from ..config.model_args import ModelArguments
from ..extras.types import DistModel, HFConfig, HFModel, Processor
from ..accelerator.helper import DeviceType
from ..config.model_args import AutoClass, ModelArguments
from ..utils.types import HFConfig, HFModel, Processor
class ModelWorker:
@@ -38,16 +40,15 @@ class ModelWorker:
"""Tokenizer or multi-modal processor."""
self.model_config: Optional[HFConfig] = None
"""Model configuration."""
self.unwrapped_model: Optional[HFModel] = None
"""Unwrapped model."""
self.model: Optional[DistModel] = None
"""Distributed model."""
self.init_processor()
self.init_model_config()
self.init_model()
self.init_adapter()
self.model: Optional[HFModel] = None
"""HF model."""
self.is_adapter = False
"""Whether the model has adapter."""
def init_processor(self) -> None:
if self.processor is not None:
return
self.processor = AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
@@ -55,38 +56,58 @@ class ModelWorker:
)
def init_model_config(self) -> None:
if self.model_config is not None:
return
self.model_config = AutoConfig.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
)
def init_model(self) -> None:
if self.args.auto_model_class == "causallm":
if self.model is not None:
return
self.init_model_config()
if self.args.auto_class == AutoClass.CAUSALLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
AutoClass = AutoModelForImageTextToText
ModelClass = AutoModelForImageTextToText
else:
AutoClass = AutoModelForCausalLM
elif self.args.auto_model_class == "classification":
ModelClass = AutoModelForCausalLM
elif self.args.auto_class == AutoClass.CLASSIFICATION:
from transformers import AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification
ModelClass = AutoModelForTokenClassification
else:
from transformers import AutoModel
AutoClass = AutoModel
ModelClass = AutoModel
self.unwrapped_model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map="cpu",
trust_remote_code=self.args.trust_remote_code,
)
default_device_type = torch.get_default_device().type
if default_device_type == DeviceType.META:
self.model = ModelClass.from_config(self.model_config)
else:
self.model = ModelClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=default_device_type,
trust_remote_code=self.args.trust_remote_code,
)
def init_adapter(self) -> None:
pass
if self.is_adapter:
return
if self.args.peft_config is not None:
from ..plugins.model_plugins.peft import PeftPlugin
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
self.is_adapter = True
def get_processor(self) -> Processor:
return self.processor
@@ -95,4 +116,4 @@ class ModelWorker:
return self.model_config
def get_model(self) -> HFModel:
return self.unwrapped_model
return self.model

View File

@@ -15,7 +15,7 @@
from typing import Any
from ...extras.types import Processor, Tensor, TorchDataset
from ...utils.types import Processor, Tensor, TorchDataset
class DataCollator: