API for accessing frames in order in Implicitron dataset.

Summary: We often want to iterate over frames in the sequence in temporal order. This diff provides the API to do that. `seq_to_idx` should probably be considered to have `protected` visibility.

Reviewed By: davnov134

Differential Revision: D35012121

fbshipit-source-id: 41896672ec35cd62f3ed4be3aa119efd33adada1
This commit is contained in:
Roman Shapovalov 2022-03-24 05:33:25 -07:00 committed by Facebook GitHub Bot
parent 05f656c01f
commit c0bb49b5f6
3 changed files with 56 additions and 37 deletions

View File

@ -18,6 +18,8 @@ from pathlib import Path
from typing import (
ClassVar,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence.
Members:
seq_to_idx: For each sequence, the indices of its frames.
"""
# Maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member.
seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
@ -240,6 +242,43 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
def get_eval_batches(self) -> Optional[List[List[int]]]:
return None
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self.seq_to_idx.keys()
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
"""Returns an iterator over the frame indices in a given sequence.
We attempt to first sort by timestamp (if they are available),
then by frame number.
Args:
seq_name: the name of the sequence.
Returns:
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
where `frame_no` is the index within the sequence, and
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
seq_frame_indices = self.seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
yield from sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
for _, _, idx in self.sequence_frames_in_order(seq_name):
yield idx
class FrameAnnotsEntry(TypedDict):
subset: Optional[str]

View File

@ -7,7 +7,7 @@
import warnings
from dataclasses import dataclass, field
from typing import Iterator, List, Sequence, Tuple
from typing import Iterable, Iterator, List, Sequence, Tuple
import numpy as np
from torch.utils.data.sampler import Sampler
@ -54,7 +54,7 @@ class SceneBatchSampler(Sampler[List[int]]):
if len(self.images_per_seq_options) < 1:
raise ValueError("n_per_seq_posibilities list cannot be empty")
self.seq_names = list(self.dataset.seq_to_idx.keys())
self.seq_names = list(self.dataset.sequence_names())
def __len__(self) -> int:
return self.num_batches
@ -72,9 +72,7 @@ class SceneBatchSampler(Sampler[List[int]]):
if self.sample_consecutive_frames:
frame_idx = []
for seq in chosen_seq:
segment_index = self._build_segment_index(
list(self.dataset.seq_to_idx[seq]), n_per_seq
)
segment_index = self._build_segment_index(seq, n_per_seq)
segment, idx = segment_index[np.random.randint(len(segment_index))]
if len(segment) <= n_per_seq:
@ -86,7 +84,9 @@ class SceneBatchSampler(Sampler[List[int]]):
else:
frame_idx = [
_capped_random_choice(
self.dataset.seq_to_idx[seq], n_per_seq, replace=False
list(self.dataset.sequence_indices_in_order(seq)),
n_per_seq,
replace=False,
)
for seq in chosen_seq
]
@ -98,9 +98,7 @@ class SceneBatchSampler(Sampler[List[int]]):
)
return frame_idx
def _build_segment_index(
self, seq_frame_indices: List[int], size: int
) -> List[Tuple[List[int], int]]:
def _build_segment_index(self, seq: str, size: int) -> List[Tuple[List[int], int]]:
"""
Returns a list of (segment, index) tuples, one per eligible frame, where
segment is a list of frame indices in the contiguous segment the frame
@ -111,16 +109,14 @@ class SceneBatchSampler(Sampler[List[int]]):
self.consecutive_frames_max_gap > 0
or self.consecutive_frames_max_gap_seconds > 0.0
):
sequence_timestamps = _sort_frames_by_timestamps_then_numbers(
seq_frame_indices, self.dataset
segments = self._split_to_segments(
self.dataset.sequence_frames_in_order(seq)
)
# TODO: use new API to access frame numbers / timestamps
segments = self._split_to_segments(sequence_timestamps)
segments = _cull_short_segments(segments, size)
if not segments:
raise AssertionError("Empty segments after culling")
else:
segments = [seq_frame_indices]
segments = [list(self.dataset.sequence_indices_in_order(seq))]
# build an index of segment for random selection of a pivot frame
segment_index = [
@ -130,7 +126,7 @@ class SceneBatchSampler(Sampler[List[int]]):
return segment_index
def _split_to_segments(
self, sequence_timestamps: List[Tuple[float, int, int]]
self, sequence_timestamps: Iterable[Tuple[float, int, int]]
) -> List[List[int]]:
if (
self.consecutive_frames_max_gap <= 0
@ -144,7 +140,7 @@ class SceneBatchSampler(Sampler[List[int]]):
for ts, no, idx in sequence_timestamps:
if ts <= 0.0 and no <= last_no:
raise AssertionError(
"Frames are not ordered in seq_to_idx while timestamps are not given"
"Sequence frames are not ordered while timestamps are not given"
)
if (
@ -161,23 +157,6 @@ class SceneBatchSampler(Sampler[List[int]]):
return segments
def _sort_frames_by_timestamps_then_numbers(
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
) -> List[Tuple[float, int, int]]:
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
We attempt to first sort by timestamp, then by frame number.
Timestamps are coalesced with 0s.
"""
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
return sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
lengths = [(len(segment), segment) for segment in segments]
max_len, longest_segment = max(lengths)

View File

@ -9,6 +9,7 @@ import unittest
from collections import defaultdict
from dataclasses import dataclass
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@ -18,7 +19,7 @@ class MockFrameAnnotation:
frame_timestamp: float = 0.0
class MockDataset:
class MockDataset(ImplicitronDatasetBase):
def __init__(self, num_seq, max_frame_gap=1):
"""
Makes a gap of max_frame_gap frame numbers in the middle of each sequence