mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
Using the new dataset idx API everywhere.
Summary: Using the API from D35012121 everywhere. Reviewed By: bottler Differential Revision: D35045870 fbshipit-source-id: dab112b5e04160334859bbe8fa2366344b6e0f70
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c0bb49b5f6
commit
e2622d79c0
@@ -210,7 +210,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
# Maps sequence name to the sequence's global frame indices.
|
||||
# It is used for the default implementations of some functions in this class.
|
||||
# Implementations which override them are free to ignore this member.
|
||||
seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
@@ -223,7 +223,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
unordered views, then the dataset should override this method to
|
||||
return the index and timestamp in their videos of the frames whose
|
||||
indices are given in `idxs`. In addition,
|
||||
the values in seq_to_idx should be in ascending order.
|
||||
the values in _seq_to_idx should be in ascending order.
|
||||
If timestamps are absent, they should be replaced with a constant.
|
||||
|
||||
This is used for letting SceneBatchSampler identify consecutive
|
||||
@@ -244,7 +244,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
return self.seq_to_idx.keys()
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
@@ -262,7 +262,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
seq_frame_indices = self.seq_to_idx[seq_name]
|
||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
yield from sorted(
|
||||
@@ -411,7 +411,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
||||
for idx in seq_idx
|
||||
}
|
||||
for seq, seq_idx in self.seq_to_idx.items()
|
||||
for seq, seq_idx in self._seq_to_idx.items()
|
||||
}
|
||||
|
||||
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
|
||||
@@ -804,7 +804,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
if self.n_frames_per_sequence > 0:
|
||||
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||
keep_idx = []
|
||||
for seq, seq_indices in self.seq_to_idx.items():
|
||||
for seq, seq_indices in self._seq_to_idx.items():
|
||||
# infer the seed from the sequence name, this is reproducible
|
||||
# and makes the selection differ for different sequences
|
||||
seed = _seq_name_to_seed(seq) + self.seed
|
||||
@@ -826,20 +826,20 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self._invalidate_indexes(filter_seq_annots=True)
|
||||
|
||||
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
|
||||
# update seq_to_idx and filter seq_meta according to frame_annots change
|
||||
# if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx
|
||||
# update _seq_to_idx and filter seq_meta according to frame_annots change
|
||||
# if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
|
||||
self._invalidate_seq_to_idx()
|
||||
|
||||
if filter_seq_annots:
|
||||
self.seq_annots = {
|
||||
k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx
|
||||
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
|
||||
}
|
||||
|
||||
def _invalidate_seq_to_idx(self) -> None:
|
||||
seq_to_idx = defaultdict(list)
|
||||
for idx, entry in enumerate(self.frame_annots):
|
||||
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
||||
self.seq_to_idx = seq_to_idx
|
||||
self._seq_to_idx = seq_to_idx
|
||||
|
||||
def _resize_image(
|
||||
self, image, mode="bilinear"
|
||||
|
||||
@@ -198,7 +198,7 @@ def _get_all_source_cameras(
|
||||
"""
|
||||
|
||||
# load all source cameras of the sequence
|
||||
seq_idx = dataset.seq_to_idx[sequence_name]
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
|
||||
(all_frame_data,) = torch.utils.data.DataLoader(
|
||||
dataset_for_loader,
|
||||
|
||||
Reference in New Issue
Block a user