mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-19 20:35:59 +08:00
Summary: No one is using these. (The minify part has been broken for a couple of years, too) Reviewed By: patricklabatut Differential Revision: D96977684 fbshipit-source-id: 4708dfd37b14d1930f1370677eb126a61a0d9d3c
174 lines
6.3 KiB
Python
174 lines
6.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# pyre-unsafe
|
|
|
|
|
|
# This file defines a base class for dataset map providers which
|
|
# provide data for a single scene.
|
|
|
|
from dataclasses import field
|
|
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from pytorch3d.implicitron.tools.config import (
|
|
Configurable,
|
|
expand_args_fields,
|
|
run_auto_creation,
|
|
)
|
|
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
|
|
|
from .dataset_base import DatasetBase
|
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
|
from .frame_data import FrameData
|
|
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
|
|
|
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
|
|
|
|
|
@expand_args_fields
|
|
class SingleSceneDataset(DatasetBase, Configurable):
|
|
"""
|
|
A dataset from images from a single scene.
|
|
"""
|
|
|
|
images: List[torch.Tensor] = field()
|
|
fg_probabilities: Optional[List[torch.Tensor]] = field()
|
|
poses: List[PerspectiveCameras] = field()
|
|
object_name: str = field()
|
|
frame_types: List[str] = field()
|
|
eval_batches: Optional[List[List[int]]] = field()
|
|
|
|
def sequence_names(self) -> Iterable[str]:
|
|
return [_SINGLE_SEQUENCE_NAME]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.poses)
|
|
|
|
def sequence_frames_in_order(
|
|
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
|
) -> Iterator[Tuple[float, int, int]]:
|
|
for i in range(len(self)):
|
|
if subset_filter is None or self.frame_types[i] in subset_filter:
|
|
yield 0.0, i, i
|
|
|
|
def __getitem__(self, index) -> FrameData:
|
|
if index >= len(self):
|
|
raise IndexError(f"index {index} out of range {len(self)}")
|
|
image = self.images[index]
|
|
pose = self.poses[index]
|
|
frame_type = self.frame_types[index]
|
|
fg_probability = (
|
|
None if self.fg_probabilities is None else self.fg_probabilities[index]
|
|
)
|
|
|
|
frame_data = FrameData(
|
|
frame_number=index,
|
|
sequence_name=_SINGLE_SEQUENCE_NAME,
|
|
sequence_category=self.object_name,
|
|
camera=pose,
|
|
# pyre-ignore
|
|
image_size_hw=torch.tensor(image.shape[1:], dtype=torch.long),
|
|
image_rgb=image,
|
|
fg_probability=fg_probability,
|
|
frame_type=frame_type,
|
|
)
|
|
return frame_data
|
|
|
|
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
|
return self.eval_batches
|
|
|
|
|
|
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|
"""
|
|
Base for provider of data for one scene.
|
|
|
|
Members:
|
|
base_dir: directory holding the data for the scene.
|
|
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
|
|
It will typically be equal to the name of the directory self.base_dir.
|
|
path_manager_factory: Creates path manager which may be used for
|
|
interpreting paths.
|
|
n_known_frames_for_test: If set, training frames are included in the val
|
|
and test datasets, and this many random training frames are added to
|
|
each test batch. If not set, test batches each contain just a single
|
|
testing frame.
|
|
"""
|
|
|
|
# pyre-fixme[13]: Attribute `base_dir` is never initialized.
|
|
base_dir: str
|
|
# pyre-fixme[13]: Attribute `object_name` is never initialized.
|
|
object_name: str
|
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
|
path_manager_factory: PathManagerFactory
|
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
|
n_known_frames_for_test: Optional[int] = None
|
|
|
|
def __post_init__(self) -> None:
|
|
run_auto_creation(self)
|
|
self._load_data()
|
|
|
|
def _load_data(self) -> None:
|
|
# This must be defined by each subclass,
|
|
# and should set the following on self.
|
|
# - poses: a list of length-1 camera objects
|
|
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
|
|
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
|
|
# - splits: List[List[int]] of indices for train/val/test subsets.
|
|
raise NotImplementedError()
|
|
|
|
def _get_dataset(
|
|
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
|
|
) -> SingleSceneDataset:
|
|
# pyre-ignore[16]
|
|
split = self.i_split[split_idx]
|
|
frame_types = [frame_type] * len(split)
|
|
fg_probabilities = (
|
|
None
|
|
# pyre-ignore[16]
|
|
if self.fg_probabilities is None
|
|
else self.fg_probabilities[split]
|
|
)
|
|
eval_batches = [[i] for i in range(len(split))]
|
|
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
|
train_split = self.i_split[0]
|
|
if set_eval_batches:
|
|
generator = np.random.default_rng(seed=0)
|
|
for batch in eval_batches:
|
|
# using permutation so that changes to n_known_frames_for_test
|
|
# result in consistent batches.
|
|
to_add = generator.permutation(len(train_split))[
|
|
: self.n_known_frames_for_test
|
|
]
|
|
batch.extend((to_add + len(split)).tolist())
|
|
split = np.concatenate([split, train_split])
|
|
frame_types.extend([DATASET_TYPE_KNOWN] * len(train_split))
|
|
|
|
# pyre-ignore[28]
|
|
return SingleSceneDataset(
|
|
object_name=self.object_name,
|
|
# pyre-ignore[16]
|
|
images=self.images[split],
|
|
fg_probabilities=fg_probabilities,
|
|
# pyre-ignore[16]
|
|
poses=[self.poses[i] for i in split],
|
|
frame_types=frame_types,
|
|
eval_batches=eval_batches if set_eval_batches else None,
|
|
)
|
|
|
|
def get_dataset_map(self) -> DatasetMap:
|
|
return DatasetMap(
|
|
train=self._get_dataset(0, DATASET_TYPE_KNOWN),
|
|
val=self._get_dataset(1, DATASET_TYPE_UNKNOWN),
|
|
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
|
|
)
|
|
|
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
|
# pyre-ignore[16]
|
|
cameras = [self.poses[i] for i in self.i_split[0]]
|
|
return join_cameras_as_batch(cameras)
|