Update loader.py

Former-commit-id: a5b809516e7de1d6d5f4583089fee3028d0db01d
This commit is contained in:
hoshi-hiyouga 2024-07-15 00:50:06 +08:00 committed by GitHub
parent 140b512426
commit 2e9c9471da

View File

@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import os import os
import sys import sys
from typing import TYPE_CHECKING, Literal, Optional, Union, Dict from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np import numpy as np
from datasets import load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger from ..extras.logging import get_logger
@ -27,7 +27,7 @@ from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer, Template from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,13 +35,15 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments, ModelArguments from ..hparams import DataArguments, ModelArguments
from .data_utils import DatasetModule
from .parser import DatasetAttr from .parser import DatasetAttr
from .template import Template
logger = get_logger(__name__) logger = get_logger(__name__)
def load_single_dataset( def _load_single_dataset(
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
@ -81,7 +83,7 @@ def load_single_dataset(
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
try: require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE
@ -98,14 +100,7 @@ def load_single_dataset(
) )
if isinstance(dataset, MsDataset): if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else: else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
name=data_name, name=data_name,
@ -115,7 +110,7 @@ def load_single_dataset(
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs, trust_remote_code=True,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
@ -140,53 +135,41 @@ def load_single_dataset(
return align_dataset(dataset, dataset_attr, data_args, training_args) return align_dataset(dataset, dataset_attr, data_args, training_args)
def load_and_preprocess( def _get_merged_dataset(
dataset_names: Optional[Sequence[str]],
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer", ) -> Optional[Union["Dataset", "IterableDataset"]]:
template: "Template", if dataset_names is None:
processor: Optional["ProcessorMixin"] = None, return None
is_eval: bool = False
) -> Union["Dataset", "IterableDataset"]:
if not is_eval and data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming: datasets = []
raise ValueError("Turn off `streaming` when saving dataset to disk.") for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
if is_eval and data_args.eval_tokenized_path is not None:
if has_tokenized_data(data_args.eval_tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.eval_tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.eval_tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args, data_args.eval_dataset if is_eval else data_args.dataset):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.") raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args)) datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
dataset = merge_dataset(all_datasets, data_args, training_args) return merge_dataset(datasets, data_args, seed=training_args.seed)
def _get_preprocessed_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
if dataset is None:
return None
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func( preprocess_func, print_function = get_preprocess_and_print_func(
data_args, training_args, stage, template, tokenizer, processor data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
) )
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}
@ -199,23 +182,9 @@ def load_and_preprocess(
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if not is_eval and data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
sys.exit(0)
if is_eval and data_args.eval_tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.eval_tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.eval_tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.eval_tokenized_path))
sys.exit(0)
if training_args.should_log: if training_args.should_log:
try: try:
print("eval example:" if is_eval else "training example:")
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
except StopIteration: except StopIteration:
if stage == "pt": if stage == "pt":
@ -232,16 +201,76 @@ def get_dataset(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None processor: Optional["ProcessorMixin"] = None,
) -> Dict[str, "Dataset"]: ) -> "DatasetModule":
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
train_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor) # Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.eval_dataset or data_args.eval_tokenized_path: dataset_module: Dict[str, "Dataset"] = {}
eval_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor, True) if "train" in dataset_dict:
return {"train_dataset": train_dataset, "eval_dataset": eval_dataset} dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
return dataset_module
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
# Load and preprocess dataset
with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
with training_args.main_process_first(desc="pre-process dataset"):
dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
if data_args.val_size > 1e-6:
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
else: else:
return split_dataset(train_dataset, data_args, training_args) dataset_dict = {}
if dataset is not None:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
dataset_dict["train"] = dataset
if eval_dataset is not None:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
dataset_dict["validation"] = eval_dataset
dataset_dict = DatasetDict(dataset_dict)
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
sys.exit(0)
dataset_module = {}
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
return dataset_module