Former-commit-id: 352693e2dcc8fc039b5d574e1a5709563929b0ce
This commit is contained in:
hiyouga 2024-03-11 00:17:18 +08:00
parent 566bfad930
commit f776e738f8
2 changed files with 31 additions and 29 deletions

View File

@ -1,8 +1,8 @@
import inspect
import os
from typing import TYPE_CHECKING, List, Literal, Union
from typing import TYPE_CHECKING, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
@ -10,7 +10,7 @@ from .aligner import align_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum
from .utils import checksum, merge_dataset
if TYPE_CHECKING:
@ -111,30 +111,6 @@ def load_single_dataset(
return align_dataset(dataset, dataset_attr, data_args)
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",

View File

@ -2,12 +2,14 @@ import hashlib
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from datasets import concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from transformers import Seq2SeqTrainingArguments
from llmtuner.hparams import DataArguments
@ -46,8 +48,32 @@ def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label
return max_source_len, max_target_len
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset