mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
Merge pull request #6160 from village-way/pr_dataloader
fix:tokenized_path not None and load_from_disk return Dataset Trigger… Former-commit-id: cf298468309cd923d830dcaf7a1aa837519faf1e
This commit is contained in:
commit
9bbeba6323
@ -239,15 +239,19 @@ def get_dataset(
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||
tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
|
||||
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||
|
||||
dataset_module: Dict[str, "Dataset"] = {}
|
||||
if "train" in dataset_dict:
|
||||
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||
if isinstance(tokenized_data, DatasetDict):
|
||||
if "train" in tokenized_data:
|
||||
dataset_module["train_dataset"] = tokenized_data["train"]
|
||||
|
||||
if "validation" in dataset_dict:
|
||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||
if "validation" in tokenized_data:
|
||||
dataset_module["eval_dataset"] = tokenized_data["validation"]
|
||||
|
||||
else: # Dataset
|
||||
dataset_module["train_dataset"] = tokenized_data
|
||||
|
||||
if data_args.streaming:
|
||||
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||
|
Loading…
x
Reference in New Issue
Block a user