From cefc0b2f03e9ecdb739479fa0e56061262145774 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 12 Dec 2023 18:33:06 +0800 Subject: [PATCH] fix modelscope data hub Former-commit-id: d5b2c57a356539df9993e4774b856231eca8a6da --- data/dataset_info.json | 19 +++++++++------- src/llmtuner/data/loader.py | 33 ++++++++++++++++------------ src/llmtuner/hparams/data_args.py | 36 +++++++++++++++++++------------ src/llmtuner/model/adapter.py | 4 ++-- src/llmtuner/model/loader.py | 2 -- src/llmtuner/model/parser.py | 12 +++++------ 6 files changed, 60 insertions(+), 46 deletions(-) diff --git a/data/dataset_info.json b/data/dataset_info.json index 02664af7..43cec1e4 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -109,7 +109,8 @@ "ms_hub_url": "AI-ModelScope/CodeAlpaca-20k" }, "alpaca_cot": { - "hf_hub_url": "QingyiSi/Alpaca-CoT" + "hf_hub_url": "QingyiSi/Alpaca-CoT", + "ms_hub_url": "AI-ModelScope/Alpaca-CoT" }, "openorca": { "hf_hub_url": "Open-Orca/OpenOrca", @@ -170,23 +171,23 @@ "hf_hub_url": "HuggingFaceH4/ultrachat_200k", "ms_hub_url": "AI-ModelScope/ultrachat_200k", "columns": { - "prompt": "messages", - "query": "role", - "response": "content" + "messages": "messages", + "role": "role", + "content": "content" }, "formatting": "sharegpt" }, "agent_instruct": { "hf_hub_url": "THUDM/AgentInstruct", + "ms_hub_url": "ZhipuAI/AgentInstruct", "formatting": "sharegpt" }, "lmsys_chat": { "hf_hub_url": "lmsys/lmsys-chat-1m", - "ms_hub_url": "AI-ModelScope/lmsys-chat-1m", "columns": { - "prompt": "conversation", - "query": "role", - "response": "content" + "messages": "conversation", + "role": "role", + "content": "content" }, "formatting": "sharegpt" }, @@ -287,12 +288,14 @@ }, "the_stack": { "hf_hub_url": "bigcode/the-stack", + "ms_hub_url": "AI-ModelScope/the-stack", "columns": { "prompt": "content" } }, "starcoder_python": { "hf_hub_url": "bigcode/starcoderdata", + "ms_hub_url": "AI-ModelScope/starcoderdata", "columns": { "prompt": "content" }, diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 75113cb3..665a6a02 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -25,7 +25,7 @@ def get_dataset( logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None - if dataset_attr.load_from in ("hf_hub", "ms_hub"): + if dataset_attr.load_from in ["hf_hub", "ms_hub"]: data_path = dataset_attr.dataset_name data_name = dataset_attr.subset data_dir = dataset_attr.folder @@ -53,24 +53,29 @@ def get_dataset( else: raise NotImplementedError - if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and dataset_attr.load_from == "ms_hub": - from modelscope import MsDataset - from modelscope.utils.config_ds import MS_DATASETS_CACHE - cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + if dataset_attr.load_from == "ms_hub": + try: + from modelscope import MsDataset # type: ignore + from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore - dataset = MsDataset.load( - dataset_name=data_path, - subset_name=data_name, - split=data_args.split, - data_files=data_files, - cache_dir=cache_dir, - token=model_args.ms_hub_token, - use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), - ).to_hf_dataset() + cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, + data_dir=data_dir, + data_files=data_files, + split=data_args.split, + cache_dir=cache_dir, + token=model_args.ms_hub_token, + use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + ).to_hf_dataset() + except ImportError: + raise ImportError("Please install modelscope via `pip install modelscope -U`") else: dataset = load_dataset( path=data_path, name=data_name, + data_dir=data_dir, data_files=data_files, split=data_args.split, cache_dir=model_args.cache_dir, diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index b5ed3f99..f74c3cf0 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -2,17 +2,19 @@ import os import json from typing import List, Literal, Optional from dataclasses import dataclass, field -from llmtuner.extras.logging import get_logger -logger = get_logger(__name__) DATA_CONFIG = "dataset_info.json" +def use_modelscope() -> bool: + return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0"))) + + @dataclass class DatasetAttr: - load_from: str + load_from: Literal["hf_hub", "ms_hub", "script", "file"] dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None system_prompt: Optional[str] = None @@ -155,19 +157,25 @@ class DataArguments: if name not in dataset_info: raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) - if "hf_hub_url" in dataset_info[name] or 'ms_hub_url' in dataset_info[name]: - url_key_name = "hf_hub_url" - if int(os.environ.get('USE_MODELSCOPE_HUB', '0')): - if 'ms_hub_url' in dataset_info[name]: - url_key_name = 'ms_hub_url' - else: - logger.warning('You are using ModelScope Hub, but the specified dataset ' - 'has no `ms_hub_url` key, so `hf_hub_url` will be used instead.') + has_hf_url = "hf_hub_url" in dataset_info[name] + has_ms_url = "ms_hub_url" in dataset_info[name] - dataset_attr = DatasetAttr(url_key_name[:url_key_name.index('_url')], - dataset_name=dataset_info[name][url_key_name]) + if has_hf_url or has_ms_url: + if (use_modelscope() and has_ms_url) or (not has_hf_url): + dataset_attr = DatasetAttr( + "ms_hub", + dataset_name=dataset_info[name]["ms_hub_url"] + ) + else: + dataset_attr = DatasetAttr( + "hf_hub", + dataset_name=dataset_info[name]["hf_hub_url"] + ) elif "script_url" in dataset_info[name]: - dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + dataset_attr = DatasetAttr( + "script", + dataset_name=dataset_info[name]["script_url"] + ) else: dataset_attr = DatasetAttr( "file", diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 82fa8c7b..72cea444 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -66,8 +66,8 @@ def init_adapter( if model_args.checkpoint_dir is not None: is_mergeable = True - if getattr(model, "quantization_method", None) == "gptq": - assert len(model_args.checkpoint_dir) == 1, "GPTQ quantized model only accepts a single checkpoint." + if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable + assert len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." is_mergeable = False if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 082ee6aa..f635009c 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,4 +1,3 @@ -import os import math import torch from types import MethodType @@ -13,7 +12,6 @@ from transformers import ( PreTrainedModel, PreTrainedTokenizerBase ) -from transformers.models.llama import modeling_llama as LlamaModule from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead diff --git a/src/llmtuner/model/parser.py b/src/llmtuner/model/parser.py index d298996e..611a9eaa 100644 --- a/src/llmtuner/model/parser.py +++ b/src/llmtuner/model/parser.py @@ -44,12 +44,12 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") - if ( - model_args.checkpoint_dir is not None - and len(model_args.checkpoint_dir) != 1 - and finetuning_args.finetuning_type != "lora" - ): - raise ValueError("Multiple checkpoints are only available for LoRA tuning.") + if model_args.checkpoint_dir is not None and len(model_args.checkpoint_dir) != 1: + if finetuning_args.finetuning_type != "lora": + raise ValueError("Multiple checkpoints are only available for LoRA tuning.") + + if model_args.quantization_bit is not None: + raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: