Using the new dataset idx API everywhere.

Summary: Using the API from D35012121 everywhere.

Reviewed By: bottler

Differential Revision: D35045870

fbshipit-source-id: dab112b5e04160334859bbe8fa2366344b6e0f70
This commit is contained in:
Roman Shapovalov 2022-03-24 05:33:25 -07:00 committed by Facebook GitHub Bot
parent c0bb49b5f6
commit e2622d79c0
5 changed files with 17 additions and 15 deletions

View File

@ -67,7 +67,7 @@ def render_sequence(
if seed is None:
seed = hash(sequence_name)
print(f"Loading all data of sequence '{sequence_name}'.")
seq_idx = dataset.seq_to_idx[sequence_name]
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
@ -345,7 +345,7 @@ def export_scenes(
dataset = dataset_zoo(**config.dataset_args)[split]
# iterate over the sequences in the dataset
for sequence_name in dataset.seq_to_idx.keys():
for sequence_name in dataset.sequence_names():
with torch.no_grad():
render_sequence(
dataset,

View File

@ -210,7 +210,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
# Maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member.
seq_to_idx: Dict[str, List[int]] = field(init=False)
_seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
raise NotImplementedError
@ -223,7 +223,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
unordered views, then the dataset should override this method to
return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. In addition,
the values in seq_to_idx should be in ascending order.
the values in _seq_to_idx should be in ascending order.
If timestamps are absent, they should be replaced with a constant.
This is used for letting SceneBatchSampler identify consecutive
@ -244,7 +244,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self.seq_to_idx.keys()
return self._seq_to_idx.keys()
def sequence_frames_in_order(
self, seq_name: str
@ -262,7 +262,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
seq_frame_indices = self.seq_to_idx[seq_name]
seq_frame_indices = self._seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
yield from sorted(
@ -411,7 +411,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self.frame_annots[idx]["frame_annotation"].frame_number: idx
for idx in seq_idx
}
for seq, seq_idx in self.seq_to_idx.items()
for seq, seq_idx in self._seq_to_idx.items()
}
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
@ -804,7 +804,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
if self.n_frames_per_sequence > 0:
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = []
for seq, seq_indices in self.seq_to_idx.items():
for seq, seq_indices in self._seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences
seed = _seq_name_to_seed(seq) + self.seed
@ -826,20 +826,20 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self._invalidate_indexes(filter_seq_annots=True)
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
# update seq_to_idx and filter seq_meta according to frame_annots change
# if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx
# update _seq_to_idx and filter seq_meta according to frame_annots change
# if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
self._invalidate_seq_to_idx()
if filter_seq_annots:
self.seq_annots = {
k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
}
def _invalidate_seq_to_idx(self) -> None:
seq_to_idx = defaultdict(list)
for idx, entry in enumerate(self.frame_annots):
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
self.seq_to_idx = seq_to_idx
self._seq_to_idx = seq_to_idx
def _resize_image(
self, image, mode="bilinear"

View File

@ -198,7 +198,7 @@ def _get_all_source_cameras(
"""
# load all source cameras of the sequence
seq_idx = dataset.seq_to_idx[sequence_name]
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
(all_frame_data,) = torch.utils.data.DataLoader(
dataset_for_loader,

View File

@ -25,7 +25,7 @@ class MockDataset(ImplicitronDatasetBase):
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
"""
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
self.seq_to_idx = {
self._seq_to_idx = {
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
}

View File

@ -8,6 +8,7 @@
import contextlib
import copy
import dataclasses
import itertools
import math
import os
import unittest
@ -285,6 +286,7 @@ class TestEvaluation(unittest.TestCase):
def test_full_eval(self, n_sequences=5):
"""Test evaluation."""
for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]:
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
idx = list(self.dataset.sequence_indices_in_order(seq))
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
self._one_sequence_test(seq_dataset)