Extract BlobLoader class from JsonIndexDataset and moving crop_by_bbox to FrameData

Summary:
extracted blob loader
added documentation for blob_loader
did some refactoring on fields
for detailed steps and discussions see:
https://github.com/facebookresearch/pytorch3d/pull/1463
https://github.com/fairinternal/pixar_replay/pull/160

Reviewed By: bottler

Differential Revision: D44061728

fbshipit-source-id: eefb21e9679003045d73729f96e6a93a1d4d2d51
This commit is contained in:
Ildar Salakhiev 2023-04-04 07:17:43 -07:00 committed by Facebook GitHub Bot
parent c759fc560f
commit ebdbfde0ce
15 changed files with 1421 additions and 694 deletions

View File

@ -18,8 +18,9 @@ from torch.utils.data import (
Sampler,
)
from .dataset_base import DatasetBase, FrameData
from .dataset_base import DatasetBase
from .dataset_map_provider import DatasetMap
from .frame_data import FrameData
from .scene_batch_sampler import SceneBatchSampler
from .utils import is_known_frame_scalar

View File

@ -5,217 +5,27 @@
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from dataclasses import dataclass, field, fields
from dataclasses import dataclass
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import numpy as np
import torch
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
@dataclass
class FrameData(Mapping[str, Any]):
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
frame_timestamp: The time elapsed since the start of a sequence in sec.
image_size_hw: The size of the image in pixels; (height, width) tensor
of shape (2,).
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box tightly enclosing the foreground object in the
format (x0, y0, width, height). The convention assumes that
`x0+width` and `y0+height` includes the boundary of the box.
I.e., to slice out the corresponding crop from an image tensor `I`
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
in the original image coordinates in the format (x0, y0, width, height).
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
from `bbox_xywh` due to padding (which can happen e.g. due to
setting `JsonIndexDataset.box_crop_context > 0`)
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = None
depth_map: Optional[torch.Tensor] = None
depth_mask: Optional[torch.Tensor] = None
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
crop_bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # known | unseen
meta: dict = field(default_factory=lambda: {})
def to(self, *args, **kwargs):
new_params = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[f.name] = value.to(*args, **kwargs)
else:
new_params[f.name] = value
return type(self)(**new_params)
def cpu(self):
return self.to(device=torch.device("cpu"))
def cuda(self):
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def __iter__(self):
for f in fields(self):
yield f.name
def __getitem__(self, key):
return getattr(self, key)
def __len__(self):
return len(fields(self))
@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)
class _GenericWorkaround:
"""
OmegaConf.structured has a weirdness when you try to apply
it to a dataclass whose first base class is a Generic which is not
Dict. The issue is with a function called get_dict_key_value_types
in omegaconf/_utils.py.
For example this fails:
@dataclass(eq=False)
class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
We avoid the problem by adding this class as an extra base class.
"""
pass
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
@dataclass(eq=False)
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]):
"""
Base class to describe a dataset to be used with Implicitron.

View File

@ -0,0 +1,728 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import (
Any,
ClassVar,
Generic,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
import torch
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.utils import (
adjust_camera_to_bbox_crop_,
adjust_camera_to_image_scale_,
bbox_xyxy_to_xywh,
clamp_box_to_image_bounds_and_round,
crop_around_box,
GenericWorkaround,
get_bbox_from_mask,
get_clamp_bbox,
load_depth,
load_depth_mask,
load_image,
load_mask,
load_pointcloud,
rescale_bbox,
resize_image,
safe_as_tensor,
)
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
@dataclass
class FrameData(Mapping[str, Any]):
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
frame_timestamp: The time elapsed since the start of a sequence in sec.
image_size_hw: The size of the original image in pixels; (height, width)
tensor of shape (2,). Note that it is optional, e.g. it can be `None`
if the frame annotation has no size ans image_rgb has not [yet] been
loaded. Image-less FrameData is valid but mutators like crop/resize
may fail if the original image size cannot be deduced.
effective_image_size_hw: The size of the image after mutations such as
crop/resize in pixels; (height, width). if the image has not been mutated,
it is equal to `image_size_hw`. Note that it is also optional, for the
same reason as `image_size_hw`.
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box tightly enclosing the foreground object in the
format (x0, y0, width, height). The convention assumes that
`x0+width` and `y0+height` includes the boundary of the box.
I.e., to slice out the corresponding crop from an image tensor `I`
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
in the original image coordinates in the format (x0, y0, width, height).
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
from `bbox_xywh` due to padding (which can happen e.g. due to
setting `JsonIndexDataset.box_crop_context > 0`)
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.LongTensor] = None
effective_image_size_hw: Optional[torch.LongTensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = None
depth_map: Optional[torch.Tensor] = None
depth_mask: Optional[torch.Tensor] = None
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
crop_bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # known | unseen
meta: dict = field(default_factory=lambda: {})
# NOTE that batching resets this attribute
_uncropped: bool = field(init=False, default=True)
def to(self, *args, **kwargs):
new_params = {}
for field_name in iter(self):
value = getattr(self, field_name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[field_name] = value.to(*args, **kwargs)
else:
new_params[field_name] = value
frame_data = type(self)(**new_params)
frame_data._uncropped = self._uncropped
return frame_data
def cpu(self):
return self.to(device=torch.device("cpu"))
def cuda(self):
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def __iter__(self):
for f in fields(self):
if f.name.startswith("_"):
continue
yield f.name
def __getitem__(self, key):
return getattr(self, key)
def __len__(self):
return sum(1 for f in iter(self))
def crop_by_metadata_bbox_(
self,
box_crop_context: float,
) -> None:
"""Crops the frame data in-place by (possibly expanded) bounding box.
The bounding box is taken from the object state (usually taken from
the frame annotation or estimated from the foregroubnd mask).
If the expanded bounding box does not fit the image, it is clamped,
i.e. the image is *not* padded.
Args:
box_crop_context: rate of expansion for bbox; 0 means no expansion,
Raises:
ValueError: If the object does not contain a bounding box (usually when no
mask annotation is provided)
ValueError: If the frame data have been cropped or resized, thus the intrinsic
bounding box is not valid for the current image size.
ValueError: If the frame does not have an image size (usually a corner case
when no image has been loaded)
"""
if self.bbox_xywh is None:
raise ValueError("Attempted cropping by metadata with empty bounding box")
if not self._uncropped:
raise ValueError(
"Trying to apply the metadata bounding box to already cropped "
"or resized image; coordinates have changed."
)
self._crop_by_bbox_(
box_crop_context,
self.bbox_xywh,
)
def crop_by_given_bbox_(
self,
box_crop_context: float,
bbox_xywh: torch.Tensor,
) -> None:
"""Crops the frame data in-place by (possibly expanded) bounding box.
If the expanded bounding box does not fit the image, it is clamped,
i.e. the image is *not* padded.
Args:
box_crop_context: rate of expansion for bbox; 0 means no expansion,
bbox_xywh: bounding box in [x0, y0, width, height] format. If float
tensor, values are floored (after converting to [x0, y0, x1, y1]).
Raises:
ValueError: If the frame does not have an image size (usually a corner case
when no image has been loaded)
"""
self._crop_by_bbox_(
box_crop_context,
bbox_xywh,
)
def _crop_by_bbox_(
self,
box_crop_context: float,
bbox_xywh: torch.Tensor,
) -> None:
"""Crops the frame data in-place by (possibly expanded) bounding box.
If the expanded bounding box does not fit the image, it is clamped,
i.e. the image is *not* padded.
Args:
box_crop_context: rate of expansion for bbox; 0 means no expansion,
bbox_xywh: bounding box in [x0, y0, width, height] format. If float
tensor, values are floored (after converting to [x0, y0, x1, y1]).
Raises:
ValueError: If the frame does not have an image size (usually a corner case
when no image has been loaded)
"""
effective_image_size_hw = self.effective_image_size_hw
if effective_image_size_hw is None:
raise ValueError("Calling crop on image-less FrameData")
bbox_xyxy = get_clamp_bbox(
bbox_xywh,
image_path=self.image_path, # pyre-ignore
box_crop_context=box_crop_context,
)
clamp_bbox_xyxy = clamp_box_to_image_bounds_and_round(
bbox_xyxy,
image_size_hw=tuple(self.effective_image_size_hw), # pyre-ignore
)
crop_bbox_xywh = bbox_xyxy_to_xywh(clamp_bbox_xyxy)
if self.fg_probability is not None:
self.fg_probability = crop_around_box(
self.fg_probability,
clamp_bbox_xyxy,
self.mask_path, # pyre-ignore
)
if self.image_rgb is not None:
self.image_rgb = crop_around_box(
self.image_rgb,
clamp_bbox_xyxy,
self.image_path, # pyre-ignore
)
depth_map = self.depth_map
if depth_map is not None:
clamp_bbox_xyxy_depth = rescale_bbox(
clamp_bbox_xyxy, tuple(depth_map.shape[-2:]), effective_image_size_hw
).long()
self.depth_map = crop_around_box(
depth_map,
clamp_bbox_xyxy_depth,
self.depth_path, # pyre-ignore
)
depth_mask = self.depth_mask
if depth_mask is not None:
clamp_bbox_xyxy_depth = rescale_bbox(
clamp_bbox_xyxy, tuple(depth_mask.shape[-2:]), effective_image_size_hw
).long()
self.depth_mask = crop_around_box(
depth_mask,
clamp_bbox_xyxy_depth,
self.mask_path, # pyre-ignore
)
# changing principal_point according to bbox_crop
if self.camera is not None:
adjust_camera_to_bbox_crop_(
camera=self.camera,
image_size_wh=effective_image_size_hw.flip(dims=[-1]),
clamp_bbox_xywh=crop_bbox_xywh,
)
# pyre-ignore
self.effective_image_size_hw = crop_bbox_xywh[..., 2:].flip(dims=[-1])
self._uncropped = False
def resize_frame_(self, new_size_hw: torch.LongTensor) -> None:
"""Resizes frame data in-place according to given dimensions.
Args:
new_size_hw: target image size [height, width], a LongTensor of shape (2,)
Raises:
ValueError: If the frame does not have an image size (usually a corner case
when no image has been loaded)
"""
effective_image_size_hw = self.effective_image_size_hw
if effective_image_size_hw is None:
raise ValueError("Calling resize on image-less FrameData")
image_height, image_width = new_size_hw.tolist()
if self.fg_probability is not None:
self.fg_probability, _, _ = resize_image(
self.fg_probability,
image_height=image_height,
image_width=image_width,
mode="nearest",
)
if self.image_rgb is not None:
self.image_rgb, _, self.mask_crop = resize_image(
self.image_rgb, image_height=image_height, image_width=image_width
)
if self.depth_map is not None:
self.depth_map, _, _ = resize_image(
self.depth_map,
image_height=image_height,
image_width=image_width,
mode="nearest",
)
if self.depth_mask is not None:
self.depth_mask, _, _ = resize_image(
self.depth_mask,
image_height=image_height,
image_width=image_width,
mode="nearest",
)
if self.camera is not None:
if self.image_size_hw is None:
raise ValueError(
"image_size_hw has to be defined for resizing FrameData with cameras."
)
adjust_camera_to_image_scale_(
camera=self.camera,
original_size_wh=effective_image_size_hw.flip(dims=[-1]),
new_size_wh=new_size_hw.flip(dims=[-1]), # pyre-ignore
)
self.effective_image_size_hw = new_size_hw
self._uncropped = False
@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
if not f.init:
continue
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)
FrameDataSubtype = TypeVar("FrameDataSubtype", bound=FrameData)
class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
"""A base class for FrameDataBuilders that build a FrameData object, load and
process the binary data (crop and resize). Implementations should parametrize
the class with a subtype of FrameData and set frame_data_type class variable to
that type. They have to also implement `build` method.
"""
# To be initialised to FrameDataSubtype
frame_data_type: ClassVar[Type[FrameDataSubtype]]
@abstractmethod
def build(
self,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
) -> FrameDataSubtype:
"""An abstract method to build the frame data based on raw frame/sequence
annotations, load the binary data and adjust them according to the metadata.
"""
raise NotImplementedError()
class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
"""
A class to build a FrameData object, load and process the binary data (crop and
resize). This is an abstract class for extending to build FrameData subtypes. Most
users need to use concrete `FrameDataBuilder` class instead.
Beware that modifications of frame data are done in-place.
Args:
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the
depth values used for evaluation (the points consistent across views).
load_masks: Enable loading frame foreground masks.
load_point_clouds: Enable loading sequence-level point clouds.
max_points: Cap on the number of loaded points in the point cloud;
if reached, they are randomly sampled without replacement.
mask_images: Whether to mask the images with the loaded foreground masks;
0 value is used for background.
mask_depths: Whether to mask the depth maps with the loaded foreground
masks; 0 value is used for background.
image_height: The height of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
image_width: The width of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
box_crop: Enable cropping of the image around the bounding box inferred
from the foreground region of the loaded segmentation mask; masks
and depth maps are cropped accordingly; cameras are corrected.
box_crop_mask_thr: The threshold used to separate pixels into foreground
and background based on the foreground_probability mask; if no value
is greater than this threshold, the loader lowers it and repeats.
box_crop_context: The amount of additional padding added to each
dimension of the cropping bounding box, relative to box size.
path_manager: Optionally a PathManager for interpreting paths in a special way.
"""
dataset_root: str = ""
load_images: bool = True
load_depths: bool = True
load_depth_masks: bool = True
load_masks: bool = True
load_point_clouds: bool = False
max_points: int = 0
mask_images: bool = False
mask_depths: bool = False
image_height: Optional[int] = 800
image_width: Optional[int] = 800
box_crop: bool = True
box_crop_mask_thr: float = 0.4
box_crop_context: float = 0.3
path_manager: Any = None
def build(
self,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
load_blobs: bool = True,
) -> FrameDataSubtype:
"""Builds the frame data based on raw frame/sequence annotations, loads the
binary data and adjust them according to the metadata. The processing includes:
* if box_crop is set, the image/mask/depth are cropped with the bounding
box provided or estimated from MaskAnnotation,
* if image_height/image_width are set, the image/mask/depth are resized to
fit that resolution. Note that the aspect ratio is preserved, and the
(possibly cropped) image is pasted into the top-left corner. In the
resulting frame_data, mask_crop field corresponds to the mask of the
pasted image.
Args:
frame_annotation: frame annotation
sequence_annotation: sequence annotation
load_blobs: if the function should attempt loading the image, depth map
and mask, and foreground mask
Returns:
The constructed FrameData object.
"""
point_cloud = sequence_annotation.point_cloud
frame_data = self.frame_data_type(
frame_number=safe_as_tensor(frame_annotation.frame_number, torch.long),
frame_timestamp=safe_as_tensor(
frame_annotation.frame_timestamp, torch.float
),
sequence_name=frame_annotation.sequence_name,
sequence_category=sequence_annotation.category,
camera_quality_score=safe_as_tensor(
sequence_annotation.viewpoint_quality_score, torch.float
),
point_cloud_quality_score=safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
)
if load_blobs and self.load_masks and frame_annotation.mask is not None:
(
frame_data.fg_probability,
frame_data.mask_path,
frame_data.bbox_xywh,
) = self._load_fg_probability(frame_annotation)
if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
frame_data.image_size_hw = image_size_hw # original image size
# image size after crop/resize
frame_data.effective_image_size_hw = image_size_hw
if load_blobs and self.load_images:
(
frame_data.image_rgb,
frame_data.image_path,
) = self._load_images(frame_annotation, frame_data.fg_probability)
if load_blobs and self.load_depths and frame_annotation.depth is not None:
(
frame_data.depth_map,
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(frame_annotation, frame_data.fg_probability)
if load_blobs and self.load_point_clouds and point_cloud is not None:
pcl_path = self._fix_point_cloud_path(point_cloud.path)
frame_data.sequence_point_cloud = load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points
)
frame_data.sequence_point_cloud_path = pcl_path
if frame_annotation.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
if self.box_crop:
frame_data.crop_by_metadata_bbox_(self.box_crop_context)
if self.image_height is not None and self.image_width is not None:
new_size = (self.image_height, self.image_width)
frame_data.resize_frame_(
new_size_hw=torch.tensor(new_size, dtype=torch.long), # pyre-ignore
)
return frame_data
def _load_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
fg_probability = load_mask(self._local_path(full_path))
# we can use provided bbox_xywh or calculate it based on mask
# saves time to skip bbox calculation
# pyre-ignore
bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
fg_probability, self.box_crop_mask_thr
)
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
)
return (
safe_as_tensor(fg_probability, torch.float),
full_path,
safe_as_tensor(bbox_xywh, torch.long),
)
def _load_images(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str]:
assert self.dataset_root is not None and entry.image is not None
path = os.path.join(self.dataset_root, entry.image.path)
image_rgb = load_image(self._local_path(path))
if image_rgb.shape[-2:] != entry.image.size:
raise ValueError(
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
)
if self.mask_images:
assert fg_probability is not None
image_rgb *= fg_probability
return image_rgb, path
def _load_mask_depth(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.mask_depths:
assert fg_probability is not None
depth_map *= fg_probability
if self.load_depth_masks:
assert entry_depth.mask_path is not None
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = load_depth_mask(self._local_path(mask_path))
else:
depth_mask = torch.ones_like(depth_map)
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
def _get_pytorch3d_camera(
self,
entry: types.FrameAnnotation,
) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None
# principal point and focal length
principal_point = torch.tensor(
entry_viewpoint.principal_point, dtype=torch.float
)
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
format = entry_viewpoint.intrinsics_format
if entry_viewpoint.intrinsics_format == "ndc_norm_image_bounds":
# legacy PyTorch3D NDC format
# convert to pixels unequally and convert to ndc equally
image_size_as_list = list(reversed(entry.image.size))
image_size_wh = torch.tensor(image_size_as_list, dtype=torch.float)
per_axis_scale = image_size_wh / image_size_wh.min()
focal_length = focal_length * per_axis_scale
principal_point = principal_point * per_axis_scale
elif entry_viewpoint.intrinsics_format != "ndc_isotropic":
raise ValueError(f"Unknown intrinsics format: {format}")
return PerspectiveCameras(
focal_length=focal_length[None],
principal_point=principal_point[None],
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)
def _fix_point_cloud_path(self, path: str) -> str:
"""
Fix up a point cloud path from the dataset.
Some files in Co3Dv2 have an accidental absolute path stored.
"""
unwanted_prefix = (
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
return os.path.join(self.dataset_root, path)
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
return self.path_manager.get_local_path(path)
@registry.register
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):
"""
A concrete class to build a FrameData object, load and process the binary data (crop
and resize). Beware that modifications of frame data are done in-place. Please see
the documentation for `GenericFrameDataBuilder` for the description of parameters
and methods.
"""
frame_data_type: ClassVar[Type[FrameData]] = FrameData

View File

@ -15,7 +15,6 @@ import random
import warnings
from collections import defaultdict
from itertools import islice
from pathlib import Path
from typing import (
Any,
ClassVar,
@ -30,19 +29,15 @@ from typing import (
Union,
)
import numpy as np
import torch
from PIL import Image
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.frame_data import FrameData, FrameDataBuilder
from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.io import IO
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds
from tqdm import tqdm
from pytorch3d.renderer.cameras import CamerasBase
from . import types
from .dataset_base import DatasetBase, FrameData
from .utils import is_known_frame_scalar
from tqdm import tqdm
logger = logging.getLogger(__name__)
@ -65,7 +60,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
A dataset with annotations in json files like the Common Objects in 3D
(CO3D) dataset.
Args:
Metadata-related args::
frame_annotations_file: A zipped json file containing metadata of the
frames in the dataset, serialized List[types.FrameAnnotation].
sequence_annotations_file: A zipped json file containing metadata of the
@ -83,6 +78,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
pick_sequence: A list of sequence names to restrict the dataset to.
exclude_sequence: A list of the names of the sequences to exclude.
limit_category_to: Restrict the dataset to the given list of categories.
remove_empty_masks: Removes the frames with no active foreground pixels
in the segmentation mask after thresholding (see box_crop_mask_thr).
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
frames in each sequences uniformly without replacement if it has
more frames than that; applied before other frame-level filters.
seed: The seed of the random generator sampling #n_frames_per_sequence
random frames per sequence.
sort_frames: Enable frame annotations sorting to group frames from the
same sequences together and order them by timestamps
eval_batches: A list of batches that form the evaluation set;
list of batch-sized lists of indices corresponding to __getitem__
of this class, thus it can be used directly as a batch sampler.
eval_batch_index:
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
A list of batches of frames described as (sequence_name, frame_idx)
that can form the evaluation set, `eval_batches` will be set from this.
Blob-loading parameters:
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
load_images: Enable loading the frame RGB data.
@ -109,23 +122,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
is greater than this threshold, the loader lowers it and repeats.
box_crop_context: The amount of additional padding added to each
dimension of the cropping bounding box, relative to box size.
remove_empty_masks: Removes the frames with no active foreground pixels
in the segmentation mask after thresholding (see box_crop_mask_thr).
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
frames in each sequences uniformly without replacement if it has
more frames than that; applied before other frame-level filters.
seed: The seed of the random generator sampling #n_frames_per_sequence
random frames per sequence.
sort_frames: Enable frame annotations sorting to group frames from the
same sequences together and order them by timestamps
eval_batches: A list of batches that form the evaluation set;
list of batch-sized lists of indices corresponding to __getitem__
of this class, thus it can be used directly as a batch sampler.
eval_batch_index:
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
A list of batches of frames described as (sequence_name, frame_idx)
that can form the evaluation set, `eval_batches` will be set from this.
"""
frame_annotations_type: ClassVar[
@ -162,12 +158,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
sort_frames: bool = False
eval_batches: Any = None
eval_batch_index: Any = None
# initialised in __post_init__
# commented because of OmegaConf (for tests to pass)
# _frame_data_builder: FrameDataBuilder = field(init=False)
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
def __post_init__(self) -> None:
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
self.subset_to_image_path = None
self._load_frames()
self._load_sequences()
if self.sort_frames:
@ -175,9 +173,27 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
self._load_subset_lists()
self._filter_db() # also computes sequence indices
self._extract_and_set_eval_batches()
# pyre-ignore
self._frame_data_builder = FrameDataBuilder(
dataset_root=self.dataset_root,
load_images=self.load_images,
load_depths=self.load_depths,
load_depth_masks=self.load_depth_masks,
load_masks=self.load_masks,
load_point_clouds=self.load_point_clouds,
max_points=self.max_points,
mask_images=self.mask_images,
mask_depths=self.mask_depths,
image_height=self.image_height,
image_width=self.image_width,
box_crop=self.box_crop,
box_crop_mask_thr=self.box_crop_mask_thr,
box_crop_context=self.box_crop_context,
)
logger.info(str(self))
def _extract_and_set_eval_batches(self):
def _extract_and_set_eval_batches(self) -> None:
"""
Sets eval_batches based on input eval_batch_index.
"""
@ -207,13 +223,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
# https://gist.github.com/treyhunner/f35292e676efa0be1728
functools.reduce(
lambda a, b: {**a, **b},
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
# pyre-ignore[16]
[d.seq_annots for d in other_datasets],
)
)
all_eval_batches = [
self.eval_batches,
# pyre-ignore
*[d.eval_batches for d in other_datasets],
*[d.eval_batches for d in other_datasets], # pyre-ignore[16]
]
if not (
all(ba is None for ba in all_eval_batches)
@ -251,7 +267,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
allow_missing_indices: bool = False,
remove_missing_indices: bool = False,
suppress_missing_index_warning: bool = True,
) -> List[List[Union[Optional[int], int]]]:
) -> Union[List[List[Optional[int]]], List[List[int]]]:
"""
Obtain indices into the dataset object given a list of frame ids.
@ -323,9 +339,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
valid_dataset_idx = [
[b for b in batch if b is not None] for batch in dataset_idx
]
return [ # pyre-ignore[7]
batch for batch in valid_dataset_idx if len(batch) > 0
]
return [batch for batch in valid_dataset_idx if len(batch) > 0]
return dataset_idx
@ -417,255 +431,18 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
entry = self.frame_annots[index]["frame_annotation"]
# pyre-ignore[16]
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
frame_data = FrameData(
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
sequence_name=entry.sequence_name,
sequence_category=self.seq_annots[entry.sequence_name].category,
camera_quality_score=_safe_as_tensor(
self.seq_annots[entry.sequence_name].viewpoint_quality_score,
torch.float,
),
point_cloud_quality_score=_safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
)
# The rest of the fields are optional
# pyre-ignore
frame_data = self._frame_data_builder.build(
entry,
# pyre-ignore
self.seq_annots[entry.sequence_name],
)
# Optional field
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
(
frame_data.fg_probability,
frame_data.mask_path,
frame_data.bbox_xywh,
clamp_bbox_xyxy,
frame_data.crop_bbox_xywh,
) = self._load_crop_fg_probability(entry)
scale = 1.0
if self.load_images and entry.image is not None:
# original image size
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
(
frame_data.image_rgb,
frame_data.image_path,
frame_data.mask_crop,
scale,
) = self._load_crop_images(
entry, frame_data.fg_probability, clamp_bbox_xyxy
)
if self.load_depths and entry.depth is not None:
(
frame_data.depth_map,
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
if entry.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(
entry,
scale,
clamp_bbox_xyxy,
)
if self.load_point_clouds and point_cloud is not None:
pcl_path = self._fix_point_cloud_path(point_cloud.path)
frame_data.sequence_point_cloud = _load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points
)
frame_data.sequence_point_cloud_path = pcl_path
return frame_data
def _fix_point_cloud_path(self, path: str) -> str:
"""
Fix up a point cloud path from the dataset.
Some files in Co3Dv2 have an accidental absolute path stored.
"""
unwanted_prefix = (
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
return os.path.join(self.dataset_root, path)
def _load_crop_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[
Optional[torch.Tensor],
Optional[str],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
fg_probability = None
full_path = None
bbox_xywh = None
clamp_bbox_xyxy = None
crop_box_xywh = None
if (self.load_masks or self.box_crop) and entry.mask is not None:
full_path = os.path.join(self.dataset_root, entry.mask.path)
mask = _load_mask(self._local_path(full_path))
if mask.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
)
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
if self.box_crop:
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
_get_clamp_bbox(
bbox_xywh,
image_path=entry.image.path,
box_crop_context=self.box_crop_context,
),
image_size_hw=tuple(mask.shape[-2:]),
)
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
def _load_crop_images(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
assert self.dataset_root is not None and entry.image is not None
path = os.path.join(self.dataset_root, entry.image.path)
image_rgb = _load_image(self._local_path(path))
if image_rgb.shape[-2:] != entry.image.size:
raise ValueError(
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
)
if self.box_crop:
assert clamp_bbox_xyxy is not None
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
image_rgb, scale, mask_crop = self._resize_image(image_rgb)
if self.mask_images:
assert fg_probability is not None
image_rgb *= fg_probability
return image_rgb, path, mask_crop, scale
def _load_mask_depth(
self,
entry: types.FrameAnnotation,
clamp_bbox_xyxy: Optional[torch.Tensor],
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.box_crop:
assert clamp_bbox_xyxy is not None
depth_bbox_xyxy = _rescale_bbox(
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
)
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
if self.mask_depths:
assert fg_probability is not None
depth_map *= fg_probability
if self.load_depth_masks:
assert entry_depth.mask_path is not None
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = _load_depth_mask(self._local_path(mask_path))
if self.box_crop:
assert clamp_bbox_xyxy is not None
depth_mask_bbox_xyxy = _rescale_bbox(
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
)
depth_mask = _crop_around_box(
depth_mask, depth_mask_bbox_xyxy, mask_path
)
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
else:
depth_mask = torch.ones_like(depth_map)
return depth_map, path, depth_mask
def _get_pytorch3d_camera(
self,
entry: types.FrameAnnotation,
scale: float,
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None
# principal point and focal length
principal_point = torch.tensor(
entry_viewpoint.principal_point, dtype=torch.float
)
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
half_image_size_wh_orig = (
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
)
# first, we convert from the dataset's NDC convention to pixels
format = entry_viewpoint.intrinsics_format
if format.lower() == "ndc_norm_image_bounds":
# this is e.g. currently used in CO3D for storing intrinsics
rescale = half_image_size_wh_orig
elif format.lower() == "ndc_isotropic":
rescale = half_image_size_wh_orig.min()
else:
raise ValueError(f"Unknown intrinsics format: {format}")
# principal point and focal length in pixels
principal_point_px = half_image_size_wh_orig - principal_point * rescale
focal_length_px = focal_length * rescale
if self.box_crop:
assert clamp_bbox_xyxy is not None
principal_point_px -= clamp_bbox_xyxy[:2]
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
if self.image_height is None or self.image_width is None:
out_size = list(reversed(entry.image.size))
else:
out_size = [self.image_width, self.image_height]
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
half_min_image_size_output = half_image_size_output.min()
# rescaled principal point and focal length in ndc
principal_point = (
half_image_size_output - principal_point_px * scale
) / half_min_image_size_output
focal_length = focal_length_px * scale / half_min_image_size_output
return PerspectiveCameras(
focal_length=focal_length[None],
principal_point=principal_point[None],
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)
def _load_frames(self) -> None:
logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
local_file = self._local_path(self.frame_annotations_file)
@ -853,35 +630,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
# pyre-ignore[16]
self._seq_to_idx = seq_to_idx
def _resize_image(
self, image, mode="bilinear"
) -> Tuple[torch.Tensor, float, torch.Tensor]:
image_height, image_width = self.image_height, self.image_width
if image_height is None or image_width is None:
# skip the resizing
imre_ = torch.from_numpy(image)
return imre_, 1.0, torch.ones_like(imre_[:1])
# takes numpy array, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
torch.from_numpy(image)[None],
scale_factor=minscale,
mode=mode,
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
# pyre-fixme[19]: Expected 1 positional argument.
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
mask = torch.zeros(1, self.image_height, self.image_width)
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
return imre_, minscale, mask
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
@ -918,169 +666,3 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
def _seq_name_to_seed(seq_name) -> int:
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
def _load_image(path) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im
def _load_16big_png_depth(depth_png) -> np.ndarray:
with Image.open(depth_png) as depth_pil:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth
def _load_1bit_png_mask(file: str) -> np.ndarray:
with Image.open(file) as pil_im:
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
return mask
def _load_depth_mask(path: str) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path)
m = _load_1bit_png_mask(path)
return m[None] # fake feature channel
def _load_depth(path, scale_adjustment) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth file name "%s"' % path)
d = _load_16big_png_depth(path) * scale_adjustment
d[~np.isfinite(d)] = 0.0
return d[None] # fake feature channel
def _load_mask(path) -> np.ndarray:
with Image.open(path) as pil_im:
mask = np.array(pil_im)
mask = mask.astype(np.float32) / 255.0
return mask[None] # fake feature channel
def _get_1d_bounds(arr) -> Tuple[int, int]:
nz = np.flatnonzero(arr)
return nz[0], nz[-1] + 1
def _get_bbox_from_mask(
mask, thr, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
masks_for_box = (mask > thr).astype(np.float32)
thr -= decrease_quant
if thr <= 0.0:
warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
return x0, y0, x1 - x0, y1 - y0
def _get_clamp_bbox(
bbox: torch.Tensor,
box_crop_context: float = 0.0,
image_path: str = "",
) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
bbox = bbox.clone() # do not edit bbox in place
# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.float()
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c
if (bbox[2:] <= 1.0).any():
raise ValueError(
f"squashed image {image_path}!! The bounding box contains no pixels."
)
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
return bbox_xyxy
def _crop_around_box(tensor, bbox, impath: str = ""):
# bbox is xyxy, where the upper bound is corrected with +1
bbox = _clamp_box_to_image_bounds_and_round(
bbox,
image_size_hw=tensor.shape[-2:],
)
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
return tensor
def _clamp_box_to_image_bounds_and_round(
bbox_xyxy: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.LongTensor:
bbox_xyxy = bbox_xyxy.clone()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
if not isinstance(bbox_xyxy, torch.LongTensor):
bbox_xyxy = bbox_xyxy.round().long()
return bbox_xyxy # pyre-ignore [7]
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
assert bbox is not None
assert np.prod(orig_res) > 1e-8
# average ratio of dimensions
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
return bbox * rel_size
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
wh = xyxy[2:] - xyxy[:2]
xywh = torch.cat([xyxy[:2], wh])
return xywh
def _bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy
def _safe_as_tensor(data, dtype):
if data is None:
return None
return torch.tensor(data, dtype=dtype)
# NOTE this cache is per-worker; they are implemented as processes.
# each batch is loaded and collated by a single worker;
# since sequences tend to co-occur within batches, this is useful.
@functools.lru_cache(maxsize=256)
def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
pcl = IO().load_pointcloud(pcl_path)
if max_points > 0:
pcl = pcl.subsample(max_points)
return pcl

View File

@ -20,8 +20,9 @@ from pytorch3d.implicitron.tools.config import (
)
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
from .dataset_base import DatasetBase, FrameData
from .dataset_base import DatasetBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
from .frame_data import FrameData
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
@ -69,7 +70,8 @@ class SingleSceneDataset(DatasetBase, Configurable):
sequence_name=_SINGLE_SEQUENCE_NAME,
sequence_category=self.object_name,
camera=pose,
image_size_hw=torch.tensor(image.shape[1:]),
# pyre-ignore
image_size_hw=torch.tensor(image.shape[1:], dtype=torch.long),
image_rgb=image,
fg_probability=fg_probability,
frame_type=frame_type,

View File

@ -55,6 +55,8 @@ class MaskAnnotation:
path: str
# (soft) number of pixels in the mask; sum(Prob(fg | pixel))
mass: Optional[float] = None
# tight bounding box around the foreground mask
bounding_box_xywh: Optional[Tuple[float, float, float, float]] = None
@dataclass

View File

@ -5,10 +5,18 @@
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
import functools
import warnings
from pathlib import Path
from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np
import torch
from PIL import Image
from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds
DATASET_TYPE_TRAIN = "train"
DATASET_TYPE_TEST = "test"
@ -16,6 +24,26 @@ DATASET_TYPE_KNOWN = "known"
DATASET_TYPE_UNKNOWN = "unseen"
class GenericWorkaround:
"""
OmegaConf.structured has a weirdness when you try to apply
it to a dataclass whose first base class is a Generic which is not
Dict. The issue is with a function called get_dict_key_value_types
in omegaconf/_utils.py.
For example this fails:
@dataclass(eq=False)
class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
We avoid the problem by adding this class as an extra base class.
"""
pass
def is_known_frame_scalar(frame_type: str) -> bool:
"""
Given a single frame type corresponding to a single frame, return whether
@ -52,3 +80,286 @@ def is_train_frame(
dtype=torch.bool,
device=device,
)
def get_bbox_from_mask(
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
masks_for_box = (mask > thr).astype(np.float32)
thr -= decrease_quant
if thr <= 0.0:
warnings.warn(
f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1
)
x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
return x0, y0, x1 - x0, y1 - y0
def crop_around_box(
tensor: torch.Tensor, bbox: torch.Tensor, impath: str = ""
) -> torch.Tensor:
# bbox is xyxy, where the upper bound is corrected with +1
bbox = clamp_box_to_image_bounds_and_round(
bbox,
image_size_hw=tuple(tensor.shape[-2:]),
)
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
return tensor
def clamp_box_to_image_bounds_and_round(
bbox_xyxy: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.LongTensor:
bbox_xyxy = bbox_xyxy.clone()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
if not isinstance(bbox_xyxy, torch.LongTensor):
bbox_xyxy = bbox_xyxy.round().long()
return bbox_xyxy # pyre-ignore [7]
T = TypeVar("T", bound=torch.Tensor)
def bbox_xyxy_to_xywh(xyxy: T) -> T:
wh = xyxy[2:] - xyxy[:2]
xywh = torch.cat([xyxy[:2], wh])
return xywh # pyre-ignore
def get_clamp_bbox(
bbox: torch.Tensor,
box_crop_context: float = 0.0,
image_path: str = "",
) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
bbox = bbox.clone() # do not edit bbox in place
# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.float()
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c
if (bbox[2:] <= 1.0).any():
raise ValueError(
f"squashed image {image_path}!! The bounding box contains no pixels."
)
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
bbox_xyxy = bbox_xywh_to_xyxy(bbox, clamp_size=2)
return bbox_xyxy
def rescale_bbox(
bbox: torch.Tensor,
orig_res: Union[Tuple[int, int], torch.LongTensor],
new_res: Union[Tuple[int, int], torch.LongTensor],
) -> torch.Tensor:
assert bbox is not None
assert np.prod(orig_res) > 1e-8
# average ratio of dimensions
# pyre-ignore
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
return bbox * rel_size
def bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
nz = np.flatnonzero(arr)
return nz[0], nz[-1] + 1
def resize_image(
image: Union[np.ndarray, torch.Tensor],
image_height: Optional[int],
image_width: Optional[int],
mode: str = "bilinear",
) -> Tuple[torch.Tensor, float, torch.Tensor]:
if type(image) == np.ndarray:
image = torch.from_numpy(image)
if image_height is None or image_width is None:
# skip the resizing
return image, 1.0, torch.ones_like(image[:1])
# takes numpy array or tensor, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
image[None],
scale_factor=minscale,
mode=mode,
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
imre_ = torch.zeros(image.shape[0], image_height, image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, image_height, image_width)
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
return imre_, minscale, mask
def load_image(path: str) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im
def load_mask(path: str) -> np.ndarray:
with Image.open(path) as pil_im:
mask = np.array(pil_im)
mask = mask.astype(np.float32) / 255.0
return mask[None] # fake feature channel
def load_depth(path: str, scale_adjustment: float) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth file name "%s"' % path)
d = load_16big_png_depth(path) * scale_adjustment
d[~np.isfinite(d)] = 0.0
return d[None] # fake feature channel
def load_16big_png_depth(depth_png: str) -> np.ndarray:
with Image.open(depth_png) as depth_pil:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth
def load_1bit_png_mask(file: str) -> np.ndarray:
with Image.open(file) as pil_im:
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
return mask
def load_depth_mask(path: str) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path)
m = load_1bit_png_mask(path)
return m[None] # fake feature channel
def safe_as_tensor(data, dtype):
return torch.tensor(data, dtype=dtype) if data is not None else None
def _convert_ndc_to_pixels(
focal_length: torch.Tensor,
principal_point: torch.Tensor,
image_size_wh: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
half_image_size = image_size_wh / 2
rescale = half_image_size.min()
principal_point_px = half_image_size - principal_point * rescale
focal_length_px = focal_length * rescale
return focal_length_px, principal_point_px
def _convert_pixels_to_ndc(
focal_length_px: torch.Tensor,
principal_point_px: torch.Tensor,
image_size_wh: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
half_image_size = image_size_wh / 2
rescale = half_image_size.min()
principal_point = (half_image_size - principal_point_px) / rescale
focal_length = focal_length_px / rescale
return focal_length, principal_point
def adjust_camera_to_bbox_crop_(
camera: PerspectiveCameras,
image_size_wh: torch.Tensor,
clamp_bbox_xywh: torch.Tensor,
) -> None:
if len(camera) != 1:
raise ValueError("Adjusting currently works with singleton cameras camera only")
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0],
camera.principal_point[0], # pyre-ignore
image_size_wh,
)
principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
focal_length, principal_point_cropped = _convert_pixels_to_ndc(
focal_length_px,
principal_point_px_cropped,
clamp_bbox_xywh[2:],
)
camera.focal_length = focal_length[None]
camera.principal_point = principal_point_cropped[None] # pyre-ignore
def adjust_camera_to_image_scale_(
camera: PerspectiveCameras,
original_size_wh: torch.Tensor,
new_size_wh: torch.LongTensor,
) -> PerspectiveCameras:
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0],
camera.principal_point[0], # pyre-ignore
original_size_wh,
)
# now scale and convert from pixels to NDC
image_size_wh_output = new_size_wh.float()
scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
focal_length_px_scaled = focal_length_px * scale
principal_point_px_scaled = principal_point_px * scale
focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
focal_length_px_scaled,
principal_point_px_scaled,
image_size_wh_output,
)
camera.focal_length = focal_length_scaled[None]
camera.principal_point = principal_point_scaled[None] # pyre-ignore
# NOTE this cache is per-worker; they are implemented as processes.
# each batch is loaded and collated by a single worker;
# since sequences tend to co-occur within batches, this is useful.
@functools.lru_cache(maxsize=256)
def load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
pcl = IO().load_pointcloud(pcl_path)
if max_points > 0:
pcl = pcl.subsample(max_points)
return pcl

View File

@ -10,7 +10,7 @@ import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.structures import Pointclouds
from .dataset_base import FrameData
from .frame_data import FrameData
from .json_index_dataset import JsonIndexDataset

View File

@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Un
import numpy as np
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender
from pytorch3d.implicitron.tools import vis_utils

View File

@ -17,7 +17,8 @@ from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DoublePoolBatchSampler,
)
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler

View File

@ -9,11 +9,19 @@ import unittest
import numpy as np
import torch
from pytorch3d.implicitron.dataset.json_index_dataset import (
_bbox_xywh_to_xyxy,
_bbox_xyxy_to_xywh,
_get_bbox_from_mask,
from pytorch3d.implicitron.dataset.utils import (
bbox_xywh_to_xyxy,
bbox_xyxy_to_xywh,
clamp_box_to_image_bounds_and_round,
crop_around_box,
get_1d_bounds,
get_bbox_from_mask,
get_clamp_bbox,
rescale_bbox,
resize_image,
)
from tests.common_testing import TestCaseMixin
@ -31,9 +39,9 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
]
)
for bbox_xywh in bbox_xywh_list:
bbox_xyxy = _bbox_xywh_to_xyxy(bbox_xywh)
bbox_xywh_ = _bbox_xyxy_to_xywh(bbox_xyxy)
bbox_xyxy_ = _bbox_xywh_to_xyxy(bbox_xywh_)
bbox_xyxy = bbox_xywh_to_xyxy(bbox_xywh)
bbox_xywh_ = bbox_xyxy_to_xywh(bbox_xyxy)
bbox_xyxy_ = bbox_xywh_to_xyxy(bbox_xywh_)
self.assertClose(bbox_xywh_, bbox_xywh)
self.assertClose(bbox_xyxy, bbox_xyxy_)
@ -47,8 +55,8 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
]
)
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected:
self.assertClose(_bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
self.assertClose(_bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
self.assertClose(bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
self.assertClose(bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
clamp_amnt = 3
bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor(
@ -61,7 +69,7 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
)
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected:
self.assertClose(
_bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
bbox_xyxy_expected,
)
@ -74,5 +82,61 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
]
).astype(np.float32)
expected_bbox_xywh = [2, 1, 2, 1]
bbox_xywh = _get_bbox_from_mask(mask, 0.5)
bbox_xywh = get_bbox_from_mask(mask, 0.5)
self.assertClose(bbox_xywh, expected_bbox_xywh)
def test_crop_around_box(self):
bbox = torch.LongTensor([0, 1, 2, 3]) # (x_min, y_min, x_max, y_max)
image = torch.LongTensor(
[
[0, 0, 10, 20],
[10, 20, 5, 1],
[10, 20, 1, 1],
[5, 4, 0, 1],
]
)
cropped = crop_around_box(image, bbox)
self.assertClose(cropped, image[1:3, 0:2])
def test_clamp_box_to_image_bounds_and_round(self):
bbox = torch.LongTensor([0, 1, 10, 12])
image_size = (5, 6)
expected_clamped_bbox = torch.LongTensor([0, 1, image_size[1], image_size[0]])
clamped_bbox = clamp_box_to_image_bounds_and_round(bbox, image_size)
self.assertClose(clamped_bbox, expected_clamped_bbox)
def test_get_clamp_bbox(self):
bbox_xywh = torch.LongTensor([1, 1, 4, 5])
clamped_bbox_xyxy = get_clamp_bbox(bbox_xywh, box_crop_context=2)
# size multiplied by 2 and added coordinates
self.assertClose(clamped_bbox_xyxy, torch.Tensor([-3, -4, 9, 11]))
def test_rescale_bbox(self):
bbox = torch.Tensor([0.0, 1.0, 3.0, 4.0])
original_resolution = (4, 4)
new_resolution = (8, 8) # twice bigger
rescaled_bbox = rescale_bbox(bbox, original_resolution, new_resolution)
self.assertClose(bbox * 2, rescaled_bbox)
def test_get_1d_bounds(self):
array = [0, 1, 2]
bounds = get_1d_bounds(array)
# make nonzero 1d bounds of image
self.assertClose(bounds, [1, 3])
def test_resize_image(self):
image = np.random.rand(3, 300, 500) # rgb image 300x500
expected_shape = (150, 250)
resized_image, scale, mask_crop = resize_image(
image, image_height=expected_shape[0], image_width=expected_shape[1]
)
original_shape = image.shape[-2:]
expected_scale = min(
expected_shape[0] / original_shape[0], expected_shape[1] / original_shape[1]
)
self.assertEqual(scale, expected_scale)
self.assertEqual(resized_image.shape[-2:], expected_shape)
self.assertEqual(mask_crop.shape[-2:], expected_shape)

View File

@ -8,7 +8,7 @@ import os
import unittest
import torch
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import (
RenderedMeshDatasetMapProvider,
)

View File

@ -13,8 +13,10 @@ import os
import unittest
import lpips
import numpy as np
import torch
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
@ -268,7 +270,7 @@ class TestEvaluation(unittest.TestCase):
for metric in lower_better:
m_better = eval_result[metric]
m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse:
if np.isnan(m_better) or np.isnan(m_worse):
continue # metric is missing, i.e. NaN
_assert = (
self.assertLessEqual

View File

@ -0,0 +1,224 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import gzip
import os
import unittest
from typing import List
import numpy as np
import torch
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
from pytorch3d.implicitron.dataset.utils import (
load_16big_png_depth,
load_1bit_png_mask,
load_depth,
load_depth_mask,
load_image,
load_mask,
safe_as_tensor,
)
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer.cameras import PerspectiveCameras
from tests.common_testing import TestCaseMixin
from tests.implicitron.common_resources import get_skateboard_data
class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
category = "skateboard"
stack = contextlib.ExitStack()
self.dataset_root, self.path_manager = stack.enter_context(
get_skateboard_data()
)
self.addCleanup(stack.close)
self.image_height = 768
self.image_width = 512
self.frame_data_builder = FrameDataBuilder(
image_height=self.image_height,
image_width=self.image_width,
dataset_root=self.dataset_root,
path_manager=self.path_manager,
)
# loading single frame annotation of dataset (see JsonIndexDataset._load_frames())
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
local_file = self.path_manager.get_local_path(frame_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
frame_annots_list = types.load_dataclass(
zipfile, List[types.FrameAnnotation]
)
self.frame_annotation = frame_annots_list[0]
sequence_annotations_file = os.path.join(
self.dataset_root, category, "sequence_annotations.jgz"
)
local_file = self.path_manager.get_local_path(sequence_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
seq_annots_list = types.load_dataclass(
zipfile, List[types.SequenceAnnotation]
)
seq_annots = {entry.sequence_name: entry for entry in seq_annots_list}
self.seq_annotation = seq_annots[self.frame_annotation.sequence_name]
point_cloud = self.seq_annotation.point_cloud
self.frame_data = FrameData(
frame_number=safe_as_tensor(self.frame_annotation.frame_number, torch.long),
frame_timestamp=safe_as_tensor(
self.frame_annotation.frame_timestamp, torch.float
),
sequence_name=self.frame_annotation.sequence_name,
sequence_category=self.seq_annotation.category,
camera_quality_score=safe_as_tensor(
self.seq_annotation.viewpoint_quality_score, torch.float
),
point_cloud_quality_score=safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
)
def test_frame_data_builder_args(self):
# test that FrameDataBuilder works with get_default_args
get_default_args(FrameDataBuilder)
def test_fix_point_cloud_path(self):
"""Some files in Co3Dv2 have an accidental absolute path stored."""
original_path = "some_file_path"
modified_path = self.frame_data_builder._fix_point_cloud_path(original_path)
self.assertIn(original_path, modified_path)
self.assertIn(self.frame_data_builder.dataset_root, modified_path)
def test_load_and_adjust_frame_data(self):
self.frame_data.image_size_hw = safe_as_tensor(
self.frame_annotation.image.size, torch.long
)
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
(
self.frame_data.fg_probability,
self.frame_data.mask_path,
self.frame_data.bbox_xywh,
) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
self.assertIsNotNone(self.frame_data.mask_path)
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))
self.assertTrue(torch.is_tensor(self.frame_data.bbox_xywh))
# assert bboxes shape
self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
(
self.frame_data.image_rgb,
self.frame_data.image_path,
) = self.frame_data_builder._load_images(
self.frame_annotation, self.frame_data.fg_probability
)
self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
self.assertIsNotNone(self.frame_data.image_path)
(
self.frame_data.depth_map,
depth_path,
self.frame_data.depth_mask,
) = self.frame_data_builder._load_mask_depth(
self.frame_annotation,
self.frame_data.fg_probability,
)
self.assertTrue(torch.is_tensor(self.frame_data.depth_map))
self.assertIsNotNone(depth_path)
self.assertTrue(torch.is_tensor(self.frame_data.depth_mask))
new_size = (self.image_height, self.image_width)
if self.frame_data_builder.box_crop:
self.frame_data.crop_by_metadata_bbox_(
self.frame_data_builder.box_crop_context,
)
# assert image and mask shapes after resize
self.frame_data.resize_frame_(
new_size_hw=torch.tensor(new_size, dtype=torch.long),
)
self.assertEqual(
self.frame_data.mask_crop.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.image_rgb.shape,
torch.Size([3, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.mask_crop.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.fg_probability.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.depth_map.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.depth_mask.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.frame_data.camera = self.frame_data_builder._get_pytorch3d_camera(
self.frame_annotation,
)
self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)
def test_load_image(self):
path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
local_path = self.path_manager.get_local_path(path)
image = load_image(local_path)
self.assertEqual(image.dtype, np.float32)
self.assertLessEqual(np.max(image), 1.0)
self.assertGreaterEqual(np.min(image), 0.0)
def test_load_mask(self):
path = os.path.join(self.dataset_root, self.frame_annotation.mask.path)
mask = load_mask(path)
self.assertEqual(mask.dtype, np.float32)
self.assertLessEqual(np.max(mask), 1.0)
self.assertGreaterEqual(np.min(mask), 0.0)
def test_load_depth(self):
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
depth_map = load_depth(path, self.frame_annotation.depth.scale_adjustment)
self.assertEqual(depth_map.dtype, np.float32)
self.assertEqual(len(depth_map.shape), 3)
def test_load_16big_png_depth(self):
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
depth_map = load_16big_png_depth(path)
self.assertEqual(depth_map.dtype, np.float32)
self.assertEqual(len(depth_map.shape), 2)
def test_load_1bit_png_mask(self):
mask_path = os.path.join(
self.dataset_root, self.frame_annotation.depth.mask_path
)
mask = load_1bit_png_mask(mask_path)
self.assertEqual(mask.dtype, np.float32)
self.assertEqual(len(mask.shape), 2)
def test_load_depth_mask(self):
mask_path = os.path.join(
self.dataset_root, self.frame_annotation.depth.mask_path
)
mask = load_depth_mask(mask_path)
self.assertEqual(mask.dtype, np.float32)
self.assertEqual(len(mask.shape), 3)

View File

@ -17,7 +17,7 @@ import numpy as np
import torch
import torchvision
from PIL import Image
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
JsonIndexDatasetMapProviderV2,
)