mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
optimize data loading logic
Former-commit-id: ec1fe1daa98c61f62c753b22847de028b5c5cded
This commit is contained in:
parent
633624dc3c
commit
f64be8ee84
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||||
|
|
||||||
from llmtuner.data.utils import checksum
|
from llmtuner.data.utils import checksum
|
||||||
from llmtuner.extras.constants import FILEEXT2TYPE
|
from llmtuner.extras.constants import FILEEXT2TYPE
|
||||||
@ -22,6 +22,13 @@ def get_dataset(
|
|||||||
max_samples = data_args.max_samples
|
max_samples = data_args.max_samples
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
||||||
|
|
||||||
|
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||||
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
|
dataset = load_from_disk(data_args.cache_path)
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.to_iterable_dataset()
|
||||||
|
return dataset
|
||||||
|
|
||||||
for dataset_attr in data_args.dataset_list:
|
for dataset_attr in data_args.dataset_list:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
|
|
||||||
|
@ -3,8 +3,6 @@ import tiktoken
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
|
||||||
|
|
||||||
from datasets import load_from_disk
|
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
@ -45,11 +43,7 @@ def preprocess_dataset(
|
|||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||||
|
|
||||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
return dataset # already preprocessed
|
||||||
dataset = load_from_disk(data_args.cache_path)
|
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.to_iterable_dataset()
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user