Fix: subset filter in DatasetBase implementation

Summary: In D42739669, I forgot to update the API of existing implementations of DatasetBase to take `subset_filter`. Looks like only one was missing.

Reviewed By: bottler

Differential Revision: D46724488

fbshipit-source-id: 13ab7a457f853278cf06955aad0cc2bab5fbcce6
This commit is contained in:
Roman Shapovalov 2023-06-14 08:48:14 -07:00 committed by Facebook GitHub Bot
parent 5592d25f68
commit 3d886c32d5

View File

@ -9,7 +9,7 @@
# provide data for a single scene. # provide data for a single scene.
from dataclasses import field from dataclasses import field
from typing import Iterable, Iterator, List, Optional, Tuple from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
@ -47,13 +47,12 @@ class SingleSceneDataset(DatasetBase, Configurable):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.poses) return len(self.poses)
# pyre-fixme[14]: `sequence_frames_in_order` overrides method defined in
# `DatasetBase` inconsistently.
def sequence_frames_in_order( def sequence_frames_in_order(
self, seq_name: str self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
) -> Iterator[Tuple[float, int, int]]: ) -> Iterator[Tuple[float, int, int]]:
for i in range(len(self)): for i in range(len(self)):
yield (0.0, i, i) if subset_filter is None or self.frame_types[i] in subset_filter:
yield 0.0, i, i
def __getitem__(self, index) -> FrameData: def __getitem__(self, index) -> FrameData:
if index >= len(self): if index >= len(self):