From c0bb49b5f6f87516e694ca1aed9acc07beb2366c Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Thu, 24 Mar 2022 05:33:25 -0700 Subject: [PATCH] API for accessing frames in order in Implicitron dataset. Summary: We often want to iterate over frames in the sequence in temporal order. This diff provides the API to do that. `seq_to_idx` should probably be considered to have `protected` visibility. Reviewed By: davnov134 Differential Revision: D35012121 fbshipit-source-id: 41896672ec35cd62f3ed4be3aa119efd33adada1 --- .../dataset/implicitron_dataset.py | 45 +++++++++++++++++-- .../dataset/scene_batch_sampler.py | 45 +++++-------------- tests/implicitron/test_batch_sampler.py | 3 +- 3 files changed, 56 insertions(+), 37 deletions(-) diff --git a/pytorch3d/implicitron/dataset/implicitron_dataset.py b/pytorch3d/implicitron/dataset/implicitron_dataset.py index c397ff6f..47f4e598 100644 --- a/pytorch3d/implicitron/dataset/implicitron_dataset.py +++ b/pytorch3d/implicitron/dataset/implicitron_dataset.py @@ -18,6 +18,8 @@ from pathlib import Path from typing import ( ClassVar, Dict, + Iterable, + Iterator, List, Optional, Sequence, @@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): This means they have a __getitem__ which returns an instance of a FrameData, which will describe one frame in one sequence. - - Members: - seq_to_idx: For each sequence, the indices of its frames. """ + # Maps sequence name to the sequence's global frame indices. + # It is used for the default implementations of some functions in this class. + # Implementations which override them are free to ignore this member. seq_to_idx: Dict[str, List[int]] = field(init=False) def __len__(self) -> int: @@ -240,6 +242,43 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): def get_eval_batches(self) -> Optional[List[List[int]]]: return None + def sequence_names(self) -> Iterable[str]: + """Returns an iterator over sequence names in the dataset.""" + return self.seq_to_idx.keys() + + def sequence_frames_in_order( + self, seq_name: str + ) -> Iterator[Tuple[float, int, int]]: + """Returns an iterator over the frame indices in a given sequence. + We attempt to first sort by timestamp (if they are available), + then by frame number. + + Args: + seq_name: the name of the sequence. + + Returns: + an iterator over triplets `(timestamp, frame_no, dataset_idx)`, + where `frame_no` is the index within the sequence, and + `dataset_idx` is the index within the dataset. + `None` timestamps are replaced with 0s. + """ + seq_frame_indices = self.seq_to_idx[seq_name] + nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices) + + yield from sorted( + [ + (timestamp, frame_no, idx) + for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps) + ] + ) + + def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]: + """Same as `sequence_frames_in_order` but returns the iterator over + only dataset indices. + """ + for _, _, idx in self.sequence_frames_in_order(seq_name): + yield idx + class FrameAnnotsEntry(TypedDict): subset: Optional[str] diff --git a/pytorch3d/implicitron/dataset/scene_batch_sampler.py b/pytorch3d/implicitron/dataset/scene_batch_sampler.py index 29588079..fea6f19d 100644 --- a/pytorch3d/implicitron/dataset/scene_batch_sampler.py +++ b/pytorch3d/implicitron/dataset/scene_batch_sampler.py @@ -7,7 +7,7 @@ import warnings from dataclasses import dataclass, field -from typing import Iterator, List, Sequence, Tuple +from typing import Iterable, Iterator, List, Sequence, Tuple import numpy as np from torch.utils.data.sampler import Sampler @@ -54,7 +54,7 @@ class SceneBatchSampler(Sampler[List[int]]): if len(self.images_per_seq_options) < 1: raise ValueError("n_per_seq_posibilities list cannot be empty") - self.seq_names = list(self.dataset.seq_to_idx.keys()) + self.seq_names = list(self.dataset.sequence_names()) def __len__(self) -> int: return self.num_batches @@ -72,9 +72,7 @@ class SceneBatchSampler(Sampler[List[int]]): if self.sample_consecutive_frames: frame_idx = [] for seq in chosen_seq: - segment_index = self._build_segment_index( - list(self.dataset.seq_to_idx[seq]), n_per_seq - ) + segment_index = self._build_segment_index(seq, n_per_seq) segment, idx = segment_index[np.random.randint(len(segment_index))] if len(segment) <= n_per_seq: @@ -86,7 +84,9 @@ class SceneBatchSampler(Sampler[List[int]]): else: frame_idx = [ _capped_random_choice( - self.dataset.seq_to_idx[seq], n_per_seq, replace=False + list(self.dataset.sequence_indices_in_order(seq)), + n_per_seq, + replace=False, ) for seq in chosen_seq ] @@ -98,9 +98,7 @@ class SceneBatchSampler(Sampler[List[int]]): ) return frame_idx - def _build_segment_index( - self, seq_frame_indices: List[int], size: int - ) -> List[Tuple[List[int], int]]: + def _build_segment_index(self, seq: str, size: int) -> List[Tuple[List[int], int]]: """ Returns a list of (segment, index) tuples, one per eligible frame, where segment is a list of frame indices in the contiguous segment the frame @@ -111,16 +109,14 @@ class SceneBatchSampler(Sampler[List[int]]): self.consecutive_frames_max_gap > 0 or self.consecutive_frames_max_gap_seconds > 0.0 ): - sequence_timestamps = _sort_frames_by_timestamps_then_numbers( - seq_frame_indices, self.dataset + segments = self._split_to_segments( + self.dataset.sequence_frames_in_order(seq) ) - # TODO: use new API to access frame numbers / timestamps - segments = self._split_to_segments(sequence_timestamps) segments = _cull_short_segments(segments, size) if not segments: raise AssertionError("Empty segments after culling") else: - segments = [seq_frame_indices] + segments = [list(self.dataset.sequence_indices_in_order(seq))] # build an index of segment for random selection of a pivot frame segment_index = [ @@ -130,7 +126,7 @@ class SceneBatchSampler(Sampler[List[int]]): return segment_index def _split_to_segments( - self, sequence_timestamps: List[Tuple[float, int, int]] + self, sequence_timestamps: Iterable[Tuple[float, int, int]] ) -> List[List[int]]: if ( self.consecutive_frames_max_gap <= 0 @@ -144,7 +140,7 @@ class SceneBatchSampler(Sampler[List[int]]): for ts, no, idx in sequence_timestamps: if ts <= 0.0 and no <= last_no: raise AssertionError( - "Frames are not ordered in seq_to_idx while timestamps are not given" + "Sequence frames are not ordered while timestamps are not given" ) if ( @@ -161,23 +157,6 @@ class SceneBatchSampler(Sampler[List[int]]): return segments -def _sort_frames_by_timestamps_then_numbers( - seq_frame_indices: List[int], dataset: ImplicitronDatasetBase -) -> List[Tuple[float, int, int]]: - """Build the list of triplets (timestamp, frame_no, dataset_idx). - We attempt to first sort by timestamp, then by frame number. - Timestamps are coalesced with 0s. - """ - nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices) - - return sorted( - [ - (timestamp, frame_no, idx) - for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps) - ] - ) - - def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]: lengths = [(len(segment), segment) for segment in segments] max_len, longest_segment = max(lengths) diff --git a/tests/implicitron/test_batch_sampler.py b/tests/implicitron/test_batch_sampler.py index a2ae074a..cdfc6cad 100644 --- a/tests/implicitron/test_batch_sampler.py +++ b/tests/implicitron/test_batch_sampler.py @@ -9,6 +9,7 @@ import unittest from collections import defaultdict from dataclasses import dataclass +from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler @@ -18,7 +19,7 @@ class MockFrameAnnotation: frame_timestamp: float = 0.0 -class MockDataset: +class MockDataset(ImplicitronDatasetBase): def __init__(self, num_seq, max_frame_gap=1): """ Makes a gap of max_frame_gap frame numbers in the middle of each sequence