mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
New file for ImplicitronDatasetBase
Summary: Separate ImplicitronDatasetBase and FrameData (to be used by all data sources) from ImplicitronDataset (which is specific). Reviewed By: shapovalov Differential Revision: D36413111 fbshipit-source-id: 3725744cde2e08baa11aff4048237ba10c7efbc6
This commit is contained in:
parent
73dc109dba
commit
69c6d06ed8
@ -66,11 +66,9 @@ from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
|
@ -24,12 +24,9 @@ import torch.nn.functional as Fu
|
||||
from experiment import init_model
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
ImplicitronDatasetBase,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||
|
@ -10,8 +10,8 @@ from typing import Optional, Sequence
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from .dataset_zoo import Datasets
|
||||
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
|
268
pytorch3d/implicitron/dataset/dataset_base.py
Normal file
268
pytorch3d/implicitron/dataset/dataset_base.py
Normal file
@ -0,0 +1,268 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
"""
|
||||
A type of the elements returned by indexing the dataset object.
|
||||
It can represent both individual frames and batches of thereof;
|
||||
in this documentation, the sizes of tensors refer to single frames;
|
||||
add the first batch dimension for the collation result.
|
||||
|
||||
Args:
|
||||
frame_number: The number of the frame within its sequence.
|
||||
0-based continuous integers.
|
||||
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
||||
sequence_name: The unique name of the frame's sequence.
|
||||
sequence_category: The object category of the sequence.
|
||||
image_size_hw: The size of the image in pixels; (height, width) tuple.
|
||||
image_path: The qualified path to the loaded image (with dataset_root).
|
||||
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
||||
of the frame; elements are floats in [0, 1].
|
||||
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
||||
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
||||
are a result of zero-padding of the image after cropping around
|
||||
the object bounding box; elements are floats in {0.0, 1.0}.
|
||||
depth_path: The qualified path to the frame's depth map.
|
||||
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
||||
of the frame; values correspond to distances from the camera;
|
||||
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
||||
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
||||
depth map that are valid for evaluation, they have been checked for
|
||||
consistency across views; elements are floats in {0.0, 1.0}.
|
||||
mask_path: A qualified path to the foreground probability mask.
|
||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||
pixels belonging to the captured object; elements are floats
|
||||
in [0, 1].
|
||||
bbox_xywh: The bounding box capturing the object in the
|
||||
format (x0, y0, width, height).
|
||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||
corrected for cropping if it happened.
|
||||
camera_quality_score: The score proportional to the confidence of the
|
||||
frame's camera estimation (the higher the more accurate).
|
||||
point_cloud_quality_score: The score proportional to the accuracy of the
|
||||
frame's sequence point cloud (the higher the more accurate).
|
||||
sequence_point_cloud_path: The path to the sequence's point cloud.
|
||||
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
||||
point cloud corresponding to the frame's sequence. When the object
|
||||
represents a batch of frames, point clouds may be deduplicated;
|
||||
see `sequence_point_cloud_idx`.
|
||||
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
||||
corresponding point clouds in `sequence_point_cloud`; to get the
|
||||
corresponding point cloud to `image_rgb[i]`, use
|
||||
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
||||
frame_type: The type of the loaded frame specified in
|
||||
`subset_lists_file`, if provided.
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
frame_timestamp: Optional[torch.Tensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
image_size_hw: Optional[torch.Tensor] = None
|
||||
image_path: Union[str, List[str], None] = None
|
||||
image_rgb: Optional[torch.Tensor] = None
|
||||
# masks out padding added due to cropping the square bit
|
||||
mask_crop: Optional[torch.Tensor] = None
|
||||
depth_path: Union[str, List[str], None] = None
|
||||
depth_map: Optional[torch.Tensor] = None
|
||||
depth_mask: Optional[torch.Tensor] = None
|
||||
mask_path: Union[str, List[str], None] = None
|
||||
fg_probability: Optional[torch.Tensor] = None
|
||||
bbox_xywh: Optional[torch.Tensor] = None
|
||||
camera: Optional[PerspectiveCameras] = None
|
||||
camera_quality_score: Optional[torch.Tensor] = None
|
||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||
sequence_point_cloud_path: Union[str, List[str], None] = None
|
||||
sequence_point_cloud: Optional[Pointclouds] = None
|
||||
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
||||
frame_type: Union[str, List[str], None] = None # seen | unseen
|
||||
meta: dict = field(default_factory=lambda: {})
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new_params = {}
|
||||
for f in fields(self):
|
||||
value = getattr(self, f.name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
new_params[f.name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[f.name] = value
|
||||
return type(self)(**new_params)
|
||||
|
||||
def cpu(self):
|
||||
return self.to(device=torch.device("cpu"))
|
||||
|
||||
def cuda(self):
|
||||
return self.to(device=torch.device("cuda"))
|
||||
|
||||
# the following functions make sure **frame_data can be passed to functions
|
||||
def __iter__(self):
|
||||
for f in fields(self):
|
||||
yield f.name
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
@classmethod
|
||||
def collate(cls, batch):
|
||||
"""
|
||||
Given a list objects `batch` of class `cls`, collates them into a batched
|
||||
representation suitable for processing with deep networks.
|
||||
"""
|
||||
|
||||
elem = batch[0]
|
||||
|
||||
if isinstance(elem, cls):
|
||||
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
||||
id_to_idx = defaultdict(list)
|
||||
for i, pc_id in enumerate(pointcloud_ids):
|
||||
id_to_idx[pc_id].append(i)
|
||||
|
||||
sequence_point_cloud = []
|
||||
sequence_point_cloud_idx = -np.ones((len(batch),))
|
||||
for i, ind in enumerate(id_to_idx.values()):
|
||||
sequence_point_cloud_idx[ind] = i
|
||||
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
||||
assert (sequence_point_cloud_idx >= 0).all()
|
||||
|
||||
override_fields = {
|
||||
"sequence_point_cloud": sequence_point_cloud,
|
||||
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
||||
}
|
||||
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
||||
|
||||
collated = {}
|
||||
for f in fields(elem):
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
collated[f.name] = (
|
||||
cls.collate(list_values)
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data._utils.collate.default_collate(batch)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
Base class to describe a dataset to be used with Implicitron.
|
||||
|
||||
The dataset is made up of frames, and the frames are grouped into sequences.
|
||||
Each sequence has a name (a string).
|
||||
(A sequence could be a video, or a set of images of one scene.)
|
||||
|
||||
This means they have a __getitem__ which returns an instance of a FrameData,
|
||||
which will describe one frame in one sequence.
|
||||
"""
|
||||
|
||||
# Maps sequence name to the sequence's global frame indices.
|
||||
# It is used for the default implementations of some functions in this class.
|
||||
# Implementations which override them are free to ignore this member.
|
||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_frame_numbers_and_timestamps(
|
||||
self, idxs: Sequence[int]
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
If the sequences in the dataset are videos rather than
|
||||
unordered views, then the dataset should override this method to
|
||||
return the index and timestamp in their videos of the frames whose
|
||||
indices are given in `idxs`. In addition,
|
||||
the values in _seq_to_idx should be in ascending order.
|
||||
If timestamps are absent, they should be replaced with a constant.
|
||||
|
||||
This is used for letting SceneBatchSampler identify consecutive
|
||||
frames.
|
||||
|
||||
Args:
|
||||
idx: frame index in self
|
||||
|
||||
Returns:
|
||||
tuple of
|
||||
- frame index in video
|
||||
- timestamp of frame in video
|
||||
"""
|
||||
raise ValueError("This dataset does not contain videos.")
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return None
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
"""Returns an iterator over the frame indices in a given sequence.
|
||||
We attempt to first sort by timestamp (if they are available),
|
||||
then by frame number.
|
||||
|
||||
Args:
|
||||
seq_name: the name of the sequence.
|
||||
|
||||
Returns:
|
||||
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||
where `frame_no` is the index within the sequence, and
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
yield from sorted(
|
||||
[
|
||||
(timestamp, frame_no, idx)
|
||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||
]
|
||||
)
|
||||
|
||||
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||
yield idx
|
@ -13,7 +13,8 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
|
||||
from .dataset_base import ImplicitronDatasetBase
|
||||
from .implicitron_dataset import ImplicitronDataset
|
||||
from .utils import (
|
||||
DATASET_TYPE_KNOWN,
|
||||
DATASET_TYPE_TEST,
|
||||
|
@ -13,17 +13,13 @@ import os
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@ -37,258 +33,16 @@ import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from . import types
|
||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
"""
|
||||
A type of the elements returned by indexing the dataset object.
|
||||
It can represent both individual frames and batches of thereof;
|
||||
in this documentation, the sizes of tensors refer to single frames;
|
||||
add the first batch dimension for the collation result.
|
||||
|
||||
Args:
|
||||
frame_number: The number of the frame within its sequence.
|
||||
0-based continuous integers.
|
||||
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
||||
sequence_name: The unique name of the frame's sequence.
|
||||
sequence_category: The object category of the sequence.
|
||||
image_size_hw: The size of the image in pixels; (height, width) tuple.
|
||||
image_path: The qualified path to the loaded image (with dataset_root).
|
||||
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
||||
of the frame; elements are floats in [0, 1].
|
||||
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
||||
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
||||
are a result of zero-padding of the image after cropping around
|
||||
the object bounding box; elements are floats in {0.0, 1.0}.
|
||||
depth_path: The qualified path to the frame's depth map.
|
||||
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
||||
of the frame; values correspond to distances from the camera;
|
||||
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
||||
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
||||
depth map that are valid for evaluation, they have been checked for
|
||||
consistency across views; elements are floats in {0.0, 1.0}.
|
||||
mask_path: A qualified path to the foreground probability mask.
|
||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||
pixels belonging to the captured object; elements are floats
|
||||
in [0, 1].
|
||||
bbox_xywh: The bounding box capturing the object in the
|
||||
format (x0, y0, width, height).
|
||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||
corrected for cropping if it happened.
|
||||
camera_quality_score: The score proportional to the confidence of the
|
||||
frame's camera estimation (the higher the more accurate).
|
||||
point_cloud_quality_score: The score proportional to the accuracy of the
|
||||
frame's sequence point cloud (the higher the more accurate).
|
||||
sequence_point_cloud_path: The path to the sequence's point cloud.
|
||||
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
||||
point cloud corresponding to the frame's sequence. When the object
|
||||
represents a batch of frames, point clouds may be deduplicated;
|
||||
see `sequence_point_cloud_idx`.
|
||||
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
||||
corresponding point clouds in `sequence_point_cloud`; to get the
|
||||
corresponding point cloud to `image_rgb[i]`, use
|
||||
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
||||
frame_type: The type of the loaded frame specified in
|
||||
`subset_lists_file`, if provided.
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
frame_timestamp: Optional[torch.Tensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
image_size_hw: Optional[torch.Tensor] = None
|
||||
image_path: Union[str, List[str], None] = None
|
||||
image_rgb: Optional[torch.Tensor] = None
|
||||
# masks out padding added due to cropping the square bit
|
||||
mask_crop: Optional[torch.Tensor] = None
|
||||
depth_path: Union[str, List[str], None] = None
|
||||
depth_map: Optional[torch.Tensor] = None
|
||||
depth_mask: Optional[torch.Tensor] = None
|
||||
mask_path: Union[str, List[str], None] = None
|
||||
fg_probability: Optional[torch.Tensor] = None
|
||||
bbox_xywh: Optional[torch.Tensor] = None
|
||||
camera: Optional[PerspectiveCameras] = None
|
||||
camera_quality_score: Optional[torch.Tensor] = None
|
||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||
sequence_point_cloud_path: Union[str, List[str], None] = None
|
||||
sequence_point_cloud: Optional[Pointclouds] = None
|
||||
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
||||
frame_type: Union[str, List[str], None] = None # seen | unseen
|
||||
meta: dict = field(default_factory=lambda: {})
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new_params = {}
|
||||
for f in fields(self):
|
||||
value = getattr(self, f.name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
new_params[f.name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[f.name] = value
|
||||
return type(self)(**new_params)
|
||||
|
||||
def cpu(self):
|
||||
return self.to(device=torch.device("cpu"))
|
||||
|
||||
def cuda(self):
|
||||
return self.to(device=torch.device("cuda"))
|
||||
|
||||
# the following functions make sure **frame_data can be passed to functions
|
||||
def __iter__(self):
|
||||
for f in fields(self):
|
||||
yield f.name
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
@classmethod
|
||||
def collate(cls, batch):
|
||||
"""
|
||||
Given a list objects `batch` of class `cls`, collates them into a batched
|
||||
representation suitable for processing with deep networks.
|
||||
"""
|
||||
|
||||
elem = batch[0]
|
||||
|
||||
if isinstance(elem, cls):
|
||||
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
||||
id_to_idx = defaultdict(list)
|
||||
for i, pc_id in enumerate(pointcloud_ids):
|
||||
id_to_idx[pc_id].append(i)
|
||||
|
||||
sequence_point_cloud = []
|
||||
sequence_point_cloud_idx = -np.ones((len(batch),))
|
||||
for i, ind in enumerate(id_to_idx.values()):
|
||||
sequence_point_cloud_idx[ind] = i
|
||||
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
||||
assert (sequence_point_cloud_idx >= 0).all()
|
||||
|
||||
override_fields = {
|
||||
"sequence_point_cloud": sequence_point_cloud,
|
||||
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
||||
}
|
||||
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
||||
|
||||
collated = {}
|
||||
for f in fields(elem):
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
collated[f.name] = (
|
||||
cls.collate(list_values)
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data._utils.collate.default_collate(batch)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
Base class to describe a dataset to be used with Implicitron.
|
||||
|
||||
The dataset is made up of frames, and the frames are grouped into sequences.
|
||||
Each sequence has a name (a string).
|
||||
(A sequence could be a video, or a set of images of one scene.)
|
||||
|
||||
This means they have a __getitem__ which returns an instance of a FrameData,
|
||||
which will describe one frame in one sequence.
|
||||
"""
|
||||
|
||||
# Maps sequence name to the sequence's global frame indices.
|
||||
# It is used for the default implementations of some functions in this class.
|
||||
# Implementations which override them are free to ignore this member.
|
||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_frame_numbers_and_timestamps(
|
||||
self, idxs: Sequence[int]
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
If the sequences in the dataset are videos rather than
|
||||
unordered views, then the dataset should override this method to
|
||||
return the index and timestamp in their videos of the frames whose
|
||||
indices are given in `idxs`. In addition,
|
||||
the values in _seq_to_idx should be in ascending order.
|
||||
If timestamps are absent, they should be replaced with a constant.
|
||||
|
||||
This is used for letting SceneBatchSampler identify consecutive
|
||||
frames.
|
||||
|
||||
Args:
|
||||
idx: frame index in self
|
||||
|
||||
Returns:
|
||||
tuple of
|
||||
- frame index in video
|
||||
- timestamp of frame in video
|
||||
"""
|
||||
raise ValueError("This dataset does not contain videos.")
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return None
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
"""Returns an iterator over the frame indices in a given sequence.
|
||||
We attempt to first sort by timestamp (if they are available),
|
||||
then by frame number.
|
||||
|
||||
Args:
|
||||
seq_name: the name of the sequence.
|
||||
|
||||
Returns:
|
||||
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||
where `frame_no` is the index within the sequence, and
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
yield from sorted(
|
||||
[
|
||||
(timestamp, frame_no, idx)
|
||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||
]
|
||||
)
|
||||
|
||||
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||
yield idx
|
||||
|
||||
|
||||
class FrameAnnotsEntry(TypedDict):
|
||||
subset: Optional[str]
|
||||
frame_annotation: types.FrameAnnotation
|
||||
|
@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from .implicitron_dataset import ImplicitronDatasetBase
|
||||
from .dataset_base import ImplicitronDatasetBase
|
||||
|
||||
|
||||
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
||||
|
@ -10,7 +10,8 @@ import torch
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
||||
from pytorch3d.structures import Pointclouds
|
||||
|
||||
from .implicitron_dataset import FrameData, ImplicitronDataset
|
||||
from .dataset_base import FrameData
|
||||
from .implicitron_dataset import ImplicitronDataset
|
||||
|
||||
|
||||
def get_implicitron_sequence_pointcloud(
|
||||
|
@ -14,12 +14,9 @@ import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
ImplicitronDatasetBase,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
aggregate_nvs_results,
|
||||
|
@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||
from pytorch3d.implicitron.tools import vis_utils
|
||||
|
@ -9,7 +9,7 @@ import unittest
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
|
@ -14,10 +14,8 @@ import unittest
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
|
||||
|
Loading…
x
Reference in New Issue
Block a user