mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
18c6e6fea9
commit
11a6c8a9a0
@ -22,12 +22,15 @@ def get_dataset(
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
||||
|
||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.cache_path)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
if data_args.cache_path is not None:
|
||||
if os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.cache_path)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
elif data_args.streaming:
|
||||
raise ValueError("Turn off dataset streaming to save cache files.")
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
@ -127,9 +127,6 @@ class DataArguments:
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
if self.streaming and self.cache_path:
|
||||
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
|
Loading…
x
Reference in New Issue
Block a user