mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: We currently support caching bounding boxes in MaskAnnotation. If present, they are not re-computed from the mask. However, the masks need to be loaded for the bbox to be set. This diff fixes that. Even if load_masks / load_blobs are unset, the bounding box can be picked up from the metadata. Reviewed By: bottler Differential Revision: D45144918 fbshipit-source-id: 8a2e2c115e96070b6fcdc29cbe57e1cee606ddcd
726 lines
30 KiB
Python
726 lines
30 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field, fields
|
|
from typing import (
|
|
Any,
|
|
ClassVar,
|
|
Generic,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from pytorch3d.implicitron.dataset import types
|
|
from pytorch3d.implicitron.dataset.utils import (
|
|
adjust_camera_to_bbox_crop_,
|
|
adjust_camera_to_image_scale_,
|
|
bbox_xyxy_to_xywh,
|
|
clamp_box_to_image_bounds_and_round,
|
|
crop_around_box,
|
|
GenericWorkaround,
|
|
get_bbox_from_mask,
|
|
get_clamp_bbox,
|
|
load_depth,
|
|
load_depth_mask,
|
|
load_image,
|
|
load_mask,
|
|
load_pointcloud,
|
|
rescale_bbox,
|
|
resize_image,
|
|
safe_as_tensor,
|
|
)
|
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
|
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
|
|
|
|
|
@dataclass
|
|
class FrameData(Mapping[str, Any]):
|
|
"""
|
|
A type of the elements returned by indexing the dataset object.
|
|
It can represent both individual frames and batches of thereof;
|
|
in this documentation, the sizes of tensors refer to single frames;
|
|
add the first batch dimension for the collation result.
|
|
|
|
Args:
|
|
frame_number: The number of the frame within its sequence.
|
|
0-based continuous integers.
|
|
sequence_name: The unique name of the frame's sequence.
|
|
sequence_category: The object category of the sequence.
|
|
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
|
image_size_hw: The size of the original image in pixels; (height, width)
|
|
tensor of shape (2,). Note that it is optional, e.g. it can be `None`
|
|
if the frame annotation has no size ans image_rgb has not [yet] been
|
|
loaded. Image-less FrameData is valid but mutators like crop/resize
|
|
may fail if the original image size cannot be deduced.
|
|
effective_image_size_hw: The size of the image after mutations such as
|
|
crop/resize in pixels; (height, width). if the image has not been mutated,
|
|
it is equal to `image_size_hw`. Note that it is also optional, for the
|
|
same reason as `image_size_hw`.
|
|
image_path: The qualified path to the loaded image (with dataset_root).
|
|
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
|
of the frame; elements are floats in [0, 1].
|
|
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
|
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
|
are a result of zero-padding of the image after cropping around
|
|
the object bounding box; elements are floats in {0.0, 1.0}.
|
|
depth_path: The qualified path to the frame's depth map.
|
|
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
|
of the frame; values correspond to distances from the camera;
|
|
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
|
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
|
depth map that are valid for evaluation, they have been checked for
|
|
consistency across views; elements are floats in {0.0, 1.0}.
|
|
mask_path: A qualified path to the foreground probability mask.
|
|
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
|
pixels belonging to the captured object; elements are floats
|
|
in [0, 1].
|
|
bbox_xywh: The bounding box tightly enclosing the foreground object in the
|
|
format (x0, y0, width, height). The convention assumes that
|
|
`x0+width` and `y0+height` includes the boundary of the box.
|
|
I.e., to slice out the corresponding crop from an image tensor `I`
|
|
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
|
|
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
|
|
in the original image coordinates in the format (x0, y0, width, height).
|
|
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
|
|
from `bbox_xywh` due to padding (which can happen e.g. due to
|
|
setting `JsonIndexDataset.box_crop_context > 0`)
|
|
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
|
corrected for cropping if it happened.
|
|
camera_quality_score: The score proportional to the confidence of the
|
|
frame's camera estimation (the higher the more accurate).
|
|
point_cloud_quality_score: The score proportional to the accuracy of the
|
|
frame's sequence point cloud (the higher the more accurate).
|
|
sequence_point_cloud_path: The path to the sequence's point cloud.
|
|
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
|
point cloud corresponding to the frame's sequence. When the object
|
|
represents a batch of frames, point clouds may be deduplicated;
|
|
see `sequence_point_cloud_idx`.
|
|
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
|
corresponding point clouds in `sequence_point_cloud`; to get the
|
|
corresponding point cloud to `image_rgb[i]`, use
|
|
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
|
frame_type: The type of the loaded frame specified in
|
|
`subset_lists_file`, if provided.
|
|
meta: A dict for storing additional frame information.
|
|
"""
|
|
|
|
frame_number: Optional[torch.LongTensor]
|
|
sequence_name: Union[str, List[str]]
|
|
sequence_category: Union[str, List[str]]
|
|
frame_timestamp: Optional[torch.Tensor] = None
|
|
image_size_hw: Optional[torch.LongTensor] = None
|
|
effective_image_size_hw: Optional[torch.LongTensor] = None
|
|
image_path: Union[str, List[str], None] = None
|
|
image_rgb: Optional[torch.Tensor] = None
|
|
# masks out padding added due to cropping the square bit
|
|
mask_crop: Optional[torch.Tensor] = None
|
|
depth_path: Union[str, List[str], None] = None
|
|
depth_map: Optional[torch.Tensor] = None
|
|
depth_mask: Optional[torch.Tensor] = None
|
|
mask_path: Union[str, List[str], None] = None
|
|
fg_probability: Optional[torch.Tensor] = None
|
|
bbox_xywh: Optional[torch.Tensor] = None
|
|
crop_bbox_xywh: Optional[torch.Tensor] = None
|
|
camera: Optional[PerspectiveCameras] = None
|
|
camera_quality_score: Optional[torch.Tensor] = None
|
|
point_cloud_quality_score: Optional[torch.Tensor] = None
|
|
sequence_point_cloud_path: Union[str, List[str], None] = None
|
|
sequence_point_cloud: Optional[Pointclouds] = None
|
|
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
|
frame_type: Union[str, List[str], None] = None # known | unseen
|
|
meta: dict = field(default_factory=lambda: {})
|
|
|
|
# NOTE that batching resets this attribute
|
|
_uncropped: bool = field(init=False, default=True)
|
|
|
|
def to(self, *args, **kwargs):
|
|
new_params = {}
|
|
for field_name in iter(self):
|
|
value = getattr(self, field_name)
|
|
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
|
new_params[field_name] = value.to(*args, **kwargs)
|
|
else:
|
|
new_params[field_name] = value
|
|
frame_data = type(self)(**new_params)
|
|
frame_data._uncropped = self._uncropped
|
|
return frame_data
|
|
|
|
def cpu(self):
|
|
return self.to(device=torch.device("cpu"))
|
|
|
|
def cuda(self):
|
|
return self.to(device=torch.device("cuda"))
|
|
|
|
# the following functions make sure **frame_data can be passed to functions
|
|
def __iter__(self):
|
|
for f in fields(self):
|
|
if f.name.startswith("_"):
|
|
continue
|
|
|
|
yield f.name
|
|
|
|
def __getitem__(self, key):
|
|
return getattr(self, key)
|
|
|
|
def __len__(self):
|
|
return sum(1 for f in iter(self))
|
|
|
|
def crop_by_metadata_bbox_(
|
|
self,
|
|
box_crop_context: float,
|
|
) -> None:
|
|
"""Crops the frame data in-place by (possibly expanded) bounding box.
|
|
The bounding box is taken from the object state (usually taken from
|
|
the frame annotation or estimated from the foregroubnd mask).
|
|
If the expanded bounding box does not fit the image, it is clamped,
|
|
i.e. the image is *not* padded.
|
|
|
|
Args:
|
|
box_crop_context: rate of expansion for bbox; 0 means no expansion,
|
|
|
|
Raises:
|
|
ValueError: If the object does not contain a bounding box (usually when no
|
|
mask annotation is provided)
|
|
ValueError: If the frame data have been cropped or resized, thus the intrinsic
|
|
bounding box is not valid for the current image size.
|
|
ValueError: If the frame does not have an image size (usually a corner case
|
|
when no image has been loaded)
|
|
"""
|
|
if self.bbox_xywh is None:
|
|
raise ValueError("Attempted cropping by metadata with empty bounding box")
|
|
|
|
if not self._uncropped:
|
|
raise ValueError(
|
|
"Trying to apply the metadata bounding box to already cropped "
|
|
"or resized image; coordinates have changed."
|
|
)
|
|
|
|
self._crop_by_bbox_(
|
|
box_crop_context,
|
|
self.bbox_xywh,
|
|
)
|
|
|
|
def crop_by_given_bbox_(
|
|
self,
|
|
box_crop_context: float,
|
|
bbox_xywh: torch.Tensor,
|
|
) -> None:
|
|
"""Crops the frame data in-place by (possibly expanded) bounding box.
|
|
If the expanded bounding box does not fit the image, it is clamped,
|
|
i.e. the image is *not* padded.
|
|
|
|
Args:
|
|
box_crop_context: rate of expansion for bbox; 0 means no expansion,
|
|
bbox_xywh: bounding box in [x0, y0, width, height] format. If float
|
|
tensor, values are floored (after converting to [x0, y0, x1, y1]).
|
|
|
|
Raises:
|
|
ValueError: If the frame does not have an image size (usually a corner case
|
|
when no image has been loaded)
|
|
"""
|
|
self._crop_by_bbox_(
|
|
box_crop_context,
|
|
bbox_xywh,
|
|
)
|
|
|
|
def _crop_by_bbox_(
|
|
self,
|
|
box_crop_context: float,
|
|
bbox_xywh: torch.Tensor,
|
|
) -> None:
|
|
"""Crops the frame data in-place by (possibly expanded) bounding box.
|
|
If the expanded bounding box does not fit the image, it is clamped,
|
|
i.e. the image is *not* padded.
|
|
|
|
Args:
|
|
box_crop_context: rate of expansion for bbox; 0 means no expansion,
|
|
bbox_xywh: bounding box in [x0, y0, width, height] format. If float
|
|
tensor, values are floored (after converting to [x0, y0, x1, y1]).
|
|
|
|
Raises:
|
|
ValueError: If the frame does not have an image size (usually a corner case
|
|
when no image has been loaded)
|
|
"""
|
|
effective_image_size_hw = self.effective_image_size_hw
|
|
if effective_image_size_hw is None:
|
|
raise ValueError("Calling crop on image-less FrameData")
|
|
|
|
bbox_xyxy = get_clamp_bbox(
|
|
bbox_xywh,
|
|
image_path=self.image_path, # pyre-ignore
|
|
box_crop_context=box_crop_context,
|
|
)
|
|
clamp_bbox_xyxy = clamp_box_to_image_bounds_and_round(
|
|
bbox_xyxy,
|
|
image_size_hw=tuple(self.effective_image_size_hw), # pyre-ignore
|
|
)
|
|
crop_bbox_xywh = bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
|
|
|
if self.fg_probability is not None:
|
|
self.fg_probability = crop_around_box(
|
|
self.fg_probability,
|
|
clamp_bbox_xyxy,
|
|
self.mask_path, # pyre-ignore
|
|
)
|
|
if self.image_rgb is not None:
|
|
self.image_rgb = crop_around_box(
|
|
self.image_rgb,
|
|
clamp_bbox_xyxy,
|
|
self.image_path, # pyre-ignore
|
|
)
|
|
|
|
depth_map = self.depth_map
|
|
if depth_map is not None:
|
|
clamp_bbox_xyxy_depth = rescale_bbox(
|
|
clamp_bbox_xyxy, tuple(depth_map.shape[-2:]), effective_image_size_hw
|
|
).long()
|
|
self.depth_map = crop_around_box(
|
|
depth_map,
|
|
clamp_bbox_xyxy_depth,
|
|
self.depth_path, # pyre-ignore
|
|
)
|
|
|
|
depth_mask = self.depth_mask
|
|
if depth_mask is not None:
|
|
clamp_bbox_xyxy_depth = rescale_bbox(
|
|
clamp_bbox_xyxy, tuple(depth_mask.shape[-2:]), effective_image_size_hw
|
|
).long()
|
|
self.depth_mask = crop_around_box(
|
|
depth_mask,
|
|
clamp_bbox_xyxy_depth,
|
|
self.mask_path, # pyre-ignore
|
|
)
|
|
|
|
# changing principal_point according to bbox_crop
|
|
if self.camera is not None:
|
|
adjust_camera_to_bbox_crop_(
|
|
camera=self.camera,
|
|
image_size_wh=effective_image_size_hw.flip(dims=[-1]),
|
|
clamp_bbox_xywh=crop_bbox_xywh,
|
|
)
|
|
|
|
# pyre-ignore
|
|
self.effective_image_size_hw = crop_bbox_xywh[..., 2:].flip(dims=[-1])
|
|
self._uncropped = False
|
|
|
|
def resize_frame_(self, new_size_hw: torch.LongTensor) -> None:
|
|
"""Resizes frame data in-place according to given dimensions.
|
|
|
|
Args:
|
|
new_size_hw: target image size [height, width], a LongTensor of shape (2,)
|
|
|
|
Raises:
|
|
ValueError: If the frame does not have an image size (usually a corner case
|
|
when no image has been loaded)
|
|
"""
|
|
|
|
effective_image_size_hw = self.effective_image_size_hw
|
|
if effective_image_size_hw is None:
|
|
raise ValueError("Calling resize on image-less FrameData")
|
|
|
|
image_height, image_width = new_size_hw.tolist()
|
|
|
|
if self.fg_probability is not None:
|
|
self.fg_probability, _, _ = resize_image(
|
|
self.fg_probability,
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
mode="nearest",
|
|
)
|
|
|
|
if self.image_rgb is not None:
|
|
self.image_rgb, _, self.mask_crop = resize_image(
|
|
self.image_rgb, image_height=image_height, image_width=image_width
|
|
)
|
|
|
|
if self.depth_map is not None:
|
|
self.depth_map, _, _ = resize_image(
|
|
self.depth_map,
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
mode="nearest",
|
|
)
|
|
|
|
if self.depth_mask is not None:
|
|
self.depth_mask, _, _ = resize_image(
|
|
self.depth_mask,
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
mode="nearest",
|
|
)
|
|
|
|
if self.camera is not None:
|
|
if self.image_size_hw is None:
|
|
raise ValueError(
|
|
"image_size_hw has to be defined for resizing FrameData with cameras."
|
|
)
|
|
adjust_camera_to_image_scale_(
|
|
camera=self.camera,
|
|
original_size_wh=effective_image_size_hw.flip(dims=[-1]),
|
|
new_size_wh=new_size_hw.flip(dims=[-1]), # pyre-ignore
|
|
)
|
|
|
|
self.effective_image_size_hw = new_size_hw
|
|
self._uncropped = False
|
|
|
|
@classmethod
|
|
def collate(cls, batch):
|
|
"""
|
|
Given a list objects `batch` of class `cls`, collates them into a batched
|
|
representation suitable for processing with deep networks.
|
|
"""
|
|
|
|
elem = batch[0]
|
|
|
|
if isinstance(elem, cls):
|
|
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
|
id_to_idx = defaultdict(list)
|
|
for i, pc_id in enumerate(pointcloud_ids):
|
|
id_to_idx[pc_id].append(i)
|
|
|
|
sequence_point_cloud = []
|
|
sequence_point_cloud_idx = -np.ones((len(batch),))
|
|
for i, ind in enumerate(id_to_idx.values()):
|
|
sequence_point_cloud_idx[ind] = i
|
|
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
|
assert (sequence_point_cloud_idx >= 0).all()
|
|
|
|
override_fields = {
|
|
"sequence_point_cloud": sequence_point_cloud,
|
|
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
|
}
|
|
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
|
|
|
collated = {}
|
|
for f in fields(elem):
|
|
if not f.init:
|
|
continue
|
|
|
|
list_values = override_fields.get(
|
|
f.name, [getattr(d, f.name) for d in batch]
|
|
)
|
|
collated[f.name] = (
|
|
cls.collate(list_values)
|
|
if all(list_value is not None for list_value in list_values)
|
|
else None
|
|
)
|
|
return cls(**collated)
|
|
|
|
elif isinstance(elem, Pointclouds):
|
|
return join_pointclouds_as_batch(batch)
|
|
|
|
elif isinstance(elem, CamerasBase):
|
|
# TODO: don't store K; enforce working in NDC space
|
|
return join_cameras_as_batch(batch)
|
|
else:
|
|
return torch.utils.data._utils.collate.default_collate(batch)
|
|
|
|
|
|
FrameDataSubtype = TypeVar("FrameDataSubtype", bound=FrameData)
|
|
|
|
|
|
class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
|
"""A base class for FrameDataBuilders that build a FrameData object, load and
|
|
process the binary data (crop and resize). Implementations should parametrize
|
|
the class with a subtype of FrameData and set frame_data_type class variable to
|
|
that type. They have to also implement `build` method.
|
|
"""
|
|
|
|
# To be initialised to FrameDataSubtype
|
|
frame_data_type: ClassVar[Type[FrameDataSubtype]]
|
|
|
|
@abstractmethod
|
|
def build(
|
|
self,
|
|
frame_annotation: types.FrameAnnotation,
|
|
sequence_annotation: types.SequenceAnnotation,
|
|
) -> FrameDataSubtype:
|
|
"""An abstract method to build the frame data based on raw frame/sequence
|
|
annotations, load the binary data and adjust them according to the metadata.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|
"""
|
|
A class to build a FrameData object, load and process the binary data (crop and
|
|
resize). This is an abstract class for extending to build FrameData subtypes. Most
|
|
users need to use concrete `FrameDataBuilder` class instead.
|
|
Beware that modifications of frame data are done in-place.
|
|
|
|
Args:
|
|
dataset_root: The root folder of the dataset; all the paths in jsons are
|
|
specified relative to this root (but not json paths themselves).
|
|
load_images: Enable loading the frame RGB data.
|
|
load_depths: Enable loading the frame depth maps.
|
|
load_depth_masks: Enable loading the frame depth map masks denoting the
|
|
depth values used for evaluation (the points consistent across views).
|
|
load_masks: Enable loading frame foreground masks.
|
|
load_point_clouds: Enable loading sequence-level point clouds.
|
|
max_points: Cap on the number of loaded points in the point cloud;
|
|
if reached, they are randomly sampled without replacement.
|
|
mask_images: Whether to mask the images with the loaded foreground masks;
|
|
0 value is used for background.
|
|
mask_depths: Whether to mask the depth maps with the loaded foreground
|
|
masks; 0 value is used for background.
|
|
image_height: The height of the returned images, masks, and depth maps;
|
|
aspect ratio is preserved during cropping/resizing.
|
|
image_width: The width of the returned images, masks, and depth maps;
|
|
aspect ratio is preserved during cropping/resizing.
|
|
box_crop: Enable cropping of the image around the bounding box inferred
|
|
from the foreground region of the loaded segmentation mask; masks
|
|
and depth maps are cropped accordingly; cameras are corrected.
|
|
box_crop_mask_thr: The threshold used to separate pixels into foreground
|
|
and background based on the foreground_probability mask; if no value
|
|
is greater than this threshold, the loader lowers it and repeats.
|
|
box_crop_context: The amount of additional padding added to each
|
|
dimension of the cropping bounding box, relative to box size.
|
|
path_manager: Optionally a PathManager for interpreting paths in a special way.
|
|
"""
|
|
|
|
dataset_root: str = ""
|
|
load_images: bool = True
|
|
load_depths: bool = True
|
|
load_depth_masks: bool = True
|
|
load_masks: bool = True
|
|
load_point_clouds: bool = False
|
|
max_points: int = 0
|
|
mask_images: bool = False
|
|
mask_depths: bool = False
|
|
image_height: Optional[int] = 800
|
|
image_width: Optional[int] = 800
|
|
box_crop: bool = True
|
|
box_crop_mask_thr: float = 0.4
|
|
box_crop_context: float = 0.3
|
|
path_manager: Any = None
|
|
|
|
def build(
|
|
self,
|
|
frame_annotation: types.FrameAnnotation,
|
|
sequence_annotation: types.SequenceAnnotation,
|
|
load_blobs: bool = True,
|
|
) -> FrameDataSubtype:
|
|
"""Builds the frame data based on raw frame/sequence annotations, loads the
|
|
binary data and adjust them according to the metadata. The processing includes:
|
|
* if box_crop is set, the image/mask/depth are cropped with the bounding
|
|
box provided or estimated from MaskAnnotation,
|
|
* if image_height/image_width are set, the image/mask/depth are resized to
|
|
fit that resolution. Note that the aspect ratio is preserved, and the
|
|
(possibly cropped) image is pasted into the top-left corner. In the
|
|
resulting frame_data, mask_crop field corresponds to the mask of the
|
|
pasted image.
|
|
|
|
Args:
|
|
frame_annotation: frame annotation
|
|
sequence_annotation: sequence annotation
|
|
load_blobs: if the function should attempt loading the image, depth map
|
|
and mask, and foreground mask
|
|
|
|
Returns:
|
|
The constructed FrameData object.
|
|
"""
|
|
|
|
point_cloud = sequence_annotation.point_cloud
|
|
|
|
frame_data = self.frame_data_type(
|
|
frame_number=safe_as_tensor(frame_annotation.frame_number, torch.long),
|
|
frame_timestamp=safe_as_tensor(
|
|
frame_annotation.frame_timestamp, torch.float
|
|
),
|
|
sequence_name=frame_annotation.sequence_name,
|
|
sequence_category=sequence_annotation.category,
|
|
camera_quality_score=safe_as_tensor(
|
|
sequence_annotation.viewpoint_quality_score, torch.float
|
|
),
|
|
point_cloud_quality_score=safe_as_tensor(
|
|
point_cloud.quality_score, torch.float
|
|
)
|
|
if point_cloud is not None
|
|
else None,
|
|
)
|
|
|
|
mask_annotation = frame_annotation.mask
|
|
if mask_annotation is not None:
|
|
fg_mask_np: Optional[np.ndarray] = None
|
|
if load_blobs and self.load_masks:
|
|
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
|
frame_data.mask_path = mask_path
|
|
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
|
|
|
bbox_xywh = mask_annotation.bounding_box_xywh
|
|
if bbox_xywh is None and fg_mask_np is not None:
|
|
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
|
|
|
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
|
|
|
|
if frame_annotation.image is not None:
|
|
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
|
frame_data.image_size_hw = image_size_hw # original image size
|
|
# image size after crop/resize
|
|
frame_data.effective_image_size_hw = image_size_hw
|
|
|
|
if load_blobs and self.load_images:
|
|
(
|
|
frame_data.image_rgb,
|
|
frame_data.image_path,
|
|
) = self._load_images(frame_annotation, frame_data.fg_probability)
|
|
|
|
if load_blobs and self.load_depths and frame_annotation.depth is not None:
|
|
(
|
|
frame_data.depth_map,
|
|
frame_data.depth_path,
|
|
frame_data.depth_mask,
|
|
) = self._load_mask_depth(frame_annotation, frame_data.fg_probability)
|
|
|
|
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
|
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
|
frame_data.sequence_point_cloud = load_pointcloud(
|
|
self._local_path(pcl_path), max_points=self.max_points
|
|
)
|
|
frame_data.sequence_point_cloud_path = pcl_path
|
|
|
|
if frame_annotation.viewpoint is not None:
|
|
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
|
|
|
if self.box_crop:
|
|
frame_data.crop_by_metadata_bbox_(self.box_crop_context)
|
|
|
|
if self.image_height is not None and self.image_width is not None:
|
|
new_size = (self.image_height, self.image_width)
|
|
frame_data.resize_frame_(
|
|
new_size_hw=torch.tensor(new_size, dtype=torch.long), # pyre-ignore
|
|
)
|
|
|
|
return frame_data
|
|
|
|
def _load_fg_probability(
|
|
self, entry: types.FrameAnnotation
|
|
) -> Tuple[np.ndarray, str]:
|
|
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
|
|
fg_probability = load_mask(self._local_path(full_path))
|
|
if fg_probability.shape[-2:] != entry.image.size:
|
|
raise ValueError(
|
|
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
|
)
|
|
|
|
return fg_probability, full_path
|
|
|
|
def _load_images(
|
|
self,
|
|
entry: types.FrameAnnotation,
|
|
fg_probability: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, str]:
|
|
assert self.dataset_root is not None and entry.image is not None
|
|
path = os.path.join(self.dataset_root, entry.image.path)
|
|
image_rgb = load_image(self._local_path(path))
|
|
|
|
if image_rgb.shape[-2:] != entry.image.size:
|
|
raise ValueError(
|
|
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
|
|
)
|
|
|
|
if self.mask_images:
|
|
assert fg_probability is not None
|
|
image_rgb *= fg_probability
|
|
|
|
return image_rgb, path
|
|
|
|
def _load_mask_depth(
|
|
self,
|
|
entry: types.FrameAnnotation,
|
|
fg_probability: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
|
entry_depth = entry.depth
|
|
assert entry_depth is not None
|
|
path = os.path.join(self.dataset_root, entry_depth.path)
|
|
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
|
|
|
if self.mask_depths:
|
|
assert fg_probability is not None
|
|
depth_map *= fg_probability
|
|
|
|
if self.load_depth_masks:
|
|
assert entry_depth.mask_path is not None
|
|
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
|
depth_mask = load_depth_mask(self._local_path(mask_path))
|
|
else:
|
|
depth_mask = torch.ones_like(depth_map)
|
|
|
|
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
|
|
|
|
def _get_pytorch3d_camera(
|
|
self,
|
|
entry: types.FrameAnnotation,
|
|
) -> PerspectiveCameras:
|
|
entry_viewpoint = entry.viewpoint
|
|
assert entry_viewpoint is not None
|
|
# principal point and focal length
|
|
principal_point = torch.tensor(
|
|
entry_viewpoint.principal_point, dtype=torch.float
|
|
)
|
|
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
|
|
|
|
format = entry_viewpoint.intrinsics_format
|
|
if entry_viewpoint.intrinsics_format == "ndc_norm_image_bounds":
|
|
# legacy PyTorch3D NDC format
|
|
# convert to pixels unequally and convert to ndc equally
|
|
image_size_as_list = list(reversed(entry.image.size))
|
|
image_size_wh = torch.tensor(image_size_as_list, dtype=torch.float)
|
|
per_axis_scale = image_size_wh / image_size_wh.min()
|
|
focal_length = focal_length * per_axis_scale
|
|
principal_point = principal_point * per_axis_scale
|
|
elif entry_viewpoint.intrinsics_format != "ndc_isotropic":
|
|
raise ValueError(f"Unknown intrinsics format: {format}")
|
|
|
|
return PerspectiveCameras(
|
|
focal_length=focal_length[None],
|
|
principal_point=principal_point[None],
|
|
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
|
|
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
|
)
|
|
|
|
def _fix_point_cloud_path(self, path: str) -> str:
|
|
"""
|
|
Fix up a point cloud path from the dataset.
|
|
Some files in Co3Dv2 have an accidental absolute path stored.
|
|
"""
|
|
unwanted_prefix = (
|
|
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
|
)
|
|
if path.startswith(unwanted_prefix):
|
|
path = path[len(unwanted_prefix) :]
|
|
return os.path.join(self.dataset_root, path)
|
|
|
|
def _local_path(self, path: str) -> str:
|
|
if self.path_manager is None:
|
|
return path
|
|
return self.path_manager.get_local_path(path)
|
|
|
|
|
|
@registry.register
|
|
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):
|
|
"""
|
|
A concrete class to build a FrameData object, load and process the binary data (crop
|
|
and resize). Beware that modifications of frame data are done in-place. Please see
|
|
the documentation for `GenericFrameDataBuilder` for the description of parameters
|
|
and methods.
|
|
"""
|
|
|
|
frame_data_type: ClassVar[Type[FrameData]] = FrameData
|