Merge pull request #1802 from tastelikefeet/feat/support_ms

Support ModelScope Datahub

Former-commit-id: f73f321e765aab9325673218779ff4ee7f281514
This commit is contained in:
hoshi-hiyouga 2023-12-12 17:58:37 +08:00 committed by GitHub
commit 7c9f37c83d
3 changed files with 42 additions and 13 deletions

View File

@ -25,7 +25,7 @@ def get_dataset(
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from == "hf_hub": if dataset_attr.load_from in ("hf_hub", "ms_hub"):
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_dir = dataset_attr.folder data_dir = dataset_attr.folder
@ -53,10 +53,24 @@ def get_dataset(
else: else:
raise NotImplementedError raise NotImplementedError
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and dataset_attr.load_from == "ms_hub":
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
split=data_args.split,
data_files=data_files,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
).to_hf_dataset()
else:
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
name=data_name, name=data_name,
data_dir=data_dir,
data_files=data_files, data_files=data_files,
split=data_args.split, split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,

View File

@ -2,7 +2,9 @@ import os
import json import json
from typing import List, Literal, Optional from typing import List, Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
DATA_CONFIG = "dataset_info.json" DATA_CONFIG = "dataset_info.json"
@ -153,8 +155,17 @@ class DataArguments:
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
if "hf_hub_url" in dataset_info[name]: if "hf_hub_url" in dataset_info[name] or 'ms_hub_url' in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) url_key_name = "hf_hub_url"
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')):
if 'ms_hub_url' in dataset_info[name]:
url_key_name = 'ms_hub_url'
else:
logger.warning('You are using ModelScope Hub, but the specified dataset '
'has no `ms_hub_url` key, so `hf_hub_url` will be used instead.')
dataset_attr = DatasetAttr(url_key_name[:url_key_name.index('_url')],
dataset_name=dataset_info[name][url_key_name])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else: else:

View File

@ -59,6 +59,10 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."} metadata={"help": "Auth token to log in with Hugging Face Hub."}
) )
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}
)
def __post_init__(self): def __post_init__(self):
self.compute_dtype = None self.compute_dtype = None