optimize data loading logic

Former-commit-id: ec1fe1daa98c61f62c753b22847de028b5c5cded
This commit is contained in:
hiyouga 2023-12-20 16:15:41 +08:00
parent 633624dc3c
commit f64be8ee84
2 changed files with 9 additions and 8 deletions

View File

@ -1,7 +1,7 @@
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from llmtuner.data.utils import checksum from llmtuner.data.utils import checksum
from llmtuner.extras.constants import FILEEXT2TYPE from llmtuner.extras.constants import FILEEXT2TYPE
@ -22,6 +22,13 @@ def get_dataset(
max_samples = data_args.max_samples max_samples = data_args.max_samples
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))

View File

@ -3,8 +3,6 @@ import tiktoken
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
from datasets import load_from_disk
from llmtuner.data.template import get_template_and_fix_tokenizer from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
@ -45,11 +43,7 @@ def preprocess_dataset(
template = get_template_and_fix_tokenizer(data_args.template, tokenizer) template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if data_args.cache_path is not None and os.path.exists(data_args.cache_path): if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.") return dataset # already preprocessed
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")