Former-commit-id: 6a5074e46695378b76d58aac8ad7768b6b034b9c
This commit is contained in:
hoshi-hiyouga 2024-12-04 22:08:27 +08:00 committed by GitHub
parent ae09c6c214
commit 92940817e7

View File

@ -14,13 +14,11 @@
import os import os
import sys import sys
import zstandard as zstd
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np import numpy as np
from datasets import DatasetDict, Dataset, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from typing import Union
from ..extras import logging from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
@ -54,7 +52,6 @@ def _load_single_dataset(
Loads a single dataset and aligns it to the standard format. Loads a single dataset and aligns it to the standard format.
""" """
logger.info_rank0(f"Loading dataset {dataset_attr}...") logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
@ -122,7 +119,6 @@ def _load_single_dataset(
streaming=data_args.streaming, streaming=data_args.streaming,
) )
else: else:
logger.info(f"Loading dataset {data_path}, {data_name}, {data_dir}, {data_files},{dataset_attr.split},{model_args.cache_dir},{model_args.hf_hub_token},{data_args.streaming}...")
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
name=data_name, name=data_name,
@ -243,30 +239,24 @@ def get_dataset(
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: Union[DatasetDict, Dataset] = load_from_disk(data_args.tokenized_path) tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if isinstance(tokenized_data, DatasetDict):
if "train" in tokenized_data:
dataset_module["train_dataset"] = tokenized_data["train"]
if isinstance(dataset_dict, DatasetDict): if "validation" in tokenized_data:
print(f"Loaded DatasetDict with keys: {dataset_dict.keys()}") dataset_module["eval_dataset"] = tokenized_data["validation"]
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict: else: # Dataset
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["train_dataset"] = tokenized_data
if data_args.streaming: if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
logger.info_rank0(f"Finished tokenized dataset load from {data_args.tokenized_path}.") return dataset_module
return dataset_module
elif isinstance(dataset_dict, Dataset):
logger.info_rank0(f"Loaded single Dataset with {len(dataset_dict)} samples.")
dataset_module["train_dataset"] = dataset_dict
return dataset_module
else:
raise ValueError("Unknown dataset type loaded!")
if data_args.streaming: if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.") raise ValueError("Turn off `streaming` when saving dataset to disk.")