mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
tiny fix
Former-commit-id: 352693e2dcc8fc039b5d574e1a5709563929b0ce
This commit is contained in:
parent
566bfad930
commit
f776e738f8
@ -1,8 +1,8 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Literal, Union
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
@ -10,7 +10,7 @@ from .aligner import align_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import checksum
|
||||
from .utils import checksum, merge_dataset
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -111,30 +111,6 @@ def load_single_dataset(
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=training_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
|
||||
|
||||
def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
|
@ -2,12 +2,14 @@ import hashlib
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import TrainingArguments
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
@ -46,8 +48,32 @@ def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label
|
||||
return max_source_len, max_target_len
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=training_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
||||
) -> Dict[str, "Dataset"]:
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
|
Loading…
x
Reference in New Issue
Block a user