support ms dataset

Former-commit-id: 98638b35dc24045ac17b9b01d08d3a02372acef3
This commit is contained in:
yuze.zyz 2023-12-08 18:00:57 +08:00
parent a1ec668b70
commit 596f496f19
2 changed files with 30 additions and 12 deletions

View File

@ -24,7 +24,7 @@ def get_dataset(
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub": if dataset_attr.load_from in ("hf_hub", "ms_hub"):
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_files = None data_files = None
@ -53,15 +53,22 @@ def get_dataset(
else: else:
raise NotImplementedError raise NotImplementedError
dataset = load_dataset( if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and dataset_attr.load_from == "ms_hub":
path=data_path, from modelscope import MsDataset
name=data_name, dataset = MsDataset.load(
data_files=data_files, dataset_name=data_path,
split=data_args.split, subset_name=data_name,
cache_dir=model_args.cache_dir, ).to_hf_dataset()
token=model_args.hf_hub_token, else:
streaming=(data_args.streaming and (dataset_attr.load_from != "file")) dataset = load_dataset(
) path=data_path,
name=data_name,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
)
if data_args.streaming and (dataset_attr.load_from == "file"): if data_args.streaming and (dataset_attr.load_from == "file"):
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter

View File

@ -2,7 +2,9 @@ import os
import json import json
from typing import List, Literal, Optional from typing import List, Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
DATA_CONFIG = "dataset_info.json" DATA_CONFIG = "dataset_info.json"
@ -152,8 +154,17 @@ class DataArguments:
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
if "hf_hub_url" in dataset_info[name]: if "hf_hub_url" in dataset_info[name] or 'ms_hub_url' in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) url_key_name = "hf_hub_url"
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')):
if 'ms_hub_url' in dataset_info[name]:
url_key_name = 'ms_hub_url'
else:
logger.warning('You are using ModelScope Hub, but the specified dataset '
'has no `ms_hub_url` key, so `hf_hub_url` will be used instead.')
dataset_attr = DatasetAttr(url_key_name[:url_key_name.index('_url')],
dataset_name=dataset_info[name][url_key_name])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else: else: