API for accessing frames in order in Implicitron dataset.

Summary: We often want to iterate over frames in the sequence in temporal order. This diff provides the API to do that. `seq_to_idx` should probably be considered to have `protected` visibility.

Reviewed By: davnov134

Differential Revision: D35012121

fbshipit-source-id: 41896672ec35cd62f3ed4be3aa119efd33adada1
This commit is contained in:
Roman Shapovalov 2022-03-24 05:33:25 -07:00 committed by Facebook GitHub Bot
parent 05f656c01f
commit c0bb49b5f6
3 changed files with 56 additions and 37 deletions

View File

@ -18,6 +18,8 @@ from pathlib import Path
from typing import ( from typing import (
ClassVar, ClassVar,
Dict, Dict,
Iterable,
Iterator,
List, List,
Optional, Optional,
Sequence, Sequence,
@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
This means they have a __getitem__ which returns an instance of a FrameData, This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence. which will describe one frame in one sequence.
Members:
seq_to_idx: For each sequence, the indices of its frames.
""" """
# Maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member.
seq_to_idx: Dict[str, List[int]] = field(init=False) seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int: def __len__(self) -> int:
@ -240,6 +242,43 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
def get_eval_batches(self) -> Optional[List[List[int]]]: def get_eval_batches(self) -> Optional[List[List[int]]]:
return None return None
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self.seq_to_idx.keys()
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
"""Returns an iterator over the frame indices in a given sequence.
We attempt to first sort by timestamp (if they are available),
then by frame number.
Args:
seq_name: the name of the sequence.
Returns:
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
where `frame_no` is the index within the sequence, and
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
seq_frame_indices = self.seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
yield from sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
for _, _, idx in self.sequence_frames_in_order(seq_name):
yield idx
class FrameAnnotsEntry(TypedDict): class FrameAnnotsEntry(TypedDict):
subset: Optional[str] subset: Optional[str]

View File

@ -7,7 +7,7 @@
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Iterator, List, Sequence, Tuple from typing import Iterable, Iterator, List, Sequence, Tuple
import numpy as np import numpy as np
from torch.utils.data.sampler import Sampler from torch.utils.data.sampler import Sampler
@ -54,7 +54,7 @@ class SceneBatchSampler(Sampler[List[int]]):
if len(self.images_per_seq_options) < 1: if len(self.images_per_seq_options) < 1:
raise ValueError("n_per_seq_posibilities list cannot be empty") raise ValueError("n_per_seq_posibilities list cannot be empty")
self.seq_names = list(self.dataset.seq_to_idx.keys()) self.seq_names = list(self.dataset.sequence_names())
def __len__(self) -> int: def __len__(self) -> int:
return self.num_batches return self.num_batches
@ -72,9 +72,7 @@ class SceneBatchSampler(Sampler[List[int]]):
if self.sample_consecutive_frames: if self.sample_consecutive_frames:
frame_idx = [] frame_idx = []
for seq in chosen_seq: for seq in chosen_seq:
segment_index = self._build_segment_index( segment_index = self._build_segment_index(seq, n_per_seq)
list(self.dataset.seq_to_idx[seq]), n_per_seq
)
segment, idx = segment_index[np.random.randint(len(segment_index))] segment, idx = segment_index[np.random.randint(len(segment_index))]
if len(segment) <= n_per_seq: if len(segment) <= n_per_seq:
@ -86,7 +84,9 @@ class SceneBatchSampler(Sampler[List[int]]):
else: else:
frame_idx = [ frame_idx = [
_capped_random_choice( _capped_random_choice(
self.dataset.seq_to_idx[seq], n_per_seq, replace=False list(self.dataset.sequence_indices_in_order(seq)),
n_per_seq,
replace=False,
) )
for seq in chosen_seq for seq in chosen_seq
] ]
@ -98,9 +98,7 @@ class SceneBatchSampler(Sampler[List[int]]):
) )
return frame_idx return frame_idx
def _build_segment_index( def _build_segment_index(self, seq: str, size: int) -> List[Tuple[List[int], int]]:
self, seq_frame_indices: List[int], size: int
) -> List[Tuple[List[int], int]]:
""" """
Returns a list of (segment, index) tuples, one per eligible frame, where Returns a list of (segment, index) tuples, one per eligible frame, where
segment is a list of frame indices in the contiguous segment the frame segment is a list of frame indices in the contiguous segment the frame
@ -111,16 +109,14 @@ class SceneBatchSampler(Sampler[List[int]]):
self.consecutive_frames_max_gap > 0 self.consecutive_frames_max_gap > 0
or self.consecutive_frames_max_gap_seconds > 0.0 or self.consecutive_frames_max_gap_seconds > 0.0
): ):
sequence_timestamps = _sort_frames_by_timestamps_then_numbers( segments = self._split_to_segments(
seq_frame_indices, self.dataset self.dataset.sequence_frames_in_order(seq)
) )
# TODO: use new API to access frame numbers / timestamps
segments = self._split_to_segments(sequence_timestamps)
segments = _cull_short_segments(segments, size) segments = _cull_short_segments(segments, size)
if not segments: if not segments:
raise AssertionError("Empty segments after culling") raise AssertionError("Empty segments after culling")
else: else:
segments = [seq_frame_indices] segments = [list(self.dataset.sequence_indices_in_order(seq))]
# build an index of segment for random selection of a pivot frame # build an index of segment for random selection of a pivot frame
segment_index = [ segment_index = [
@ -130,7 +126,7 @@ class SceneBatchSampler(Sampler[List[int]]):
return segment_index return segment_index
def _split_to_segments( def _split_to_segments(
self, sequence_timestamps: List[Tuple[float, int, int]] self, sequence_timestamps: Iterable[Tuple[float, int, int]]
) -> List[List[int]]: ) -> List[List[int]]:
if ( if (
self.consecutive_frames_max_gap <= 0 self.consecutive_frames_max_gap <= 0
@ -144,7 +140,7 @@ class SceneBatchSampler(Sampler[List[int]]):
for ts, no, idx in sequence_timestamps: for ts, no, idx in sequence_timestamps:
if ts <= 0.0 and no <= last_no: if ts <= 0.0 and no <= last_no:
raise AssertionError( raise AssertionError(
"Frames are not ordered in seq_to_idx while timestamps are not given" "Sequence frames are not ordered while timestamps are not given"
) )
if ( if (
@ -161,23 +157,6 @@ class SceneBatchSampler(Sampler[List[int]]):
return segments return segments
def _sort_frames_by_timestamps_then_numbers(
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
) -> List[Tuple[float, int, int]]:
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
We attempt to first sort by timestamp, then by frame number.
Timestamps are coalesced with 0s.
"""
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
return sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]: def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
lengths = [(len(segment), segment) for segment in segments] lengths = [(len(segment), segment) for segment in segments]
max_len, longest_segment = max(lengths) max_len, longest_segment = max(lengths)

View File

@ -9,6 +9,7 @@ import unittest
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@ -18,7 +19,7 @@ class MockFrameAnnotation:
frame_timestamp: float = 0.0 frame_timestamp: float = 0.0
class MockDataset: class MockDataset(ImplicitronDatasetBase):
def __init__(self, num_seq, max_frame_gap=1): def __init__(self, num_seq, max_frame_gap=1):
""" """
Makes a gap of max_frame_gap frame numbers in the middle of each sequence Makes a gap of max_frame_gap frame numbers in the middle of each sequence