mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
New file for ImplicitronDatasetBase
Summary: Separate ImplicitronDatasetBase and FrameData (to be used by all data sources) from ImplicitronDataset (which is specific). Reviewed By: shapovalov Differential Revision: D36413111 fbshipit-source-id: 3725744cde2e08baa11aff4048237ba10c7efbc6
This commit is contained in:
parent
73dc109dba
commit
69c6d06ed8
@ -66,11 +66,9 @@ from packaging import version
|
|||||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||||
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
|
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
|
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||||
FrameData,
|
|
||||||
ImplicitronDataset,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||||
|
@ -24,12 +24,9 @@ import torch.nn.functional as Fu
|
|||||||
from experiment import init_model
|
from experiment import init_model
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
|
||||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||||
FrameData,
|
|
||||||
ImplicitronDataset,
|
|
||||||
ImplicitronDatasetBase,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||||
|
@ -10,8 +10,8 @@ from typing import Optional, Sequence
|
|||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||||
|
|
||||||
|
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||||
from .dataset_zoo import Datasets
|
from .dataset_zoo import Datasets
|
||||||
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
|
|
||||||
from .scene_batch_sampler import SceneBatchSampler
|
from .scene_batch_sampler import SceneBatchSampler
|
||||||
|
|
||||||
|
|
||||||
|
268
pytorch3d/implicitron/dataset/dataset_base.py
Normal file
268
pytorch3d/implicitron/dataset/dataset_base.py
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field, fields
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||||
|
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FrameData(Mapping[str, Any]):
|
||||||
|
"""
|
||||||
|
A type of the elements returned by indexing the dataset object.
|
||||||
|
It can represent both individual frames and batches of thereof;
|
||||||
|
in this documentation, the sizes of tensors refer to single frames;
|
||||||
|
add the first batch dimension for the collation result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_number: The number of the frame within its sequence.
|
||||||
|
0-based continuous integers.
|
||||||
|
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
||||||
|
sequence_name: The unique name of the frame's sequence.
|
||||||
|
sequence_category: The object category of the sequence.
|
||||||
|
image_size_hw: The size of the image in pixels; (height, width) tuple.
|
||||||
|
image_path: The qualified path to the loaded image (with dataset_root).
|
||||||
|
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
||||||
|
of the frame; elements are floats in [0, 1].
|
||||||
|
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
||||||
|
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
||||||
|
are a result of zero-padding of the image after cropping around
|
||||||
|
the object bounding box; elements are floats in {0.0, 1.0}.
|
||||||
|
depth_path: The qualified path to the frame's depth map.
|
||||||
|
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
||||||
|
of the frame; values correspond to distances from the camera;
|
||||||
|
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
||||||
|
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
||||||
|
depth map that are valid for evaluation, they have been checked for
|
||||||
|
consistency across views; elements are floats in {0.0, 1.0}.
|
||||||
|
mask_path: A qualified path to the foreground probability mask.
|
||||||
|
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||||
|
pixels belonging to the captured object; elements are floats
|
||||||
|
in [0, 1].
|
||||||
|
bbox_xywh: The bounding box capturing the object in the
|
||||||
|
format (x0, y0, width, height).
|
||||||
|
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||||
|
corrected for cropping if it happened.
|
||||||
|
camera_quality_score: The score proportional to the confidence of the
|
||||||
|
frame's camera estimation (the higher the more accurate).
|
||||||
|
point_cloud_quality_score: The score proportional to the accuracy of the
|
||||||
|
frame's sequence point cloud (the higher the more accurate).
|
||||||
|
sequence_point_cloud_path: The path to the sequence's point cloud.
|
||||||
|
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
||||||
|
point cloud corresponding to the frame's sequence. When the object
|
||||||
|
represents a batch of frames, point clouds may be deduplicated;
|
||||||
|
see `sequence_point_cloud_idx`.
|
||||||
|
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
||||||
|
corresponding point clouds in `sequence_point_cloud`; to get the
|
||||||
|
corresponding point cloud to `image_rgb[i]`, use
|
||||||
|
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
||||||
|
frame_type: The type of the loaded frame specified in
|
||||||
|
`subset_lists_file`, if provided.
|
||||||
|
meta: A dict for storing additional frame information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frame_number: Optional[torch.LongTensor]
|
||||||
|
frame_timestamp: Optional[torch.Tensor]
|
||||||
|
sequence_name: Union[str, List[str]]
|
||||||
|
sequence_category: Union[str, List[str]]
|
||||||
|
image_size_hw: Optional[torch.Tensor] = None
|
||||||
|
image_path: Union[str, List[str], None] = None
|
||||||
|
image_rgb: Optional[torch.Tensor] = None
|
||||||
|
# masks out padding added due to cropping the square bit
|
||||||
|
mask_crop: Optional[torch.Tensor] = None
|
||||||
|
depth_path: Union[str, List[str], None] = None
|
||||||
|
depth_map: Optional[torch.Tensor] = None
|
||||||
|
depth_mask: Optional[torch.Tensor] = None
|
||||||
|
mask_path: Union[str, List[str], None] = None
|
||||||
|
fg_probability: Optional[torch.Tensor] = None
|
||||||
|
bbox_xywh: Optional[torch.Tensor] = None
|
||||||
|
camera: Optional[PerspectiveCameras] = None
|
||||||
|
camera_quality_score: Optional[torch.Tensor] = None
|
||||||
|
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||||
|
sequence_point_cloud_path: Union[str, List[str], None] = None
|
||||||
|
sequence_point_cloud: Optional[Pointclouds] = None
|
||||||
|
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
||||||
|
frame_type: Union[str, List[str], None] = None # seen | unseen
|
||||||
|
meta: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
new_params = {}
|
||||||
|
for f in fields(self):
|
||||||
|
value = getattr(self, f.name)
|
||||||
|
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||||
|
new_params[f.name] = value.to(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
new_params[f.name] = value
|
||||||
|
return type(self)(**new_params)
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
return self.to(device=torch.device("cpu"))
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
return self.to(device=torch.device("cuda"))
|
||||||
|
|
||||||
|
# the following functions make sure **frame_data can be passed to functions
|
||||||
|
def __iter__(self):
|
||||||
|
for f in fields(self):
|
||||||
|
yield f.name
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(fields(self))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def collate(cls, batch):
|
||||||
|
"""
|
||||||
|
Given a list objects `batch` of class `cls`, collates them into a batched
|
||||||
|
representation suitable for processing with deep networks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
elem = batch[0]
|
||||||
|
|
||||||
|
if isinstance(elem, cls):
|
||||||
|
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
||||||
|
id_to_idx = defaultdict(list)
|
||||||
|
for i, pc_id in enumerate(pointcloud_ids):
|
||||||
|
id_to_idx[pc_id].append(i)
|
||||||
|
|
||||||
|
sequence_point_cloud = []
|
||||||
|
sequence_point_cloud_idx = -np.ones((len(batch),))
|
||||||
|
for i, ind in enumerate(id_to_idx.values()):
|
||||||
|
sequence_point_cloud_idx[ind] = i
|
||||||
|
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
||||||
|
assert (sequence_point_cloud_idx >= 0).all()
|
||||||
|
|
||||||
|
override_fields = {
|
||||||
|
"sequence_point_cloud": sequence_point_cloud,
|
||||||
|
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
||||||
|
}
|
||||||
|
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
||||||
|
|
||||||
|
collated = {}
|
||||||
|
for f in fields(elem):
|
||||||
|
list_values = override_fields.get(
|
||||||
|
f.name, [getattr(d, f.name) for d in batch]
|
||||||
|
)
|
||||||
|
collated[f.name] = (
|
||||||
|
cls.collate(list_values)
|
||||||
|
if all(list_value is not None for list_value in list_values)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return cls(**collated)
|
||||||
|
|
||||||
|
elif isinstance(elem, Pointclouds):
|
||||||
|
return join_pointclouds_as_batch(batch)
|
||||||
|
|
||||||
|
elif isinstance(elem, CamerasBase):
|
||||||
|
# TODO: don't store K; enforce working in NDC space
|
||||||
|
return join_cameras_as_batch(batch)
|
||||||
|
else:
|
||||||
|
return torch.utils.data._utils.collate.default_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||||
|
"""
|
||||||
|
Base class to describe a dataset to be used with Implicitron.
|
||||||
|
|
||||||
|
The dataset is made up of frames, and the frames are grouped into sequences.
|
||||||
|
Each sequence has a name (a string).
|
||||||
|
(A sequence could be a video, or a set of images of one scene.)
|
||||||
|
|
||||||
|
This means they have a __getitem__ which returns an instance of a FrameData,
|
||||||
|
which will describe one frame in one sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Maps sequence name to the sequence's global frame indices.
|
||||||
|
# It is used for the default implementations of some functions in this class.
|
||||||
|
# Implementations which override them are free to ignore this member.
|
||||||
|
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_frame_numbers_and_timestamps(
|
||||||
|
self, idxs: Sequence[int]
|
||||||
|
) -> List[Tuple[int, float]]:
|
||||||
|
"""
|
||||||
|
If the sequences in the dataset are videos rather than
|
||||||
|
unordered views, then the dataset should override this method to
|
||||||
|
return the index and timestamp in their videos of the frames whose
|
||||||
|
indices are given in `idxs`. In addition,
|
||||||
|
the values in _seq_to_idx should be in ascending order.
|
||||||
|
If timestamps are absent, they should be replaced with a constant.
|
||||||
|
|
||||||
|
This is used for letting SceneBatchSampler identify consecutive
|
||||||
|
frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: frame index in self
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of
|
||||||
|
- frame index in video
|
||||||
|
- timestamp of frame in video
|
||||||
|
"""
|
||||||
|
raise ValueError("This dataset does not contain videos.")
|
||||||
|
|
||||||
|
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def sequence_names(self) -> Iterable[str]:
|
||||||
|
"""Returns an iterator over sequence names in the dataset."""
|
||||||
|
return self._seq_to_idx.keys()
|
||||||
|
|
||||||
|
def sequence_frames_in_order(
|
||||||
|
self, seq_name: str
|
||||||
|
) -> Iterator[Tuple[float, int, int]]:
|
||||||
|
"""Returns an iterator over the frame indices in a given sequence.
|
||||||
|
We attempt to first sort by timestamp (if they are available),
|
||||||
|
then by frame number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_name: the name of the sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||||
|
where `frame_no` is the index within the sequence, and
|
||||||
|
`dataset_idx` is the index within the dataset.
|
||||||
|
`None` timestamps are replaced with 0s.
|
||||||
|
"""
|
||||||
|
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||||
|
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||||
|
|
||||||
|
yield from sorted(
|
||||||
|
[
|
||||||
|
(timestamp, frame_no, idx)
|
||||||
|
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
||||||
|
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||||
|
only dataset indices.
|
||||||
|
"""
|
||||||
|
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||||
|
yield idx
|
@ -13,7 +13,8 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence
|
|||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||||
|
|
||||||
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
|
from .dataset_base import ImplicitronDatasetBase
|
||||||
|
from .implicitron_dataset import ImplicitronDataset
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DATASET_TYPE_KNOWN,
|
DATASET_TYPE_KNOWN,
|
||||||
DATASET_TYPE_TEST,
|
DATASET_TYPE_TEST,
|
||||||
|
@ -13,17 +13,13 @@ import os
|
|||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
List,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -37,258 +33,16 @@ import torch
|
|||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.io import IO
|
from pytorch3d.io import IO
|
||||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
|
||||||
|
|
||||||
from . import types
|
from . import types
|
||||||
|
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FrameData(Mapping[str, Any]):
|
|
||||||
"""
|
|
||||||
A type of the elements returned by indexing the dataset object.
|
|
||||||
It can represent both individual frames and batches of thereof;
|
|
||||||
in this documentation, the sizes of tensors refer to single frames;
|
|
||||||
add the first batch dimension for the collation result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frame_number: The number of the frame within its sequence.
|
|
||||||
0-based continuous integers.
|
|
||||||
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
|
||||||
sequence_name: The unique name of the frame's sequence.
|
|
||||||
sequence_category: The object category of the sequence.
|
|
||||||
image_size_hw: The size of the image in pixels; (height, width) tuple.
|
|
||||||
image_path: The qualified path to the loaded image (with dataset_root).
|
|
||||||
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
|
||||||
of the frame; elements are floats in [0, 1].
|
|
||||||
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
|
||||||
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
|
||||||
are a result of zero-padding of the image after cropping around
|
|
||||||
the object bounding box; elements are floats in {0.0, 1.0}.
|
|
||||||
depth_path: The qualified path to the frame's depth map.
|
|
||||||
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
|
||||||
of the frame; values correspond to distances from the camera;
|
|
||||||
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
|
||||||
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
|
||||||
depth map that are valid for evaluation, they have been checked for
|
|
||||||
consistency across views; elements are floats in {0.0, 1.0}.
|
|
||||||
mask_path: A qualified path to the foreground probability mask.
|
|
||||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
|
||||||
pixels belonging to the captured object; elements are floats
|
|
||||||
in [0, 1].
|
|
||||||
bbox_xywh: The bounding box capturing the object in the
|
|
||||||
format (x0, y0, width, height).
|
|
||||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
|
||||||
corrected for cropping if it happened.
|
|
||||||
camera_quality_score: The score proportional to the confidence of the
|
|
||||||
frame's camera estimation (the higher the more accurate).
|
|
||||||
point_cloud_quality_score: The score proportional to the accuracy of the
|
|
||||||
frame's sequence point cloud (the higher the more accurate).
|
|
||||||
sequence_point_cloud_path: The path to the sequence's point cloud.
|
|
||||||
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
|
||||||
point cloud corresponding to the frame's sequence. When the object
|
|
||||||
represents a batch of frames, point clouds may be deduplicated;
|
|
||||||
see `sequence_point_cloud_idx`.
|
|
||||||
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
|
||||||
corresponding point clouds in `sequence_point_cloud`; to get the
|
|
||||||
corresponding point cloud to `image_rgb[i]`, use
|
|
||||||
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
|
||||||
frame_type: The type of the loaded frame specified in
|
|
||||||
`subset_lists_file`, if provided.
|
|
||||||
meta: A dict for storing additional frame information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
frame_number: Optional[torch.LongTensor]
|
|
||||||
frame_timestamp: Optional[torch.Tensor]
|
|
||||||
sequence_name: Union[str, List[str]]
|
|
||||||
sequence_category: Union[str, List[str]]
|
|
||||||
image_size_hw: Optional[torch.Tensor] = None
|
|
||||||
image_path: Union[str, List[str], None] = None
|
|
||||||
image_rgb: Optional[torch.Tensor] = None
|
|
||||||
# masks out padding added due to cropping the square bit
|
|
||||||
mask_crop: Optional[torch.Tensor] = None
|
|
||||||
depth_path: Union[str, List[str], None] = None
|
|
||||||
depth_map: Optional[torch.Tensor] = None
|
|
||||||
depth_mask: Optional[torch.Tensor] = None
|
|
||||||
mask_path: Union[str, List[str], None] = None
|
|
||||||
fg_probability: Optional[torch.Tensor] = None
|
|
||||||
bbox_xywh: Optional[torch.Tensor] = None
|
|
||||||
camera: Optional[PerspectiveCameras] = None
|
|
||||||
camera_quality_score: Optional[torch.Tensor] = None
|
|
||||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
|
||||||
sequence_point_cloud_path: Union[str, List[str], None] = None
|
|
||||||
sequence_point_cloud: Optional[Pointclouds] = None
|
|
||||||
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
|
||||||
frame_type: Union[str, List[str], None] = None # seen | unseen
|
|
||||||
meta: dict = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
new_params = {}
|
|
||||||
for f in fields(self):
|
|
||||||
value = getattr(self, f.name)
|
|
||||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
|
||||||
new_params[f.name] = value.to(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
new_params[f.name] = value
|
|
||||||
return type(self)(**new_params)
|
|
||||||
|
|
||||||
def cpu(self):
|
|
||||||
return self.to(device=torch.device("cpu"))
|
|
||||||
|
|
||||||
def cuda(self):
|
|
||||||
return self.to(device=torch.device("cuda"))
|
|
||||||
|
|
||||||
# the following functions make sure **frame_data can be passed to functions
|
|
||||||
def __iter__(self):
|
|
||||||
for f in fields(self):
|
|
||||||
yield f.name
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(fields(self))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def collate(cls, batch):
|
|
||||||
"""
|
|
||||||
Given a list objects `batch` of class `cls`, collates them into a batched
|
|
||||||
representation suitable for processing with deep networks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
elem = batch[0]
|
|
||||||
|
|
||||||
if isinstance(elem, cls):
|
|
||||||
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
|
||||||
id_to_idx = defaultdict(list)
|
|
||||||
for i, pc_id in enumerate(pointcloud_ids):
|
|
||||||
id_to_idx[pc_id].append(i)
|
|
||||||
|
|
||||||
sequence_point_cloud = []
|
|
||||||
sequence_point_cloud_idx = -np.ones((len(batch),))
|
|
||||||
for i, ind in enumerate(id_to_idx.values()):
|
|
||||||
sequence_point_cloud_idx[ind] = i
|
|
||||||
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
|
||||||
assert (sequence_point_cloud_idx >= 0).all()
|
|
||||||
|
|
||||||
override_fields = {
|
|
||||||
"sequence_point_cloud": sequence_point_cloud,
|
|
||||||
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
|
||||||
}
|
|
||||||
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
|
||||||
|
|
||||||
collated = {}
|
|
||||||
for f in fields(elem):
|
|
||||||
list_values = override_fields.get(
|
|
||||||
f.name, [getattr(d, f.name) for d in batch]
|
|
||||||
)
|
|
||||||
collated[f.name] = (
|
|
||||||
cls.collate(list_values)
|
|
||||||
if all(list_value is not None for list_value in list_values)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return cls(**collated)
|
|
||||||
|
|
||||||
elif isinstance(elem, Pointclouds):
|
|
||||||
return join_pointclouds_as_batch(batch)
|
|
||||||
|
|
||||||
elif isinstance(elem, CamerasBase):
|
|
||||||
# TODO: don't store K; enforce working in NDC space
|
|
||||||
return join_cameras_as_batch(batch)
|
|
||||||
else:
|
|
||||||
return torch.utils.data._utils.collate.default_collate(batch)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
|
||||||
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
|
||||||
"""
|
|
||||||
Base class to describe a dataset to be used with Implicitron.
|
|
||||||
|
|
||||||
The dataset is made up of frames, and the frames are grouped into sequences.
|
|
||||||
Each sequence has a name (a string).
|
|
||||||
(A sequence could be a video, or a set of images of one scene.)
|
|
||||||
|
|
||||||
This means they have a __getitem__ which returns an instance of a FrameData,
|
|
||||||
which will describe one frame in one sequence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Maps sequence name to the sequence's global frame indices.
|
|
||||||
# It is used for the default implementations of some functions in this class.
|
|
||||||
# Implementations which override them are free to ignore this member.
|
|
||||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_frame_numbers_and_timestamps(
|
|
||||||
self, idxs: Sequence[int]
|
|
||||||
) -> List[Tuple[int, float]]:
|
|
||||||
"""
|
|
||||||
If the sequences in the dataset are videos rather than
|
|
||||||
unordered views, then the dataset should override this method to
|
|
||||||
return the index and timestamp in their videos of the frames whose
|
|
||||||
indices are given in `idxs`. In addition,
|
|
||||||
the values in _seq_to_idx should be in ascending order.
|
|
||||||
If timestamps are absent, they should be replaced with a constant.
|
|
||||||
|
|
||||||
This is used for letting SceneBatchSampler identify consecutive
|
|
||||||
frames.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idx: frame index in self
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple of
|
|
||||||
- frame index in video
|
|
||||||
- timestamp of frame in video
|
|
||||||
"""
|
|
||||||
raise ValueError("This dataset does not contain videos.")
|
|
||||||
|
|
||||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def sequence_names(self) -> Iterable[str]:
|
|
||||||
"""Returns an iterator over sequence names in the dataset."""
|
|
||||||
return self._seq_to_idx.keys()
|
|
||||||
|
|
||||||
def sequence_frames_in_order(
|
|
||||||
self, seq_name: str
|
|
||||||
) -> Iterator[Tuple[float, int, int]]:
|
|
||||||
"""Returns an iterator over the frame indices in a given sequence.
|
|
||||||
We attempt to first sort by timestamp (if they are available),
|
|
||||||
then by frame number.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_name: the name of the sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
|
||||||
where `frame_no` is the index within the sequence, and
|
|
||||||
`dataset_idx` is the index within the dataset.
|
|
||||||
`None` timestamps are replaced with 0s.
|
|
||||||
"""
|
|
||||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
|
||||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
|
||||||
|
|
||||||
yield from sorted(
|
|
||||||
[
|
|
||||||
(timestamp, frame_no, idx)
|
|
||||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
|
||||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
|
||||||
only dataset indices.
|
|
||||||
"""
|
|
||||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
|
||||||
yield idx
|
|
||||||
|
|
||||||
|
|
||||||
class FrameAnnotsEntry(TypedDict):
|
class FrameAnnotsEntry(TypedDict):
|
||||||
subset: Optional[str]
|
subset: Optional[str]
|
||||||
frame_annotation: types.FrameAnnotation
|
frame_annotation: types.FrameAnnotation
|
||||||
|
@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
from .implicitron_dataset import ImplicitronDatasetBase
|
from .dataset_base import ImplicitronDatasetBase
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
||||||
|
@ -10,7 +10,8 @@ import torch
|
|||||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
||||||
from pytorch3d.structures import Pointclouds
|
from pytorch3d.structures import Pointclouds
|
||||||
|
|
||||||
from .implicitron_dataset import FrameData, ImplicitronDataset
|
from .dataset_base import FrameData
|
||||||
|
from .implicitron_dataset import ImplicitronDataset
|
||||||
|
|
||||||
|
|
||||||
def get_implicitron_sequence_pointcloud(
|
def get_implicitron_sequence_pointcloud(
|
||||||
|
@ -14,12 +14,9 @@ import torch
|
|||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from pytorch3d.implicitron.dataset.data_source import Task
|
from pytorch3d.implicitron.dataset.data_source import Task
|
||||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
|
||||||
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
|
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||||
FrameData,
|
|
||||||
ImplicitronDataset,
|
|
||||||
ImplicitronDatasetBase,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||||
aggregate_nvs_results,
|
aggregate_nvs_results,
|
||||||
|
@ -15,7 +15,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pytorch3d.implicitron.dataset.data_source import Task
|
from pytorch3d.implicitron.dataset.data_source import Task
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||||
from pytorch3d.implicitron.tools import vis_utils
|
from pytorch3d.implicitron.tools import vis_utils
|
||||||
|
@ -9,7 +9,7 @@ import unittest
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase
|
||||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,10 +14,8 @@ import unittest
|
|||||||
|
|
||||||
import lpips
|
import lpips
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
FrameData,
|
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||||
ImplicitronDataset,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
|
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
|
||||||
|
Loading…
x
Reference in New Issue
Block a user