diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index fbfbba49..92c12ac1 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -21,7 +21,7 @@ dataloader_args: - 9 - 10 dataset_args: - dataset_root: ${oc.env:CO3D_DATASET_ROOT}" + dataset_root: ${oc.env:CO3D_DATASET_ROOT} load_point_clouds: false mask_depths: false mask_images: false diff --git a/pytorch3d/implicitron/dataset/dataset_zoo.py b/pytorch3d/implicitron/dataset/dataset_zoo.py index 79a312f8..a32e0a86 100644 --- a/pytorch3d/implicitron/dataset/dataset_zoo.py +++ b/pytorch3d/implicitron/dataset/dataset_zoo.py @@ -146,7 +146,6 @@ def dataset_zoo( "load_point_clouds": load_point_clouds, "mask_images": mask_images, "mask_depths": mask_depths, - "pick_sequence": restrict_sequence_name, "path_manager": path_manager, "frame_annotations_file": frame_file, "sequence_annotations_file": sequence_file, @@ -174,7 +173,10 @@ def dataset_zoo( if not os.path.isfile(batch_indices_path): # The batch indices file does not exist. # Most probably the user has not specified the root folder. - raise ValueError("Please specify a correct dataset_root folder.") + raise ValueError( + f"Looking for batch indices in {batch_indices_path}. " + + "Please specify a correct dataset_root folder." + ) with open(batch_indices_path, "r") as f: eval_batch_index = json.load(f) @@ -208,6 +210,7 @@ def dataset_zoo( train_dataset = ImplicitronDataset( n_frames_per_sequence=n_frames_per_sequence, subsets=set_names_mapping["train"], + pick_sequence=restrict_sequence_name, **common_kwargs, ) if test_on_train: @@ -215,13 +218,15 @@ def dataset_zoo( val_dataset = test_dataset = train_dataset else: val_dataset = ImplicitronDataset( - n_frames_per_sequence=1, + n_frames_per_sequence=-1, subsets=set_names_mapping["val"], + pick_sequence=restrict_sequence_name, **common_kwargs, ) test_dataset = ImplicitronDataset( - n_frames_per_sequence=1, + n_frames_per_sequence=-1, subsets=set_names_mapping["test"], + pick_sequence=restrict_sequence_name, **common_kwargs, ) if len(restrict_sequence_name) > 0: