mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
API for accessing frames in order in Implicitron dataset.
Summary: We often want to iterate over frames in the sequence in temporal order. This diff provides the API to do that. `seq_to_idx` should probably be considered to have `protected` visibility. Reviewed By: davnov134 Differential Revision: D35012121 fbshipit-source-id: 41896672ec35cd62f3ed4be3aa119efd33adada1
This commit is contained in:
parent
05f656c01f
commit
c0bb49b5f6
@ -18,6 +18,8 @@ from pathlib import Path
|
|||||||
from typing import (
|
from typing import (
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
|||||||
|
|
||||||
This means they have a __getitem__ which returns an instance of a FrameData,
|
This means they have a __getitem__ which returns an instance of a FrameData,
|
||||||
which will describe one frame in one sequence.
|
which will describe one frame in one sequence.
|
||||||
|
|
||||||
Members:
|
|
||||||
seq_to_idx: For each sequence, the indices of its frames.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Maps sequence name to the sequence's global frame indices.
|
||||||
|
# It is used for the default implementations of some functions in this class.
|
||||||
|
# Implementations which override them are free to ignore this member.
|
||||||
seq_to_idx: Dict[str, List[int]] = field(init=False)
|
seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@ -240,6 +242,43 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
|||||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def sequence_names(self) -> Iterable[str]:
|
||||||
|
"""Returns an iterator over sequence names in the dataset."""
|
||||||
|
return self.seq_to_idx.keys()
|
||||||
|
|
||||||
|
def sequence_frames_in_order(
|
||||||
|
self, seq_name: str
|
||||||
|
) -> Iterator[Tuple[float, int, int]]:
|
||||||
|
"""Returns an iterator over the frame indices in a given sequence.
|
||||||
|
We attempt to first sort by timestamp (if they are available),
|
||||||
|
then by frame number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_name: the name of the sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||||
|
where `frame_no` is the index within the sequence, and
|
||||||
|
`dataset_idx` is the index within the dataset.
|
||||||
|
`None` timestamps are replaced with 0s.
|
||||||
|
"""
|
||||||
|
seq_frame_indices = self.seq_to_idx[seq_name]
|
||||||
|
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||||
|
|
||||||
|
yield from sorted(
|
||||||
|
[
|
||||||
|
(timestamp, frame_no, idx)
|
||||||
|
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
||||||
|
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||||
|
only dataset indices.
|
||||||
|
"""
|
||||||
|
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||||
|
yield idx
|
||||||
|
|
||||||
|
|
||||||
class FrameAnnotsEntry(TypedDict):
|
class FrameAnnotsEntry(TypedDict):
|
||||||
subset: Optional[str]
|
subset: Optional[str]
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Iterator, List, Sequence, Tuple
|
from typing import Iterable, Iterator, List, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
@ -54,7 +54,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
if len(self.images_per_seq_options) < 1:
|
if len(self.images_per_seq_options) < 1:
|
||||||
raise ValueError("n_per_seq_posibilities list cannot be empty")
|
raise ValueError("n_per_seq_posibilities list cannot be empty")
|
||||||
|
|
||||||
self.seq_names = list(self.dataset.seq_to_idx.keys())
|
self.seq_names = list(self.dataset.sequence_names())
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.num_batches
|
return self.num_batches
|
||||||
@ -72,9 +72,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
if self.sample_consecutive_frames:
|
if self.sample_consecutive_frames:
|
||||||
frame_idx = []
|
frame_idx = []
|
||||||
for seq in chosen_seq:
|
for seq in chosen_seq:
|
||||||
segment_index = self._build_segment_index(
|
segment_index = self._build_segment_index(seq, n_per_seq)
|
||||||
list(self.dataset.seq_to_idx[seq]), n_per_seq
|
|
||||||
)
|
|
||||||
|
|
||||||
segment, idx = segment_index[np.random.randint(len(segment_index))]
|
segment, idx = segment_index[np.random.randint(len(segment_index))]
|
||||||
if len(segment) <= n_per_seq:
|
if len(segment) <= n_per_seq:
|
||||||
@ -86,7 +84,9 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
else:
|
else:
|
||||||
frame_idx = [
|
frame_idx = [
|
||||||
_capped_random_choice(
|
_capped_random_choice(
|
||||||
self.dataset.seq_to_idx[seq], n_per_seq, replace=False
|
list(self.dataset.sequence_indices_in_order(seq)),
|
||||||
|
n_per_seq,
|
||||||
|
replace=False,
|
||||||
)
|
)
|
||||||
for seq in chosen_seq
|
for seq in chosen_seq
|
||||||
]
|
]
|
||||||
@ -98,9 +98,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
)
|
)
|
||||||
return frame_idx
|
return frame_idx
|
||||||
|
|
||||||
def _build_segment_index(
|
def _build_segment_index(self, seq: str, size: int) -> List[Tuple[List[int], int]]:
|
||||||
self, seq_frame_indices: List[int], size: int
|
|
||||||
) -> List[Tuple[List[int], int]]:
|
|
||||||
"""
|
"""
|
||||||
Returns a list of (segment, index) tuples, one per eligible frame, where
|
Returns a list of (segment, index) tuples, one per eligible frame, where
|
||||||
segment is a list of frame indices in the contiguous segment the frame
|
segment is a list of frame indices in the contiguous segment the frame
|
||||||
@ -111,16 +109,14 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
self.consecutive_frames_max_gap > 0
|
self.consecutive_frames_max_gap > 0
|
||||||
or self.consecutive_frames_max_gap_seconds > 0.0
|
or self.consecutive_frames_max_gap_seconds > 0.0
|
||||||
):
|
):
|
||||||
sequence_timestamps = _sort_frames_by_timestamps_then_numbers(
|
segments = self._split_to_segments(
|
||||||
seq_frame_indices, self.dataset
|
self.dataset.sequence_frames_in_order(seq)
|
||||||
)
|
)
|
||||||
# TODO: use new API to access frame numbers / timestamps
|
|
||||||
segments = self._split_to_segments(sequence_timestamps)
|
|
||||||
segments = _cull_short_segments(segments, size)
|
segments = _cull_short_segments(segments, size)
|
||||||
if not segments:
|
if not segments:
|
||||||
raise AssertionError("Empty segments after culling")
|
raise AssertionError("Empty segments after culling")
|
||||||
else:
|
else:
|
||||||
segments = [seq_frame_indices]
|
segments = [list(self.dataset.sequence_indices_in_order(seq))]
|
||||||
|
|
||||||
# build an index of segment for random selection of a pivot frame
|
# build an index of segment for random selection of a pivot frame
|
||||||
segment_index = [
|
segment_index = [
|
||||||
@ -130,7 +126,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
return segment_index
|
return segment_index
|
||||||
|
|
||||||
def _split_to_segments(
|
def _split_to_segments(
|
||||||
self, sequence_timestamps: List[Tuple[float, int, int]]
|
self, sequence_timestamps: Iterable[Tuple[float, int, int]]
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
if (
|
if (
|
||||||
self.consecutive_frames_max_gap <= 0
|
self.consecutive_frames_max_gap <= 0
|
||||||
@ -144,7 +140,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
for ts, no, idx in sequence_timestamps:
|
for ts, no, idx in sequence_timestamps:
|
||||||
if ts <= 0.0 and no <= last_no:
|
if ts <= 0.0 and no <= last_no:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"Frames are not ordered in seq_to_idx while timestamps are not given"
|
"Sequence frames are not ordered while timestamps are not given"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -161,23 +157,6 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def _sort_frames_by_timestamps_then_numbers(
|
|
||||||
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
|
|
||||||
) -> List[Tuple[float, int, int]]:
|
|
||||||
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
|
|
||||||
We attempt to first sort by timestamp, then by frame number.
|
|
||||||
Timestamps are coalesced with 0s.
|
|
||||||
"""
|
|
||||||
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
|
|
||||||
|
|
||||||
return sorted(
|
|
||||||
[
|
|
||||||
(timestamp, frame_no, idx)
|
|
||||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
|
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
|
||||||
lengths = [(len(segment), segment) for segment in segments]
|
lengths = [(len(segment), segment) for segment in segments]
|
||||||
max_len, longest_segment = max(lengths)
|
max_len, longest_segment = max(lengths)
|
||||||
|
@ -9,6 +9,7 @@ import unittest
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
|
||||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||||
|
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ class MockFrameAnnotation:
|
|||||||
frame_timestamp: float = 0.0
|
frame_timestamp: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class MockDataset:
|
class MockDataset(ImplicitronDatasetBase):
|
||||||
def __init__(self, num_seq, max_frame_gap=1):
|
def __init__(self, num_seq, max_frame_gap=1):
|
||||||
"""
|
"""
|
||||||
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
|
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
|
||||||
|
Loading…
x
Reference in New Issue
Block a user