mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
API for accessing frames in order in Implicitron dataset.
Summary: We often want to iterate over frames in the sequence in temporal order. This diff provides the API to do that. `seq_to_idx` should probably be considered to have `protected` visibility. Reviewed By: davnov134 Differential Revision: D35012121 fbshipit-source-id: 41896672ec35cd62f3ed4be3aa119efd33adada1
This commit is contained in:
parent
05f656c01f
commit
c0bb49b5f6
@ -18,6 +18,8 @@ from pathlib import Path
|
||||
from typing import (
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
|
||||
This means they have a __getitem__ which returns an instance of a FrameData,
|
||||
which will describe one frame in one sequence.
|
||||
|
||||
Members:
|
||||
seq_to_idx: For each sequence, the indices of its frames.
|
||||
"""
|
||||
|
||||
# Maps sequence name to the sequence's global frame indices.
|
||||
# It is used for the default implementations of some functions in this class.
|
||||
# Implementations which override them are free to ignore this member.
|
||||
seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@ -240,6 +242,43 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return None
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
return self.seq_to_idx.keys()
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
"""Returns an iterator over the frame indices in a given sequence.
|
||||
We attempt to first sort by timestamp (if they are available),
|
||||
then by frame number.
|
||||
|
||||
Args:
|
||||
seq_name: the name of the sequence.
|
||||
|
||||
Returns:
|
||||
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||
where `frame_no` is the index within the sequence, and
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
seq_frame_indices = self.seq_to_idx[seq_name]
|
||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
yield from sorted(
|
||||
[
|
||||
(timestamp, frame_no, idx)
|
||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||
]
|
||||
)
|
||||
|
||||
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||
yield idx
|
||||
|
||||
|
||||
class FrameAnnotsEntry(TypedDict):
|
||||
subset: Optional[str]
|
||||
|
@ -7,7 +7,7 @@
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, List, Sequence, Tuple
|
||||
from typing import Iterable, Iterator, List, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
@ -54,7 +54,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
if len(self.images_per_seq_options) < 1:
|
||||
raise ValueError("n_per_seq_posibilities list cannot be empty")
|
||||
|
||||
self.seq_names = list(self.dataset.seq_to_idx.keys())
|
||||
self.seq_names = list(self.dataset.sequence_names())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_batches
|
||||
@ -72,9 +72,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
if self.sample_consecutive_frames:
|
||||
frame_idx = []
|
||||
for seq in chosen_seq:
|
||||
segment_index = self._build_segment_index(
|
||||
list(self.dataset.seq_to_idx[seq]), n_per_seq
|
||||
)
|
||||
segment_index = self._build_segment_index(seq, n_per_seq)
|
||||
|
||||
segment, idx = segment_index[np.random.randint(len(segment_index))]
|
||||
if len(segment) <= n_per_seq:
|
||||
@ -86,7 +84,9 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
else:
|
||||
frame_idx = [
|
||||
_capped_random_choice(
|
||||
self.dataset.seq_to_idx[seq], n_per_seq, replace=False
|
||||
list(self.dataset.sequence_indices_in_order(seq)),
|
||||
n_per_seq,
|
||||
replace=False,
|
||||
)
|
||||
for seq in chosen_seq
|
||||
]
|
||||
@ -98,9 +98,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
)
|
||||
return frame_idx
|
||||
|
||||
def _build_segment_index(
|
||||
self, seq_frame_indices: List[int], size: int
|
||||
) -> List[Tuple[List[int], int]]:
|
||||
def _build_segment_index(self, seq: str, size: int) -> List[Tuple[List[int], int]]:
|
||||
"""
|
||||
Returns a list of (segment, index) tuples, one per eligible frame, where
|
||||
segment is a list of frame indices in the contiguous segment the frame
|
||||
@ -111,16 +109,14 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
self.consecutive_frames_max_gap > 0
|
||||
or self.consecutive_frames_max_gap_seconds > 0.0
|
||||
):
|
||||
sequence_timestamps = _sort_frames_by_timestamps_then_numbers(
|
||||
seq_frame_indices, self.dataset
|
||||
segments = self._split_to_segments(
|
||||
self.dataset.sequence_frames_in_order(seq)
|
||||
)
|
||||
# TODO: use new API to access frame numbers / timestamps
|
||||
segments = self._split_to_segments(sequence_timestamps)
|
||||
segments = _cull_short_segments(segments, size)
|
||||
if not segments:
|
||||
raise AssertionError("Empty segments after culling")
|
||||
else:
|
||||
segments = [seq_frame_indices]
|
||||
segments = [list(self.dataset.sequence_indices_in_order(seq))]
|
||||
|
||||
# build an index of segment for random selection of a pivot frame
|
||||
segment_index = [
|
||||
@ -130,7 +126,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
return segment_index
|
||||
|
||||
def _split_to_segments(
|
||||
self, sequence_timestamps: List[Tuple[float, int, int]]
|
||||
self, sequence_timestamps: Iterable[Tuple[float, int, int]]
|
||||
) -> List[List[int]]:
|
||||
if (
|
||||
self.consecutive_frames_max_gap <= 0
|
||||
@ -144,7 +140,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
for ts, no, idx in sequence_timestamps:
|
||||
if ts <= 0.0 and no <= last_no:
|
||||
raise AssertionError(
|
||||
"Frames are not ordered in seq_to_idx while timestamps are not given"
|
||||
"Sequence frames are not ordered while timestamps are not given"
|
||||
)
|
||||
|
||||
if (
|
||||
@ -161,23 +157,6 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
return segments
|
||||
|
||||
|
||||
def _sort_frames_by_timestamps_then_numbers(
|
||||
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
|
||||
) -> List[Tuple[float, int, int]]:
|
||||
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
|
||||
We attempt to first sort by timestamp, then by frame number.
|
||||
Timestamps are coalesced with 0s.
|
||||
"""
|
||||
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
return sorted(
|
||||
[
|
||||
(timestamp, frame_no, idx)
|
||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
|
||||
lengths = [(len(segment), segment) for segment in segments]
|
||||
max_len, longest_segment = max(lengths)
|
||||
|
@ -9,6 +9,7 @@ import unittest
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@ -18,7 +19,7 @@ class MockFrameAnnotation:
|
||||
frame_timestamp: float = 0.0
|
||||
|
||||
|
||||
class MockDataset:
|
||||
class MockDataset(ImplicitronDatasetBase):
|
||||
def __init__(self, num_seq, max_frame_gap=1):
|
||||
"""
|
||||
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
|
||||
|
Loading…
x
Reference in New Issue
Block a user