This commit is contained in:
hiyouga
2024-10-11 23:51:54 +08:00
parent 228dd1739e
commit 3af57795dd
12 changed files with 91 additions and 69 deletions

View File

@@ -53,7 +53,7 @@ def _load_single_dataset(
"""
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["om_hub", "hf_hub", "ms_hub"]:
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
@@ -84,24 +84,7 @@ def _load_single_dataset(
else:
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
if dataset_attr.load_from == "om_hub":
try:
from openmind import OmDataset
from openmind.utils.hub import OM_DATASETS_CACHE
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
)
except ImportError:
raise ImportError("Please install openmind via `pip install openmind -U`")
elif dataset_attr.load_from == "ms_hub":
if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
@@ -119,6 +102,23 @@ def _load_single_dataset(
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset
from openmind.utils.hub import OM_DATASETS_CACHE
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
)
else:
dataset = load_dataset(
path=data_path,