diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index d5a7a588..75113cb3 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -25,7 +25,7 @@ def get_dataset( logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None - if dataset_attr.load_from == "hf_hub": + if dataset_attr.load_from in ("hf_hub", "ms_hub"): data_path = dataset_attr.dataset_name data_name = dataset_attr.subset data_dir = dataset_attr.folder @@ -53,16 +53,30 @@ def get_dataset( else: raise NotImplementedError - dataset = load_dataset( - path=data_path, - name=data_name, - data_dir=data_dir, - data_files=data_files, - split=data_args.split, - cache_dir=model_args.cache_dir, - token=model_args.hf_hub_token, - streaming=(data_args.streaming and (dataset_attr.load_from != "file")) - ) + if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and dataset_attr.load_from == "ms_hub": + from modelscope import MsDataset + from modelscope.utils.config_ds import MS_DATASETS_CACHE + cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, + split=data_args.split, + data_files=data_files, + cache_dir=cache_dir, + token=model_args.ms_hub_token, + use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + ).to_hf_dataset() + else: + dataset = load_dataset( + path=data_path, + name=data_name, + data_files=data_files, + split=data_args.split, + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + streaming=(data_args.streaming and (dataset_attr.load_from != "file")) + ) if data_args.streaming and (dataset_attr.load_from == "file"): dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index da9be11b..b5ed3f99 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -2,7 +2,9 @@ import os import json from typing import List, Literal, Optional from dataclasses import dataclass, field +from llmtuner.extras.logging import get_logger +logger = get_logger(__name__) DATA_CONFIG = "dataset_info.json" @@ -153,8 +155,17 @@ class DataArguments: if name not in dataset_info: raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) - if "hf_hub_url" in dataset_info[name]: - dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + if "hf_hub_url" in dataset_info[name] or 'ms_hub_url' in dataset_info[name]: + url_key_name = "hf_hub_url" + if int(os.environ.get('USE_MODELSCOPE_HUB', '0')): + if 'ms_hub_url' in dataset_info[name]: + url_key_name = 'ms_hub_url' + else: + logger.warning('You are using ModelScope Hub, but the specified dataset ' + 'has no `ms_hub_url` key, so `hf_hub_url` will be used instead.') + + dataset_attr = DatasetAttr(url_key_name[:url_key_name.index('_url')], + dataset_name=dataset_info[name][url_key_name]) elif "script_url" in dataset_info[name]: dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) else: diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 07903b37..6ba37431 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -59,6 +59,10 @@ class ModelArguments: default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} ) + ms_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with ModelScope Hub."} + ) def __post_init__(self): self.compute_dtype = None