mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
[v1] add data converter (#9263)
This commit is contained in:
@@ -12,8 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..extras.types import DataCollator, Model, Processor, TorchDataset
|
||||
from ..extras.types import Model, Processor, Tensor, TorchDataset
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
def __init__(self, processor: Processor) -> None:
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
for feature in features:
|
||||
pass
|
||||
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
|
||||
@@ -22,19 +22,7 @@ from omegaconf import OmegaConf
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..config.data_args import DataArguments
|
||||
from ..extras.types import DatasetInfo, HFDataset, Processor, Tensor
|
||||
from ..plugins.data_plugins.loader import DataGetItemPlugin, DataIndexPlugin, DataLoaderPlugin
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
def __init__(self, processor: Processor, dataset_info: dict[str, DatasetInfo]) -> None:
|
||||
self.processor = processor
|
||||
self.dataset_info = dataset_info
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
pass
|
||||
from ..extras.types import DatasetInfo, HFDataset, Sample
|
||||
|
||||
|
||||
class DataEngine(Dataset):
|
||||
@@ -45,73 +33,78 @@ class DataEngine(Dataset):
|
||||
"""Data arguments."""
|
||||
self.datasets: dict[str, HFDataset] = {}
|
||||
"""Dict of (dataset_name, dataset)"""
|
||||
self.dataset_info: dict[str, DatasetInfo] = {}
|
||||
self.dataset_infos: dict[str, DatasetInfo] = {}
|
||||
"""Dict of (dataset_name, dataset_info)"""
|
||||
self.streaming: bool = False
|
||||
"""Whether dataset is streaming."""
|
||||
self.data_index: list[tuple[str, int]] = []
|
||||
"""List of (dataset_name, sample_index)"""
|
||||
self.data_loader_plugin = DataLoaderPlugin(args=self.args)
|
||||
"""Data loader plugin."""
|
||||
self.data_index_plugin = DataIndexPlugin()
|
||||
"""Data index plugin."""
|
||||
self.data_getitem_plugin = DataGetItemPlugin(datasets=self.datasets, data_index=self.data_index)
|
||||
"""Data getitem plugin."""
|
||||
self.streaming: bool = False
|
||||
"""Whether dataset is streaming."""
|
||||
self.get_dataset_info()
|
||||
self.load_dataset()
|
||||
self.build_data_index()
|
||||
|
||||
def get_dataset_info(self) -> None:
|
||||
"""Get dataset info."""
|
||||
"""Get dataset info from data arguments."""
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(
|
||||
os.path.join(self.args.dataset_dir, self.args.dataset)
|
||||
): # local file
|
||||
self.dataset_info = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
|
||||
self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
|
||||
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
||||
repo_id, filename = os.path.split(self.args.dataset)
|
||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||
self.dataset_info = OmegaConf.load(filepath)
|
||||
self.dataset_infos = OmegaConf.load(filepath)
|
||||
elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s)
|
||||
self.dataset_info = {"default": {"file_name": self.args.dataset}}
|
||||
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_info = {"default": {"hf_hub_url": self.args.dataset}}
|
||||
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
|
||||
|
||||
def load_dataset(self) -> None:
|
||||
"""Load dataset from dataset info."""
|
||||
for key, value in self.dataset_info.items():
|
||||
"""Load datasets according to dataset info."""
|
||||
for key, value in self.dataset_infos.items():
|
||||
split = value.get("split", "train")
|
||||
streaming = value.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if "hf_hub_url" in value:
|
||||
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
|
||||
else: # data loader plugin
|
||||
self.datasets[key] = self.data_loader_plugin.auto_load_data(value)
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
|
||||
|
||||
def build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
for dataset_name, dataset in self.datasets.items():
|
||||
size = self.dataset_info[dataset_name].get("size")
|
||||
weight = self.dataset_info[dataset_name].get("weight")
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if self.streaming:
|
||||
data_index = [(dataset_name, -1) for _ in range(1000)]
|
||||
else:
|
||||
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
|
||||
|
||||
if size or weight: # data index plugin
|
||||
data_index = self.data_index_plugin.adjust_data_index(data_index, size, weight)
|
||||
from ..plugins.data_plugins.loader import DataIndexPlugin
|
||||
|
||||
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
|
||||
|
||||
self.data_index.extend(data_index)
|
||||
|
||||
def get_data_collator(self, processor: Processor) -> DataCollator:
|
||||
"""Get data collator.
|
||||
def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
|
||||
"""Convert dataset sample.
|
||||
|
||||
Args:
|
||||
processor (Processor): Processor.
|
||||
raw_sample (dict[str, Any]): Raw dataset sample.
|
||||
dataset_name (str): Dataset name.
|
||||
|
||||
Returns:
|
||||
DataCollator: Data collator.
|
||||
Sample: Dataset sample.
|
||||
"""
|
||||
return DataCollator(processor=processor, dataset_info=self.dataset_info)
|
||||
converter = self.dataset_infos[dataset_name].get("converter")
|
||||
if converter is not None:
|
||||
from ..plugins.data_plugins.converter import get_converter
|
||||
|
||||
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
|
||||
else:
|
||||
return {"_dataset_name": dataset_name, **raw_sample}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get dataset length.
|
||||
@@ -124,23 +117,33 @@ class DataEngine(Dataset):
|
||||
else:
|
||||
return len(self.data_index)
|
||||
|
||||
def __getitem__(self, index: Union[int, slice, list[int]]) -> Union[dict, list[dict]]:
|
||||
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
|
||||
"""Get dataset item.
|
||||
|
||||
Args:
|
||||
index (int): Dataset index.
|
||||
|
||||
Returns:
|
||||
dict: Dataset item.
|
||||
Sample: Dataset item.
|
||||
"""
|
||||
if self.streaming:
|
||||
raise ValueError("Streaming dataset does not support index access.")
|
||||
|
||||
if isinstance(index, int):
|
||||
dataset_name, sample_index = self.data_index[index]
|
||||
return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
else:
|
||||
return self.data_getitem_plugin.get_data(index)
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
|
||||
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
|
||||
if isinstance(selected_index, list):
|
||||
return [
|
||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
for dataset_name, sample_index in selected_index
|
||||
]
|
||||
else:
|
||||
dataset_name, sample_index = selected_index
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
|
||||
def __iter__(self) -> Iterable:
|
||||
"""Get dataset iterator.
|
||||
@@ -156,7 +159,7 @@ class DataEngine(Dataset):
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def __aiter__(self) -> AsyncIterable:
|
||||
async def __aiter__(self) -> AsyncIterable:
|
||||
"""Get dataset async iterator.
|
||||
|
||||
Returns:
|
||||
@@ -169,3 +172,11 @@ class DataEngine(Dataset):
|
||||
pass
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ..config.parser import get_args
|
||||
|
||||
data_args, *_ = get_args()
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
print(data_engine[0])
|
||||
|
||||
Reference in New Issue
Block a user