From f776e738f8906814e75e309b0123758bd9c3f025 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 11 Mar 2024 00:17:18 +0800 Subject: [PATCH] tiny fix Former-commit-id: 352693e2dcc8fc039b5d574e1a5709563929b0ce --- src/llmtuner/data/loader.py | 30 +++--------------------------- src/llmtuner/data/utils.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 937fdb36..935695ad 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -1,8 +1,8 @@ import inspect import os -from typing import TYPE_CHECKING, List, Literal, Union +from typing import TYPE_CHECKING, Literal, Union -from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk +from datasets import load_dataset, load_from_disk from ..extras.constants import FILEEXT2TYPE from ..extras.logging import get_logger @@ -10,7 +10,7 @@ from .aligner import align_dataset from .parser import get_dataset_list from .preprocess import get_preprocess_and_print_func from .template import get_template_and_fix_tokenizer -from .utils import checksum +from .utils import checksum, merge_dataset if TYPE_CHECKING: @@ -111,30 +111,6 @@ def load_single_dataset( return align_dataset(dataset, dataset_attr, data_args) -def merge_dataset( - all_datasets: List[Union["Dataset", "IterableDataset"]], - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", -) -> Union["Dataset", "IterableDataset"]: - if len(all_datasets) == 1: - return all_datasets[0] - elif data_args.mix_strategy == "concat": - if data_args.streaming: - logger.warning("The samples between different datasets will not be mixed in streaming mode.") - return concatenate_datasets(all_datasets) - elif data_args.mix_strategy.startswith("interleave"): - if not data_args.streaming: - logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") - return interleave_datasets( - datasets=all_datasets, - probabilities=data_args.interleave_probs, - seed=training_args.seed, - stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", - ) - else: - raise ValueError("Unknown mixing strategy.") - - def get_dataset( tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py index 90e3fa81..c0b6d6c2 100644 --- a/src/llmtuner/data/utils.py +++ b/src/llmtuner/data/utils.py @@ -2,12 +2,14 @@ import hashlib from enum import Enum, unique from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from datasets import concatenate_datasets, interleave_datasets + from ..extras.logging import get_logger if TYPE_CHECKING: from datasets import Dataset, IterableDataset - from transformers import TrainingArguments + from transformers import Seq2SeqTrainingArguments from llmtuner.hparams import DataArguments @@ -46,8 +48,32 @@ def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label return max_source_len, max_target_len +def merge_dataset( + all_datasets: List[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + if len(all_datasets) == 1: + return all_datasets[0] + elif data_args.mix_strategy == "concat": + if data_args.streaming: + logger.warning("The samples between different datasets will not be mixed in streaming mode.") + return concatenate_datasets(all_datasets) + elif data_args.mix_strategy.startswith("interleave"): + if not data_args.streaming: + logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") + return interleave_datasets( + datasets=all_datasets, + probabilities=data_args.interleave_probs, + seed=training_args.seed, + stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", + ) + else: + raise ValueError("Unknown mixing strategy.") + + def split_dataset( - dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments" + dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments" ) -> Dict[str, "Dataset"]: if training_args.do_train: if data_args.val_size > 1e-6: # Split the dataset