mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
[v1] add accelerator (#9607)
This commit is contained in:
@@ -29,7 +29,7 @@ Train Phase:
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..extras.types import TorchDataset
|
||||
from ..utils.types import TorchDataset
|
||||
from .model_worker import ModelWorker
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
|
||||
@@ -49,13 +49,10 @@ class BaseTrainer:
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def init_device_mesh(self) -> None:
|
||||
pass
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
self.model_config = self.model_worker.get_model_config()
|
||||
self.model_worker.init_model_config()
|
||||
# with self.dist_plugin.get_model_init_context():
|
||||
# self.model = self.model_worker.get_model(self.model_config)
|
||||
# self.model = self.model_worker.init_model(self.model_config)
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -34,7 +34,7 @@ from omegaconf import OmegaConf
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..config.data_args import DataArguments
|
||||
from ..extras.types import DatasetInfo, HFDataset, Sample
|
||||
from ..utils.types import DatasetInfo, HFDataset, Sample
|
||||
|
||||
|
||||
class DataEngine(Dataset):
|
||||
@@ -64,9 +64,9 @@ class DataEngine(Dataset):
|
||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||
self.dataset_infos = OmegaConf.load(filepath)
|
||||
elif os.path.exists(self.args.dataset): # local file(s)
|
||||
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset}}
|
||||
|
||||
def load_dataset(self) -> None:
|
||||
"""Load datasets according to dataset info."""
|
||||
@@ -74,14 +74,14 @@ class DataEngine(Dataset):
|
||||
split = value.get("split", "train")
|
||||
streaming = value.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if "hf_hub_url" in value:
|
||||
if value.get("source", "hf_hub") == "hf_hub":
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
|
||||
self.datasets[key] = load_dataset(value["path"], split=split, streaming=streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
self.datasets[key] = DataLoaderPlugin().auto_load_data(value)
|
||||
self.datasets[key] = DataLoaderPlugin(value["source"]).load(value)
|
||||
|
||||
def build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
@@ -112,9 +112,9 @@ class DataEngine(Dataset):
|
||||
"""
|
||||
converter = self.dataset_infos[dataset_name].get("converter")
|
||||
if converter is not None:
|
||||
from ..plugins.data_plugins.converter import get_converter
|
||||
from ..plugins.data_plugins.converter import DataConverterPlugin
|
||||
|
||||
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
|
||||
return {"_dataset_name": dataset_name, **DataConverterPlugin(converter)(raw_sample)}
|
||||
else:
|
||||
return {"_dataset_name": dataset_name, **raw_sample}
|
||||
|
||||
@@ -147,7 +147,7 @@ class DataEngine(Dataset):
|
||||
else:
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
|
||||
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
|
||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||
if isinstance(selected_index, list):
|
||||
return [
|
||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
@@ -187,7 +187,7 @@ class DataEngine(Dataset):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ..config.parser import get_args
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
data_args, *_ = get_args()
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
|
||||
@@ -24,10 +24,12 @@ Init Phase:
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..config.model_args import ModelArguments
|
||||
from ..extras.types import DistModel, HFConfig, HFModel, Processor
|
||||
from ..accelerator.helper import DeviceType
|
||||
from ..config.model_args import AutoClass, ModelArguments
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
@@ -38,16 +40,15 @@ class ModelWorker:
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.model_config: Optional[HFConfig] = None
|
||||
"""Model configuration."""
|
||||
self.unwrapped_model: Optional[HFModel] = None
|
||||
"""Unwrapped model."""
|
||||
self.model: Optional[DistModel] = None
|
||||
"""Distributed model."""
|
||||
self.init_processor()
|
||||
self.init_model_config()
|
||||
self.init_model()
|
||||
self.init_adapter()
|
||||
self.model: Optional[HFModel] = None
|
||||
"""HF model."""
|
||||
self.is_adapter = False
|
||||
"""Whether the model has adapter."""
|
||||
|
||||
def init_processor(self) -> None:
|
||||
if self.processor is not None:
|
||||
return
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
@@ -55,38 +56,58 @@ class ModelWorker:
|
||||
)
|
||||
|
||||
def init_model_config(self) -> None:
|
||||
if self.model_config is not None:
|
||||
return
|
||||
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_model(self) -> None:
|
||||
if self.args.auto_model_class == "causallm":
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
self.init_model_config()
|
||||
|
||||
if self.args.auto_class == AutoClass.CAUSALLM:
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
if type(self.model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
||||
AutoClass = AutoModelForImageTextToText
|
||||
ModelClass = AutoModelForImageTextToText
|
||||
else:
|
||||
AutoClass = AutoModelForCausalLM
|
||||
elif self.args.auto_model_class == "classification":
|
||||
ModelClass = AutoModelForCausalLM
|
||||
elif self.args.auto_class == AutoClass.CLASSIFICATION:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
ModelClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
AutoClass = AutoModel
|
||||
ModelClass = AutoModel
|
||||
|
||||
self.unwrapped_model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map="cpu",
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
default_device_type = torch.get_default_device().type
|
||||
if default_device_type == DeviceType.META:
|
||||
self.model = ModelClass.from_config(self.model_config)
|
||||
else:
|
||||
self.model = ModelClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=default_device_type,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
def init_adapter(self) -> None:
|
||||
pass
|
||||
if self.is_adapter:
|
||||
return
|
||||
|
||||
if self.args.peft_config is not None:
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
self.model = PeftPlugin(self.args.peft_config.name)(self.model, self.args.peft_config)
|
||||
|
||||
self.is_adapter = True
|
||||
|
||||
def get_processor(self) -> Processor:
|
||||
return self.processor
|
||||
@@ -95,4 +116,4 @@ class ModelWorker:
|
||||
return self.model_config
|
||||
|
||||
def get_model(self) -> HFModel:
|
||||
return self.unwrapped_model
|
||||
return self.model
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...extras.types import Processor, Tensor, TorchDataset
|
||||
from ...utils.types import Processor, Tensor, TorchDataset
|
||||
|
||||
|
||||
class DataCollator:
|
||||
|
||||
Reference in New Issue
Block a user