diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 802d04e3..283ef3dc 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -237,7 +237,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): raise NotImplementedError() def get_frame_numbers_and_timestamps( - self, idxs: Sequence[int] + self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None ) -> List[Tuple[int, float]]: """ If the sequences in the dataset are videos rather than @@ -251,7 +251,9 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): frames. Args: - idx: frame index in self + idxs: frame index in self + subset_filter: If given, an index in idxs is ignored if the + corresponding frame is not in any of the named subsets. Returns: tuple of @@ -291,7 +293,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): return dict(c2seq) def sequence_frames_in_order( - self, seq_name: str + self, seq_name: str, subset_filter: Optional[Sequence[str]] = None ) -> Iterator[Tuple[float, int, int]]: """Returns an iterator over the frame indices in a given sequence. We attempt to first sort by timestamp (if they are available), @@ -308,7 +310,9 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): """ # pyre-ignore[16] seq_frame_indices = self._seq_to_idx[seq_name] - nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices) + nos_timestamps = self.get_frame_numbers_and_timestamps( + seq_frame_indices, subset_filter + ) yield from sorted( [ @@ -317,11 +321,13 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): ] ) - def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]: + def sequence_indices_in_order( + self, seq_name: str, subset_filter: Optional[Sequence[str]] = None + ) -> Iterator[int]: """Same as `sequence_frames_in_order` but returns the iterator over only dataset indices. """ - for _, _, idx in self.sequence_frames_in_order(seq_name): + for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter): yield idx # frame_data_type is the actual type of frames returned by the dataset. diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 2fdab768..efa6bfba 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -888,10 +888,16 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): return self.path_manager.get_local_path(path) def get_frame_numbers_and_timestamps( - self, idxs: Sequence[int] + self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None ) -> List[Tuple[int, float]]: out: List[Tuple[int, float]] = [] for idx in idxs: + if ( + subset_filter is not None + and self.frame_annots[idx]["subset"] not in subset_filter + ): + continue + # pyre-ignore[16] frame_annotation = self.frame_annots[idx]["frame_annotation"] out.append( diff --git a/tests/implicitron/test_data_json_index.py b/tests/implicitron/test_data_json_index.py index 0cd77c82..e11e7b44 100644 --- a/tests/implicitron/test_data_json_index.py +++ b/tests/implicitron/test_data_json_index.py @@ -40,3 +40,41 @@ class TestDataJsonIndex(TestCaseMixin, unittest.TestCase): self.assertEqual(len(data_sets.train), 81) self.assertEqual(len(data_sets.val), 102) self.assertEqual(len(data_sets.test), 102) + + def test_visitor_subsets(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider" + dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args + dataset_args.category = "skateboard" + dataset_args.dataset_root = "manifold://co3d/tree/extracted" + dataset_args.test_restrict_sequence_id = 0 + dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 1 + + data_source = ImplicitronDataSource(**args) + datasets, _ = data_source.get_datasets_and_dataloaders() + dataset = datasets.test + + sequences = list(dataset.sequence_names()) + self.assertEqual(len(sequences), 1) + i = 0 + for seq in sequences: + last_ts = float("-Inf") + seq_frames = list(dataset.sequence_frames_in_order(seq)) + self.assertEqual(len(seq_frames), 102) + for ts, _, idx in seq_frames: + self.assertEqual(i, idx) + i += 1 + self.assertGreaterEqual(ts, last_ts) + last_ts = ts + + last_ts = float("-Inf") + known_frames = list(dataset.sequence_frames_in_order(seq, "test_known")) + self.assertEqual(len(known_frames), 81) + for ts, _, _ in known_frames: + self.assertGreaterEqual(ts, last_ts) + last_ts = ts + + known_indices = list(dataset.sequence_indices_in_order(seq, "test_known")) + self.assertEqual(len(known_indices), 81) + + break # testing only the first sequence