Former-commit-id: 6a5074e46695378b76d58aac8ad7768b6b034b9c
This commit is contained in:
hoshi-hiyouga 2024-12-04 22:08:27 +08:00 committed by GitHub
parent ae09c6c214
commit 92940817e7

View File

@ -14,13 +14,11 @@
import os
import sys
import zstandard as zstd
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np
from datasets import DatasetDict, Dataset, load_dataset, load_from_disk
from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from typing import Union
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
@ -54,7 +52,6 @@ def _load_single_dataset(
Loads a single dataset and aligns it to the standard format.
"""
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name
@ -122,7 +119,6 @@ def _load_single_dataset(
streaming=data_args.streaming,
)
else:
logger.info(f"Loading dataset {data_path}, {data_name}, {data_dir}, {data_files},{dataset_attr.split},{model_args.cache_dir},{model_args.hf_hub_token},{data_args.streaming}...")
dataset = load_dataset(
path=data_path,
name=data_name,
@ -243,30 +239,24 @@ def get_dataset(
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: Union[DatasetDict, Dataset] = load_from_disk(data_args.tokenized_path)
tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {}
if isinstance(tokenized_data, DatasetDict):
if "train" in tokenized_data:
dataset_module["train_dataset"] = tokenized_data["train"]
if isinstance(dataset_dict, DatasetDict):
print(f"Loaded DatasetDict with keys: {dataset_dict.keys()}")
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
else: # Dataset
dataset_module["train_dataset"] = tokenized_data
if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
logger.info_rank0(f"Finished tokenized dataset load from {data_args.tokenized_path}.")
return dataset_module
elif isinstance(dataset_dict, Dataset):
logger.info_rank0(f"Loaded single Dataset with {len(dataset_dict)} samples.")
dataset_module["train_dataset"] = dataset_dict
return dataset_module
else:
raise ValueError("Unknown dataset type loaded!")
return dataset_module
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")