mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Subsets in dataset iterators
Summary: For the new API, filtering iterators over sequences by subsets is quite helpful. The change is backwards compatible. Reviewed By: bottler Differential Revision: D42739669 fbshipit-source-id: d150a404aeaf42fd04a81304c63a4cba203f897d
This commit is contained in:
		
							parent
							
								
									54eb76d48c
								
							
						
					
					
						commit
						11959e0b24
					
				@ -237,7 +237,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_frame_numbers_and_timestamps(
 | 
			
		||||
        self, idxs: Sequence[int]
 | 
			
		||||
        self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None
 | 
			
		||||
    ) -> List[Tuple[int, float]]:
 | 
			
		||||
        """
 | 
			
		||||
        If the sequences in the dataset are videos rather than
 | 
			
		||||
@ -251,7 +251,9 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
        frames.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            idx: frame index in self
 | 
			
		||||
            idxs: frame index in self
 | 
			
		||||
            subset_filter: If given, an index in idxs is ignored if the
 | 
			
		||||
                corresponding frame is not in any of the named subsets.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tuple of
 | 
			
		||||
@ -291,7 +293,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
        return dict(c2seq)
 | 
			
		||||
 | 
			
		||||
    def sequence_frames_in_order(
 | 
			
		||||
        self, seq_name: str
 | 
			
		||||
        self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
 | 
			
		||||
    ) -> Iterator[Tuple[float, int, int]]:
 | 
			
		||||
        """Returns an iterator over the frame indices in a given sequence.
 | 
			
		||||
        We attempt to first sort by timestamp (if they are available),
 | 
			
		||||
@ -308,7 +310,9 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
        """
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        seq_frame_indices = self._seq_to_idx[seq_name]
 | 
			
		||||
        nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
 | 
			
		||||
        nos_timestamps = self.get_frame_numbers_and_timestamps(
 | 
			
		||||
            seq_frame_indices, subset_filter
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        yield from sorted(
 | 
			
		||||
            [
 | 
			
		||||
@ -317,11 +321,13 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
 | 
			
		||||
    def sequence_indices_in_order(
 | 
			
		||||
        self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
 | 
			
		||||
    ) -> Iterator[int]:
 | 
			
		||||
        """Same as `sequence_frames_in_order` but returns the iterator over
 | 
			
		||||
        only dataset indices.
 | 
			
		||||
        """
 | 
			
		||||
        for _, _, idx in self.sequence_frames_in_order(seq_name):
 | 
			
		||||
        for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
 | 
			
		||||
            yield idx
 | 
			
		||||
 | 
			
		||||
    # frame_data_type is the actual type of frames returned by the dataset.
 | 
			
		||||
 | 
			
		||||
@ -888,10 +888,16 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        return self.path_manager.get_local_path(path)
 | 
			
		||||
 | 
			
		||||
    def get_frame_numbers_and_timestamps(
 | 
			
		||||
        self, idxs: Sequence[int]
 | 
			
		||||
        self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None
 | 
			
		||||
    ) -> List[Tuple[int, float]]:
 | 
			
		||||
        out: List[Tuple[int, float]] = []
 | 
			
		||||
        for idx in idxs:
 | 
			
		||||
            if (
 | 
			
		||||
                subset_filter is not None
 | 
			
		||||
                and self.frame_annots[idx]["subset"] not in subset_filter
 | 
			
		||||
            ):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            # pyre-ignore[16]
 | 
			
		||||
            frame_annotation = self.frame_annots[idx]["frame_annotation"]
 | 
			
		||||
            out.append(
 | 
			
		||||
 | 
			
		||||
@ -40,3 +40,41 @@ class TestDataJsonIndex(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertEqual(len(data_sets.train), 81)
 | 
			
		||||
        self.assertEqual(len(data_sets.val), 102)
 | 
			
		||||
        self.assertEqual(len(data_sets.test), 102)
 | 
			
		||||
 | 
			
		||||
    def test_visitor_subsets(self):
 | 
			
		||||
        args = get_default_args(ImplicitronDataSource)
 | 
			
		||||
        args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
 | 
			
		||||
        dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
			
		||||
        dataset_args.category = "skateboard"
 | 
			
		||||
        dataset_args.dataset_root = "manifold://co3d/tree/extracted"
 | 
			
		||||
        dataset_args.test_restrict_sequence_id = 0
 | 
			
		||||
        dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 1
 | 
			
		||||
 | 
			
		||||
        data_source = ImplicitronDataSource(**args)
 | 
			
		||||
        datasets, _ = data_source.get_datasets_and_dataloaders()
 | 
			
		||||
        dataset = datasets.test
 | 
			
		||||
 | 
			
		||||
        sequences = list(dataset.sequence_names())
 | 
			
		||||
        self.assertEqual(len(sequences), 1)
 | 
			
		||||
        i = 0
 | 
			
		||||
        for seq in sequences:
 | 
			
		||||
            last_ts = float("-Inf")
 | 
			
		||||
            seq_frames = list(dataset.sequence_frames_in_order(seq))
 | 
			
		||||
            self.assertEqual(len(seq_frames), 102)
 | 
			
		||||
            for ts, _, idx in seq_frames:
 | 
			
		||||
                self.assertEqual(i, idx)
 | 
			
		||||
                i += 1
 | 
			
		||||
                self.assertGreaterEqual(ts, last_ts)
 | 
			
		||||
                last_ts = ts
 | 
			
		||||
 | 
			
		||||
            last_ts = float("-Inf")
 | 
			
		||||
            known_frames = list(dataset.sequence_frames_in_order(seq, "test_known"))
 | 
			
		||||
            self.assertEqual(len(known_frames), 81)
 | 
			
		||||
            for ts, _, _ in known_frames:
 | 
			
		||||
                self.assertGreaterEqual(ts, last_ts)
 | 
			
		||||
                last_ts = ts
 | 
			
		||||
 | 
			
		||||
            known_indices = list(dataset.sequence_indices_in_order(seq, "test_known"))
 | 
			
		||||
            self.assertEqual(len(known_indices), 81)
 | 
			
		||||
 | 
			
		||||
            break  # testing only the first sequence
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user