Files
pytorch3d/pytorch3d/implicitron/dataset/single_sequence_dataset.py
Jeremy Reizenstein b6a77ad7aa [pytorch3d[ Remove LlffDatasetMapProvider and BlenderDatasetMapProvider
Summary:
No one is using these.

(The minify part has been broken for a couple of years, too)

Reviewed By: patricklabatut

Differential Revision: D96977684

fbshipit-source-id: 4708dfd37b14d1930f1370677eb126a61a0d9d3c
2026-03-18 10:09:59 -07:00

174 lines
6.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# This file defines a base class for dataset map providers which
# provide data for a single scene.
from dataclasses import field
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
import numpy as np
import torch
from pytorch3d.implicitron.tools.config import (
Configurable,
expand_args_fields,
run_auto_creation,
)
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
from .dataset_base import DatasetBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
from .frame_data import FrameData
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
@expand_args_fields
class SingleSceneDataset(DatasetBase, Configurable):
"""
A dataset from images from a single scene.
"""
images: List[torch.Tensor] = field()
fg_probabilities: Optional[List[torch.Tensor]] = field()
poses: List[PerspectiveCameras] = field()
object_name: str = field()
frame_types: List[str] = field()
eval_batches: Optional[List[List[int]]] = field()
def sequence_names(self) -> Iterable[str]:
return [_SINGLE_SEQUENCE_NAME]
def __len__(self) -> int:
return len(self.poses)
def sequence_frames_in_order(
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
) -> Iterator[Tuple[float, int, int]]:
for i in range(len(self)):
if subset_filter is None or self.frame_types[i] in subset_filter:
yield 0.0, i, i
def __getitem__(self, index) -> FrameData:
if index >= len(self):
raise IndexError(f"index {index} out of range {len(self)}")
image = self.images[index]
pose = self.poses[index]
frame_type = self.frame_types[index]
fg_probability = (
None if self.fg_probabilities is None else self.fg_probabilities[index]
)
frame_data = FrameData(
frame_number=index,
sequence_name=_SINGLE_SEQUENCE_NAME,
sequence_category=self.object_name,
camera=pose,
# pyre-ignore
image_size_hw=torch.tensor(image.shape[1:], dtype=torch.long),
image_rgb=image,
fg_probability=fg_probability,
frame_type=frame_type,
)
return frame_data
def get_eval_batches(self) -> Optional[List[List[int]]]:
return self.eval_batches
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
"""
Base for provider of data for one scene.
Members:
base_dir: directory holding the data for the scene.
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
It will typically be equal to the name of the directory self.base_dir.
path_manager_factory: Creates path manager which may be used for
interpreting paths.
n_known_frames_for_test: If set, training frames are included in the val
and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single
testing frame.
"""
# pyre-fixme[13]: Attribute `base_dir` is never initialized.
base_dir: str
# pyre-fixme[13]: Attribute `object_name` is never initialized.
object_name: str
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
n_known_frames_for_test: Optional[int] = None
def __post_init__(self) -> None:
run_auto_creation(self)
self._load_data()
def _load_data(self) -> None:
# This must be defined by each subclass,
# and should set the following on self.
# - poses: a list of length-1 camera objects
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
# - splits: List[List[int]] of indices for train/val/test subsets.
raise NotImplementedError()
def _get_dataset(
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
) -> SingleSceneDataset:
# pyre-ignore[16]
split = self.i_split[split_idx]
frame_types = [frame_type] * len(split)
fg_probabilities = (
None
# pyre-ignore[16]
if self.fg_probabilities is None
else self.fg_probabilities[split]
)
eval_batches = [[i] for i in range(len(split))]
if split_idx != 0 and self.n_known_frames_for_test is not None:
train_split = self.i_split[0]
if set_eval_batches:
generator = np.random.default_rng(seed=0)
for batch in eval_batches:
# using permutation so that changes to n_known_frames_for_test
# result in consistent batches.
to_add = generator.permutation(len(train_split))[
: self.n_known_frames_for_test
]
batch.extend((to_add + len(split)).tolist())
split = np.concatenate([split, train_split])
frame_types.extend([DATASET_TYPE_KNOWN] * len(train_split))
# pyre-ignore[28]
return SingleSceneDataset(
object_name=self.object_name,
# pyre-ignore[16]
images=self.images[split],
fg_probabilities=fg_probabilities,
# pyre-ignore[16]
poses=[self.poses[i] for i in split],
frame_types=frame_types,
eval_batches=eval_batches if set_eval_batches else None,
)
def get_dataset_map(self) -> DatasetMap:
return DatasetMap(
train=self._get_dataset(0, DATASET_TYPE_KNOWN),
val=self._get_dataset(1, DATASET_TYPE_UNKNOWN),
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
)
def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16]
cameras = [self.poses[i] for i in self.i_split[0]]
return join_cameras_as_batch(cameras)