mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 00:28:10 +08:00
support ms dataset
Former-commit-id: 98638b35dc24045ac17b9b01d08d3a02372acef3
This commit is contained in:
parent
a1ec668b70
commit
596f496f19
@ -24,7 +24,7 @@ def get_dataset(
|
|||||||
for dataset_attr in data_args.dataset_list:
|
for dataset_attr in data_args.dataset_list:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
|
|
||||||
if dataset_attr.load_from == "hf_hub":
|
if dataset_attr.load_from in ("hf_hub", "ms_hub"):
|
||||||
data_path = dataset_attr.dataset_name
|
data_path = dataset_attr.dataset_name
|
||||||
data_name = dataset_attr.subset
|
data_name = dataset_attr.subset
|
||||||
data_files = None
|
data_files = None
|
||||||
@ -53,15 +53,22 @@ def get_dataset(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
dataset = load_dataset(
|
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and dataset_attr.load_from == "ms_hub":
|
||||||
path=data_path,
|
from modelscope import MsDataset
|
||||||
name=data_name,
|
dataset = MsDataset.load(
|
||||||
data_files=data_files,
|
dataset_name=data_path,
|
||||||
split=data_args.split,
|
subset_name=data_name,
|
||||||
cache_dir=model_args.cache_dir,
|
).to_hf_dataset()
|
||||||
token=model_args.hf_hub_token,
|
else:
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
dataset = load_dataset(
|
||||||
)
|
path=data_path,
|
||||||
|
name=data_name,
|
||||||
|
data_files=data_files,
|
||||||
|
split=data_args.split,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
token=model_args.hf_hub_token,
|
||||||
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||||
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"):
|
if data_args.streaming and (dataset_attr.load_from == "file"):
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
@ -2,7 +2,9 @@ import os
|
|||||||
import json
|
import json
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Literal, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
DATA_CONFIG = "dataset_info.json"
|
DATA_CONFIG = "dataset_info.json"
|
||||||
|
|
||||||
@ -152,8 +154,17 @@ class DataArguments:
|
|||||||
if name not in dataset_info:
|
if name not in dataset_info:
|
||||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||||
|
|
||||||
if "hf_hub_url" in dataset_info[name]:
|
if "hf_hub_url" in dataset_info[name] or 'ms_hub_url' in dataset_info[name]:
|
||||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
url_key_name = "hf_hub_url"
|
||||||
|
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')):
|
||||||
|
if 'ms_hub_url' in dataset_info[name]:
|
||||||
|
url_key_name = 'ms_hub_url'
|
||||||
|
else:
|
||||||
|
logger.warning('You are using ModelScope Hub, but the specified dataset '
|
||||||
|
'has no `ms_hub_url` key, so `hf_hub_url` will be used instead.')
|
||||||
|
|
||||||
|
dataset_attr = DatasetAttr(url_key_name[:url_key_name.index('_url')],
|
||||||
|
dataset_name=dataset_info[name][url_key_name])
|
||||||
elif "script_url" in dataset_info[name]:
|
elif "script_url" in dataset_info[name]:
|
||||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user