diff --git a/pytorch3d/implicitron/dataset/single_sequence_dataset.py b/pytorch3d/implicitron/dataset/single_sequence_dataset.py index 16972c6c..1090faa1 100644 --- a/pytorch3d/implicitron/dataset/single_sequence_dataset.py +++ b/pytorch3d/implicitron/dataset/single_sequence_dataset.py @@ -9,7 +9,7 @@ # provide data for a single scene. from dataclasses import field -from typing import Iterable, Iterator, List, Optional, Tuple +from typing import Iterable, Iterator, List, Optional, Sequence, Tuple import numpy as np import torch @@ -47,13 +47,12 @@ class SingleSceneDataset(DatasetBase, Configurable): def __len__(self) -> int: return len(self.poses) - # pyre-fixme[14]: `sequence_frames_in_order` overrides method defined in - # `DatasetBase` inconsistently. def sequence_frames_in_order( - self, seq_name: str + self, seq_name: str, subset_filter: Optional[Sequence[str]] = None ) -> Iterator[Tuple[float, int, int]]: for i in range(len(self)): - yield (0.0, i, i) + if subset_filter is None or self.frame_types[i] in subset_filter: + yield 0.0, i, i def __getitem__(self, index) -> FrameData: if index >= len(self):