Merge pull request #6160 from village-way/pr_dataloader

fix:tokenized_path not None and load_from_disk return Dataset Trigger…
Former-commit-id: cf298468309cd923d830dcaf7a1aa837519faf1e
This commit is contained in:
hoshi-hiyouga 2024-12-04 22:18:19 +08:00 committed by GitHub
commit 9bbeba6323

View File

@ -239,15 +239,19 @@ def get_dataset(
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if isinstance(tokenized_data, DatasetDict):
if "train" in tokenized_data:
dataset_module["train_dataset"] = tokenized_data["train"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"]
else: # Dataset
dataset_module["train_dataset"] = tokenized_data
if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}