mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Using the new dataset idx API everywhere.
Summary: Using the API from D35012121 everywhere. Reviewed By: bottler Differential Revision: D35045870 fbshipit-source-id: dab112b5e04160334859bbe8fa2366344b6e0f70
This commit is contained in:
parent
c0bb49b5f6
commit
e2622d79c0
@ -67,7 +67,7 @@ def render_sequence(
|
||||
if seed is None:
|
||||
seed = hash(sequence_name)
|
||||
print(f"Loading all data of sequence '{sequence_name}'.")
|
||||
seq_idx = dataset.seq_to_idx[sequence_name]
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
|
||||
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
|
||||
@ -345,7 +345,7 @@ def export_scenes(
|
||||
dataset = dataset_zoo(**config.dataset_args)[split]
|
||||
|
||||
# iterate over the sequences in the dataset
|
||||
for sequence_name in dataset.seq_to_idx.keys():
|
||||
for sequence_name in dataset.sequence_names():
|
||||
with torch.no_grad():
|
||||
render_sequence(
|
||||
dataset,
|
||||
|
@ -210,7 +210,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
# Maps sequence name to the sequence's global frame indices.
|
||||
# It is used for the default implementations of some functions in this class.
|
||||
# Implementations which override them are free to ignore this member.
|
||||
seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
@ -223,7 +223,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
unordered views, then the dataset should override this method to
|
||||
return the index and timestamp in their videos of the frames whose
|
||||
indices are given in `idxs`. In addition,
|
||||
the values in seq_to_idx should be in ascending order.
|
||||
the values in _seq_to_idx should be in ascending order.
|
||||
If timestamps are absent, they should be replaced with a constant.
|
||||
|
||||
This is used for letting SceneBatchSampler identify consecutive
|
||||
@ -244,7 +244,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
return self.seq_to_idx.keys()
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
@ -262,7 +262,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
seq_frame_indices = self.seq_to_idx[seq_name]
|
||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
yield from sorted(
|
||||
@ -411,7 +411,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
||||
for idx in seq_idx
|
||||
}
|
||||
for seq, seq_idx in self.seq_to_idx.items()
|
||||
for seq, seq_idx in self._seq_to_idx.items()
|
||||
}
|
||||
|
||||
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
|
||||
@ -804,7 +804,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
if self.n_frames_per_sequence > 0:
|
||||
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||
keep_idx = []
|
||||
for seq, seq_indices in self.seq_to_idx.items():
|
||||
for seq, seq_indices in self._seq_to_idx.items():
|
||||
# infer the seed from the sequence name, this is reproducible
|
||||
# and makes the selection differ for different sequences
|
||||
seed = _seq_name_to_seed(seq) + self.seed
|
||||
@ -826,20 +826,20 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self._invalidate_indexes(filter_seq_annots=True)
|
||||
|
||||
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
|
||||
# update seq_to_idx and filter seq_meta according to frame_annots change
|
||||
# if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx
|
||||
# update _seq_to_idx and filter seq_meta according to frame_annots change
|
||||
# if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
|
||||
self._invalidate_seq_to_idx()
|
||||
|
||||
if filter_seq_annots:
|
||||
self.seq_annots = {
|
||||
k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx
|
||||
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
|
||||
}
|
||||
|
||||
def _invalidate_seq_to_idx(self) -> None:
|
||||
seq_to_idx = defaultdict(list)
|
||||
for idx, entry in enumerate(self.frame_annots):
|
||||
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
||||
self.seq_to_idx = seq_to_idx
|
||||
self._seq_to_idx = seq_to_idx
|
||||
|
||||
def _resize_image(
|
||||
self, image, mode="bilinear"
|
||||
|
@ -198,7 +198,7 @@ def _get_all_source_cameras(
|
||||
"""
|
||||
|
||||
# load all source cameras of the sequence
|
||||
seq_idx = dataset.seq_to_idx[sequence_name]
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
|
||||
(all_frame_data,) = torch.utils.data.DataLoader(
|
||||
dataset_for_loader,
|
||||
|
@ -25,7 +25,7 @@ class MockDataset(ImplicitronDatasetBase):
|
||||
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
|
||||
"""
|
||||
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
|
||||
self.seq_to_idx = {
|
||||
self._seq_to_idx = {
|
||||
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
@ -285,6 +286,7 @@ class TestEvaluation(unittest.TestCase):
|
||||
|
||||
def test_full_eval(self, n_sequences=5):
|
||||
"""Test evaluation."""
|
||||
for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]:
|
||||
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
|
||||
idx = list(self.dataset.sequence_indices_in_order(seq))
|
||||
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
|
||||
self._one_sequence_test(seq_dataset)
|
||||
|
Loading…
x
Reference in New Issue
Block a user