mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
tiny fix
Former-commit-id: 352693e2dcc8fc039b5d574e1a5709563929b0ce
This commit is contained in:
parent
566bfad930
commit
f776e738f8
@ -1,8 +1,8 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, List, Literal, Union
|
from typing import TYPE_CHECKING, Literal, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
from datasets import load_dataset, load_from_disk
|
||||||
|
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@ -10,7 +10,7 @@ from .aligner import align_dataset
|
|||||||
from .parser import get_dataset_list
|
from .parser import get_dataset_list
|
||||||
from .preprocess import get_preprocess_and_print_func
|
from .preprocess import get_preprocess_and_print_func
|
||||||
from .template import get_template_and_fix_tokenizer
|
from .template import get_template_and_fix_tokenizer
|
||||||
from .utils import checksum
|
from .utils import checksum, merge_dataset
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -111,30 +111,6 @@ def load_single_dataset(
|
|||||||
return align_dataset(dataset, dataset_attr, data_args)
|
return align_dataset(dataset, dataset_attr, data_args)
|
||||||
|
|
||||||
|
|
||||||
def merge_dataset(
|
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
|
||||||
data_args: "DataArguments",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
|
||||||
if len(all_datasets) == 1:
|
|
||||||
return all_datasets[0]
|
|
||||||
elif data_args.mix_strategy == "concat":
|
|
||||||
if data_args.streaming:
|
|
||||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
|
||||||
return concatenate_datasets(all_datasets)
|
|
||||||
elif data_args.mix_strategy.startswith("interleave"):
|
|
||||||
if not data_args.streaming:
|
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
|
||||||
return interleave_datasets(
|
|
||||||
datasets=all_datasets,
|
|
||||||
probabilities=data_args.interleave_probs,
|
|
||||||
seed=training_args.seed,
|
|
||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown mixing strategy.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
|
@ -2,12 +2,14 @@ import hashlib
|
|||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from datasets import concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import TrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
@ -46,8 +48,32 @@ def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label
|
|||||||
return max_source_len, max_target_len
|
return max_source_len, max_target_len
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dataset(
|
||||||
|
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
if len(all_datasets) == 1:
|
||||||
|
return all_datasets[0]
|
||||||
|
elif data_args.mix_strategy == "concat":
|
||||||
|
if data_args.streaming:
|
||||||
|
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
return concatenate_datasets(all_datasets)
|
||||||
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
|
if not data_args.streaming:
|
||||||
|
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
return interleave_datasets(
|
||||||
|
datasets=all_datasets,
|
||||||
|
probabilities=data_args.interleave_probs,
|
||||||
|
seed=training_args.seed,
|
||||||
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
|
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
||||||
) -> Dict[str, "Dataset"]:
|
) -> Dict[str, "Dataset"]:
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if data_args.val_size > 1e-6: # Split the dataset
|
if data_args.val_size > 1e-6: # Split the dataset
|
||||||
|
Loading…
x
Reference in New Issue
Block a user