Using the new dataset idx API everywhere.

Summary: Using the API from D35012121 everywhere.

Reviewed By: bottler

Differential Revision: D35045870

fbshipit-source-id: dab112b5e04160334859bbe8fa2366344b6e0f70
This commit is contained in:
Roman Shapovalov 2022-03-24 05:33:25 -07:00 committed by Facebook GitHub Bot
parent c0bb49b5f6
commit e2622d79c0
5 changed files with 17 additions and 15 deletions

View File

@ -67,7 +67,7 @@ def render_sequence(
if seed is None: if seed is None:
seed = hash(sequence_name) seed = hash(sequence_name)
print(f"Loading all data of sequence '{sequence_name}'.") print(f"Loading all data of sequence '{sequence_name}'.")
seq_idx = dataset.seq_to_idx[sequence_name] seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers) train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name) assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test" sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
@ -345,7 +345,7 @@ def export_scenes(
dataset = dataset_zoo(**config.dataset_args)[split] dataset = dataset_zoo(**config.dataset_args)[split]
# iterate over the sequences in the dataset # iterate over the sequences in the dataset
for sequence_name in dataset.seq_to_idx.keys(): for sequence_name in dataset.sequence_names():
with torch.no_grad(): with torch.no_grad():
render_sequence( render_sequence(
dataset, dataset,

View File

@ -210,7 +210,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
# Maps sequence name to the sequence's global frame indices. # Maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class. # It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member. # Implementations which override them are free to ignore this member.
seq_to_idx: Dict[str, List[int]] = field(init=False) _seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int: def __len__(self) -> int:
raise NotImplementedError raise NotImplementedError
@ -223,7 +223,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
unordered views, then the dataset should override this method to unordered views, then the dataset should override this method to
return the index and timestamp in their videos of the frames whose return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. In addition, indices are given in `idxs`. In addition,
the values in seq_to_idx should be in ascending order. the values in _seq_to_idx should be in ascending order.
If timestamps are absent, they should be replaced with a constant. If timestamps are absent, they should be replaced with a constant.
This is used for letting SceneBatchSampler identify consecutive This is used for letting SceneBatchSampler identify consecutive
@ -244,7 +244,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
def sequence_names(self) -> Iterable[str]: def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset.""" """Returns an iterator over sequence names in the dataset."""
return self.seq_to_idx.keys() return self._seq_to_idx.keys()
def sequence_frames_in_order( def sequence_frames_in_order(
self, seq_name: str self, seq_name: str
@ -262,7 +262,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
`dataset_idx` is the index within the dataset. `dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s. `None` timestamps are replaced with 0s.
""" """
seq_frame_indices = self.seq_to_idx[seq_name] seq_frame_indices = self._seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices) nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
yield from sorted( yield from sorted(
@ -411,7 +411,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self.frame_annots[idx]["frame_annotation"].frame_number: idx self.frame_annots[idx]["frame_annotation"].frame_number: idx
for idx in seq_idx for idx in seq_idx
} }
for seq, seq_idx in self.seq_to_idx.items() for seq, seq_idx in self._seq_to_idx.items()
} }
def _get_batch_idx(seq_name, frame_no, path=None) -> int: def _get_batch_idx(seq_name, frame_no, path=None) -> int:
@ -804,7 +804,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
if self.n_frames_per_sequence > 0: if self.n_frames_per_sequence > 0:
print(f"Taking max {self.n_frames_per_sequence} per sequence.") print(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = [] keep_idx = []
for seq, seq_indices in self.seq_to_idx.items(): for seq, seq_indices in self._seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible # infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences # and makes the selection differ for different sequences
seed = _seq_name_to_seed(seq) + self.seed seed = _seq_name_to_seed(seq) + self.seed
@ -826,20 +826,20 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self._invalidate_indexes(filter_seq_annots=True) self._invalidate_indexes(filter_seq_annots=True)
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
# update seq_to_idx and filter seq_meta according to frame_annots change # update _seq_to_idx and filter seq_meta according to frame_annots change
# if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
self._invalidate_seq_to_idx() self._invalidate_seq_to_idx()
if filter_seq_annots: if filter_seq_annots:
self.seq_annots = { self.seq_annots = {
k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
} }
def _invalidate_seq_to_idx(self) -> None: def _invalidate_seq_to_idx(self) -> None:
seq_to_idx = defaultdict(list) seq_to_idx = defaultdict(list)
for idx, entry in enumerate(self.frame_annots): for idx, entry in enumerate(self.frame_annots):
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
self.seq_to_idx = seq_to_idx self._seq_to_idx = seq_to_idx
def _resize_image( def _resize_image(
self, image, mode="bilinear" self, image, mode="bilinear"

View File

@ -198,7 +198,7 @@ def _get_all_source_cameras(
""" """
# load all source cameras of the sequence # load all source cameras of the sequence
seq_idx = dataset.seq_to_idx[sequence_name] seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx) dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
(all_frame_data,) = torch.utils.data.DataLoader( (all_frame_data,) = torch.utils.data.DataLoader(
dataset_for_loader, dataset_for_loader,

View File

@ -25,7 +25,7 @@ class MockDataset(ImplicitronDatasetBase):
Makes a gap of max_frame_gap frame numbers in the middle of each sequence Makes a gap of max_frame_gap frame numbers in the middle of each sequence
""" """
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)} self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
self.seq_to_idx = { self._seq_to_idx = {
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq) f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
} }

View File

@ -8,6 +8,7 @@
import contextlib import contextlib
import copy import copy
import dataclasses import dataclasses
import itertools
import math import math
import os import os
import unittest import unittest
@ -285,6 +286,7 @@ class TestEvaluation(unittest.TestCase):
def test_full_eval(self, n_sequences=5): def test_full_eval(self, n_sequences=5):
"""Test evaluation.""" """Test evaluation."""
for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]: for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
idx = list(self.dataset.sequence_indices_in_order(seq))
seq_dataset = torch.utils.data.Subset(self.dataset, idx) seq_dataset = torch.utils.data.Subset(self.dataset, idx)
self._one_sequence_test(seq_dataset) self._one_sequence_test(seq_dataset)