From e2622d79c0f53d2596ecdf46405a5429170d8091 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Thu, 24 Mar 2022 05:33:25 -0700 Subject: [PATCH] Using the new dataset idx API everywhere. Summary: Using the API from D35012121 everywhere. Reviewed By: bottler Differential Revision: D35045870 fbshipit-source-id: dab112b5e04160334859bbe8fa2366344b6e0f70 --- .../visualize_reconstruction.py | 4 ++-- .../dataset/implicitron_dataset.py | 20 +++++++++---------- pytorch3d/implicitron/eval_demo.py | 2 +- tests/implicitron/test_batch_sampler.py | 2 +- tests/implicitron/test_evaluation.py | 4 +++- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index 28fa9727..51be395f 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -67,7 +67,7 @@ def render_sequence( if seed is None: seed = hash(sequence_name) print(f"Loading all data of sequence '{sequence_name}'.") - seq_idx = dataset.seq_to_idx[sequence_name] + seq_idx = list(dataset.sequence_indices_in_order(sequence_name)) train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers) assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name) sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test" @@ -345,7 +345,7 @@ def export_scenes( dataset = dataset_zoo(**config.dataset_args)[split] # iterate over the sequences in the dataset - for sequence_name in dataset.seq_to_idx.keys(): + for sequence_name in dataset.sequence_names(): with torch.no_grad(): render_sequence( dataset, diff --git a/pytorch3d/implicitron/dataset/implicitron_dataset.py b/pytorch3d/implicitron/dataset/implicitron_dataset.py index 47f4e598..88142b4e 100644 --- a/pytorch3d/implicitron/dataset/implicitron_dataset.py +++ b/pytorch3d/implicitron/dataset/implicitron_dataset.py @@ -210,7 +210,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): # Maps sequence name to the sequence's global frame indices. # It is used for the default implementations of some functions in this class. # Implementations which override them are free to ignore this member. - seq_to_idx: Dict[str, List[int]] = field(init=False) + _seq_to_idx: Dict[str, List[int]] = field(init=False) def __len__(self) -> int: raise NotImplementedError @@ -223,7 +223,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): unordered views, then the dataset should override this method to return the index and timestamp in their videos of the frames whose indices are given in `idxs`. In addition, - the values in seq_to_idx should be in ascending order. + the values in _seq_to_idx should be in ascending order. If timestamps are absent, they should be replaced with a constant. This is used for letting SceneBatchSampler identify consecutive @@ -244,7 +244,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): def sequence_names(self) -> Iterable[str]: """Returns an iterator over sequence names in the dataset.""" - return self.seq_to_idx.keys() + return self._seq_to_idx.keys() def sequence_frames_in_order( self, seq_name: str @@ -262,7 +262,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): `dataset_idx` is the index within the dataset. `None` timestamps are replaced with 0s. """ - seq_frame_indices = self.seq_to_idx[seq_name] + seq_frame_indices = self._seq_to_idx[seq_name] nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices) yield from sorted( @@ -411,7 +411,7 @@ class ImplicitronDataset(ImplicitronDatasetBase): self.frame_annots[idx]["frame_annotation"].frame_number: idx for idx in seq_idx } - for seq, seq_idx in self.seq_to_idx.items() + for seq, seq_idx in self._seq_to_idx.items() } def _get_batch_idx(seq_name, frame_no, path=None) -> int: @@ -804,7 +804,7 @@ class ImplicitronDataset(ImplicitronDatasetBase): if self.n_frames_per_sequence > 0: print(f"Taking max {self.n_frames_per_sequence} per sequence.") keep_idx = [] - for seq, seq_indices in self.seq_to_idx.items(): + for seq, seq_indices in self._seq_to_idx.items(): # infer the seed from the sequence name, this is reproducible # and makes the selection differ for different sequences seed = _seq_name_to_seed(seq) + self.seed @@ -826,20 +826,20 @@ class ImplicitronDataset(ImplicitronDatasetBase): self._invalidate_indexes(filter_seq_annots=True) def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: - # update seq_to_idx and filter seq_meta according to frame_annots change - # if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx + # update _seq_to_idx and filter seq_meta according to frame_annots change + # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx self._invalidate_seq_to_idx() if filter_seq_annots: self.seq_annots = { - k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx + k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx } def _invalidate_seq_to_idx(self) -> None: seq_to_idx = defaultdict(list) for idx, entry in enumerate(self.frame_annots): seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) - self.seq_to_idx = seq_to_idx + self._seq_to_idx = seq_to_idx def _resize_image( self, image, mode="bilinear" diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index d85158c8..150f5a0c 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -198,7 +198,7 @@ def _get_all_source_cameras( """ # load all source cameras of the sequence - seq_idx = dataset.seq_to_idx[sequence_name] + seq_idx = list(dataset.sequence_indices_in_order(sequence_name)) dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx) (all_frame_data,) = torch.utils.data.DataLoader( dataset_for_loader, diff --git a/tests/implicitron/test_batch_sampler.py b/tests/implicitron/test_batch_sampler.py index cdfc6cad..a7025038 100644 --- a/tests/implicitron/test_batch_sampler.py +++ b/tests/implicitron/test_batch_sampler.py @@ -25,7 +25,7 @@ class MockDataset(ImplicitronDatasetBase): Makes a gap of max_frame_gap frame numbers in the middle of each sequence """ self.seq_annots = {f"seq_{i}": None for i in range(num_seq)} - self.seq_to_idx = { + self._seq_to_idx = { f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq) } diff --git a/tests/implicitron/test_evaluation.py b/tests/implicitron/test_evaluation.py index 9d50aff8..e866e4af 100644 --- a/tests/implicitron/test_evaluation.py +++ b/tests/implicitron/test_evaluation.py @@ -8,6 +8,7 @@ import contextlib import copy import dataclasses +import itertools import math import os import unittest @@ -285,6 +286,7 @@ class TestEvaluation(unittest.TestCase): def test_full_eval(self, n_sequences=5): """Test evaluation.""" - for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]: + for seq in itertools.islice(self.dataset.sequence_names(), n_sequences): + idx = list(self.dataset.sequence_indices_in_order(seq)) seq_dataset = torch.utils.data.Subset(self.dataset, idx) self._one_sequence_test(seq_dataset)