From c523613f0ae26f08147f098b32b8639a767af0ed Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Fri, 8 Dec 2023 18:00:57 +0800 Subject: [PATCH] support ms dataset Former-commit-id: 9c2247d700763f480d88a5dd46480cb32cfc174e --- data/dataset_info.json | 33 ++++++++++++++++++++++--------- src/llmtuner/data/loader.py | 27 +++++++++++++++---------- src/llmtuner/hparams/data_args.py | 15 ++++++++++++-- 3 files changed, 54 insertions(+), 21 deletions(-) diff --git a/data/dataset_info.json b/data/dataset_info.json index 2b3f4eb7..60235d47 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -69,22 +69,28 @@ } }, "guanaco": { - "hf_hub_url": "JosephusCheung/GuanacoDataset" + "hf_hub_url": "JosephusCheung/GuanacoDataset", + "ms_hub_yrl": "wyj123456/GuanacoDataset" }, "belle_2m": { - "hf_hub_url": "BelleGroup/train_2M_CN" + "hf_hub_url": "BelleGroup/train_2M_CN", + "ms_hub_yrl": "AI-ModelScope/train_2M_CN" }, "belle_1m": { - "hf_hub_url": "BelleGroup/train_1M_CN" + "hf_hub_url": "BelleGroup/train_1M_CN", + "ms_hub_yrl": "AI-ModelScope/train_1M_CN" }, "belle_0.5m": { - "hf_hub_url": "BelleGroup/train_0.5M_CN" + "hf_hub_url": "BelleGroup/train_0.5M_CN", + "ms_hub_yrl": "AI-ModelScope/train_0.5M_CN" }, "belle_dialog": { - "hf_hub_url": "BelleGroup/generated_chat_0.4M" + "hf_hub_url": "BelleGroup/generated_chat_0.4M", + "ms_hub_yrl": "AI-ModelScope/generated_chat_0.4M" }, "belle_math": { - "hf_hub_url": "BelleGroup/school_math_0.25M" + "hf_hub_url": "BelleGroup/school_math_0.25M", + "ms_hub_yrl": "AI-ModelScope/school_math_0.25M" }, "belle_multiturn": { "script_url": "belle_multiturn", @@ -95,10 +101,12 @@ "formatting": "sharegpt" }, "open_platypus": { - "hf_hub_url": "garage-bAInd/Open-Platypus" + "hf_hub_url": "garage-bAInd/Open-Platypus", + "ms_hub_yrl": "AI-ModelScope/Open-Platypus" }, "codealpaca": { - "hf_hub_url": "sahil2801/CodeAlpaca-20k" + "hf_hub_url": "sahil2801/CodeAlpaca-20k", + "ms_hub_yrl": "AI-ModelScope/CodeAlpaca-20k" }, "alpaca_cot": { "hf_hub_url": "QingyiSi/Alpaca-CoT" @@ -112,6 +120,7 @@ }, "mathinstruct": { "hf_hub_url": "TIGER-Lab/MathInstruct", + "ms_hub_yrl": "AI-ModelScope/MathInstruct", "columns": { "prompt": "instruction", "response": "output" @@ -126,13 +135,15 @@ }, "webqa": { "hf_hub_url": "suolyer/webqa", + "ms_hub_yrl": "AI-ModelScope/webqa", "columns": { "prompt": "input", "response": "output" } }, "webnovel": { - "hf_hub_url": "zxbsmk/webnovel_cn" + "hf_hub_url": "zxbsmk/webnovel_cn", + "ms_hub_yrl": "AI-ModelScope/webnovel_cn" }, "nectar_sft": { "hf_hub_url": "mlinmg/SFT-Nectar" @@ -146,10 +157,12 @@ }, "sharegpt_hyper": { "hf_hub_url": "totally-not-an-llm/sharegpt-hyperfiltered-3k", + "ms_hub_yrl": "AI-ModelScope/sharegpt-hyperfiltered-3k", "formatting": "sharegpt" }, "sharegpt4": { "hf_hub_url": "shibing624/sharegpt_gpt4", + "ms_hub_yrl": "AI-ModelScope/sharegpt_gpt4", "formatting": "sharegpt" }, "ultrachat_200k": { @@ -176,6 +189,7 @@ }, "evol_instruct": { "hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k", + "ms_hub_yrl": "AI-ModelScope/WizardLM_evol_instruct_V2_196k", "formatting": "sharegpt" }, "hh_rlhf_en": { @@ -251,6 +265,7 @@ }, "wikipedia_zh": { "hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered", + "ms_hub_yrl": "AI-ModelScope/wikipedia-cn-20230720-filtered", "columns": { "prompt": "completion" } diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 8e9053ca..41c12422 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -24,7 +24,7 @@ def get_dataset( for dataset_attr in data_args.dataset_list: logger.info("Loading dataset {}...".format(dataset_attr)) - if dataset_attr.load_from == "hf_hub": + if dataset_attr.load_from in ("hf_hub", "ms_hub"): data_path = dataset_attr.dataset_name data_name = dataset_attr.subset data_files = None @@ -53,15 +53,22 @@ def get_dataset( else: raise NotImplementedError - dataset = load_dataset( - path=data_path, - name=data_name, - data_files=data_files, - split=data_args.split, - cache_dir=model_args.cache_dir, - token=model_args.hf_hub_token, - streaming=(data_args.streaming and (dataset_attr.load_from != "file")) - ) + if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and dataset_attr.load_from == "ms_hub": + from modelscope import MsDataset + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, + ).to_hf_dataset() + else: + dataset = load_dataset( + path=data_path, + name=data_name, + data_files=data_files, + split=data_args.split, + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + streaming=(data_args.streaming and (dataset_attr.load_from != "file")) + ) if data_args.streaming and (dataset_attr.load_from == "file"): dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index cea89198..6f8bb738 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -2,7 +2,9 @@ import os import json from typing import List, Literal, Optional from dataclasses import dataclass, field +from llmtuner.extras.logging import get_logger +logger = get_logger(__name__) DATA_CONFIG = "dataset_info.json" @@ -152,8 +154,17 @@ class DataArguments: if name not in dataset_info: raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) - if "hf_hub_url" in dataset_info[name]: - dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + if "hf_hub_url" in dataset_info[name] or 'ms_hub_url' in dataset_info[name]: + url_key_name = "hf_hub_url" + if int(os.environ.get('USE_MODELSCOPE_HUB', '0')): + if 'ms_hub_url' in dataset_info[name]: + url_key_name = 'ms_hub_url' + else: + logger.warning('You are using ModelScope Hub, but the specified dataset ' + 'has no `ms_hub_url` key, so `hf_hub_url` will be used instead.') + + dataset_attr = DatasetAttr(url_key_name[:url_key_name.index('_url')], + dataset_name=dataset_info[name][url_key_name]) elif "script_url" in dataset_info[name]: dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) else: