mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Extract BlobLoader class from JsonIndexDataset and moving crop_by_bbox to FrameData
Summary: extracted blob loader added documentation for blob_loader did some refactoring on fields for detailed steps and discussions see: https://github.com/facebookresearch/pytorch3d/pull/1463 https://github.com/fairinternal/pixar_replay/pull/160 Reviewed By: bottler Differential Revision: D44061728 fbshipit-source-id: eefb21e9679003045d73729f96e6a93a1d4d2d51
This commit is contained in:
parent
c759fc560f
commit
ebdbfde0ce
@ -18,8 +18,9 @@ from torch.utils.data import (
|
||||
Sampler,
|
||||
)
|
||||
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_base import DatasetBase
|
||||
from .dataset_map_provider import DatasetMap
|
||||
from .frame_data import FrameData
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
from .utils import is_known_frame_scalar
|
||||
|
||||
|
@ -5,217 +5,27 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
"""
|
||||
A type of the elements returned by indexing the dataset object.
|
||||
It can represent both individual frames and batches of thereof;
|
||||
in this documentation, the sizes of tensors refer to single frames;
|
||||
add the first batch dimension for the collation result.
|
||||
|
||||
Args:
|
||||
frame_number: The number of the frame within its sequence.
|
||||
0-based continuous integers.
|
||||
sequence_name: The unique name of the frame's sequence.
|
||||
sequence_category: The object category of the sequence.
|
||||
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
||||
image_size_hw: The size of the image in pixels; (height, width) tensor
|
||||
of shape (2,).
|
||||
image_path: The qualified path to the loaded image (with dataset_root).
|
||||
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
||||
of the frame; elements are floats in [0, 1].
|
||||
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
||||
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
||||
are a result of zero-padding of the image after cropping around
|
||||
the object bounding box; elements are floats in {0.0, 1.0}.
|
||||
depth_path: The qualified path to the frame's depth map.
|
||||
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
||||
of the frame; values correspond to distances from the camera;
|
||||
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
||||
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
||||
depth map that are valid for evaluation, they have been checked for
|
||||
consistency across views; elements are floats in {0.0, 1.0}.
|
||||
mask_path: A qualified path to the foreground probability mask.
|
||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||
pixels belonging to the captured object; elements are floats
|
||||
in [0, 1].
|
||||
bbox_xywh: The bounding box tightly enclosing the foreground object in the
|
||||
format (x0, y0, width, height). The convention assumes that
|
||||
`x0+width` and `y0+height` includes the boundary of the box.
|
||||
I.e., to slice out the corresponding crop from an image tensor `I`
|
||||
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
|
||||
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
|
||||
in the original image coordinates in the format (x0, y0, width, height).
|
||||
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
|
||||
from `bbox_xywh` due to padding (which can happen e.g. due to
|
||||
setting `JsonIndexDataset.box_crop_context > 0`)
|
||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||
corrected for cropping if it happened.
|
||||
camera_quality_score: The score proportional to the confidence of the
|
||||
frame's camera estimation (the higher the more accurate).
|
||||
point_cloud_quality_score: The score proportional to the accuracy of the
|
||||
frame's sequence point cloud (the higher the more accurate).
|
||||
sequence_point_cloud_path: The path to the sequence's point cloud.
|
||||
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
||||
point cloud corresponding to the frame's sequence. When the object
|
||||
represents a batch of frames, point clouds may be deduplicated;
|
||||
see `sequence_point_cloud_idx`.
|
||||
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
||||
corresponding point clouds in `sequence_point_cloud`; to get the
|
||||
corresponding point cloud to `image_rgb[i]`, use
|
||||
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
||||
frame_type: The type of the loaded frame specified in
|
||||
`subset_lists_file`, if provided.
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
frame_timestamp: Optional[torch.Tensor] = None
|
||||
image_size_hw: Optional[torch.Tensor] = None
|
||||
image_path: Union[str, List[str], None] = None
|
||||
image_rgb: Optional[torch.Tensor] = None
|
||||
# masks out padding added due to cropping the square bit
|
||||
mask_crop: Optional[torch.Tensor] = None
|
||||
depth_path: Union[str, List[str], None] = None
|
||||
depth_map: Optional[torch.Tensor] = None
|
||||
depth_mask: Optional[torch.Tensor] = None
|
||||
mask_path: Union[str, List[str], None] = None
|
||||
fg_probability: Optional[torch.Tensor] = None
|
||||
bbox_xywh: Optional[torch.Tensor] = None
|
||||
crop_bbox_xywh: Optional[torch.Tensor] = None
|
||||
camera: Optional[PerspectiveCameras] = None
|
||||
camera_quality_score: Optional[torch.Tensor] = None
|
||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||
sequence_point_cloud_path: Union[str, List[str], None] = None
|
||||
sequence_point_cloud: Optional[Pointclouds] = None
|
||||
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
||||
frame_type: Union[str, List[str], None] = None # known | unseen
|
||||
meta: dict = field(default_factory=lambda: {})
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new_params = {}
|
||||
for f in fields(self):
|
||||
value = getattr(self, f.name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
new_params[f.name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[f.name] = value
|
||||
return type(self)(**new_params)
|
||||
|
||||
def cpu(self):
|
||||
return self.to(device=torch.device("cpu"))
|
||||
|
||||
def cuda(self):
|
||||
return self.to(device=torch.device("cuda"))
|
||||
|
||||
# the following functions make sure **frame_data can be passed to functions
|
||||
def __iter__(self):
|
||||
for f in fields(self):
|
||||
yield f.name
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
@classmethod
|
||||
def collate(cls, batch):
|
||||
"""
|
||||
Given a list objects `batch` of class `cls`, collates them into a batched
|
||||
representation suitable for processing with deep networks.
|
||||
"""
|
||||
|
||||
elem = batch[0]
|
||||
|
||||
if isinstance(elem, cls):
|
||||
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
||||
id_to_idx = defaultdict(list)
|
||||
for i, pc_id in enumerate(pointcloud_ids):
|
||||
id_to_idx[pc_id].append(i)
|
||||
|
||||
sequence_point_cloud = []
|
||||
sequence_point_cloud_idx = -np.ones((len(batch),))
|
||||
for i, ind in enumerate(id_to_idx.values()):
|
||||
sequence_point_cloud_idx[ind] = i
|
||||
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
||||
assert (sequence_point_cloud_idx >= 0).all()
|
||||
|
||||
override_fields = {
|
||||
"sequence_point_cloud": sequence_point_cloud,
|
||||
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
||||
}
|
||||
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
||||
|
||||
collated = {}
|
||||
for f in fields(elem):
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
collated[f.name] = (
|
||||
cls.collate(list_values)
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data._utils.collate.default_collate(batch)
|
||||
|
||||
|
||||
class _GenericWorkaround:
|
||||
"""
|
||||
OmegaConf.structured has a weirdness when you try to apply
|
||||
it to a dataclass whose first base class is a Generic which is not
|
||||
Dict. The issue is with a function called get_dict_key_value_types
|
||||
in omegaconf/_utils.py.
|
||||
For example this fails:
|
||||
|
||||
@dataclass(eq=False)
|
||||
class D(torch.utils.data.Dataset[int]):
|
||||
a: int = 3
|
||||
|
||||
OmegaConf.structured(D)
|
||||
|
||||
We avoid the problem by adding this class as an extra base class.
|
||||
"""
|
||||
|
||||
pass
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
Base class to describe a dataset to be used with Implicitron.
|
||||
|
||||
|
728
pytorch3d/implicitron/dataset/frame_data.py
Normal file
728
pytorch3d/implicitron/dataset/frame_data.py
Normal file
@ -0,0 +1,728 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Generic,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
adjust_camera_to_bbox_crop_,
|
||||
adjust_camera_to_image_scale_,
|
||||
bbox_xyxy_to_xywh,
|
||||
clamp_box_to_image_bounds_and_round,
|
||||
crop_around_box,
|
||||
GenericWorkaround,
|
||||
get_bbox_from_mask,
|
||||
get_clamp_bbox,
|
||||
load_depth,
|
||||
load_depth_mask,
|
||||
load_image,
|
||||
load_mask,
|
||||
load_pointcloud,
|
||||
rescale_bbox,
|
||||
resize_image,
|
||||
safe_as_tensor,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
"""
|
||||
A type of the elements returned by indexing the dataset object.
|
||||
It can represent both individual frames and batches of thereof;
|
||||
in this documentation, the sizes of tensors refer to single frames;
|
||||
add the first batch dimension for the collation result.
|
||||
|
||||
Args:
|
||||
frame_number: The number of the frame within its sequence.
|
||||
0-based continuous integers.
|
||||
sequence_name: The unique name of the frame's sequence.
|
||||
sequence_category: The object category of the sequence.
|
||||
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
||||
image_size_hw: The size of the original image in pixels; (height, width)
|
||||
tensor of shape (2,). Note that it is optional, e.g. it can be `None`
|
||||
if the frame annotation has no size ans image_rgb has not [yet] been
|
||||
loaded. Image-less FrameData is valid but mutators like crop/resize
|
||||
may fail if the original image size cannot be deduced.
|
||||
effective_image_size_hw: The size of the image after mutations such as
|
||||
crop/resize in pixels; (height, width). if the image has not been mutated,
|
||||
it is equal to `image_size_hw`. Note that it is also optional, for the
|
||||
same reason as `image_size_hw`.
|
||||
image_path: The qualified path to the loaded image (with dataset_root).
|
||||
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
||||
of the frame; elements are floats in [0, 1].
|
||||
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
||||
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
||||
are a result of zero-padding of the image after cropping around
|
||||
the object bounding box; elements are floats in {0.0, 1.0}.
|
||||
depth_path: The qualified path to the frame's depth map.
|
||||
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
||||
of the frame; values correspond to distances from the camera;
|
||||
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
||||
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
||||
depth map that are valid for evaluation, they have been checked for
|
||||
consistency across views; elements are floats in {0.0, 1.0}.
|
||||
mask_path: A qualified path to the foreground probability mask.
|
||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||
pixels belonging to the captured object; elements are floats
|
||||
in [0, 1].
|
||||
bbox_xywh: The bounding box tightly enclosing the foreground object in the
|
||||
format (x0, y0, width, height). The convention assumes that
|
||||
`x0+width` and `y0+height` includes the boundary of the box.
|
||||
I.e., to slice out the corresponding crop from an image tensor `I`
|
||||
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
|
||||
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
|
||||
in the original image coordinates in the format (x0, y0, width, height).
|
||||
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
|
||||
from `bbox_xywh` due to padding (which can happen e.g. due to
|
||||
setting `JsonIndexDataset.box_crop_context > 0`)
|
||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||
corrected for cropping if it happened.
|
||||
camera_quality_score: The score proportional to the confidence of the
|
||||
frame's camera estimation (the higher the more accurate).
|
||||
point_cloud_quality_score: The score proportional to the accuracy of the
|
||||
frame's sequence point cloud (the higher the more accurate).
|
||||
sequence_point_cloud_path: The path to the sequence's point cloud.
|
||||
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
||||
point cloud corresponding to the frame's sequence. When the object
|
||||
represents a batch of frames, point clouds may be deduplicated;
|
||||
see `sequence_point_cloud_idx`.
|
||||
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
||||
corresponding point clouds in `sequence_point_cloud`; to get the
|
||||
corresponding point cloud to `image_rgb[i]`, use
|
||||
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
||||
frame_type: The type of the loaded frame specified in
|
||||
`subset_lists_file`, if provided.
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
frame_timestamp: Optional[torch.Tensor] = None
|
||||
image_size_hw: Optional[torch.LongTensor] = None
|
||||
effective_image_size_hw: Optional[torch.LongTensor] = None
|
||||
image_path: Union[str, List[str], None] = None
|
||||
image_rgb: Optional[torch.Tensor] = None
|
||||
# masks out padding added due to cropping the square bit
|
||||
mask_crop: Optional[torch.Tensor] = None
|
||||
depth_path: Union[str, List[str], None] = None
|
||||
depth_map: Optional[torch.Tensor] = None
|
||||
depth_mask: Optional[torch.Tensor] = None
|
||||
mask_path: Union[str, List[str], None] = None
|
||||
fg_probability: Optional[torch.Tensor] = None
|
||||
bbox_xywh: Optional[torch.Tensor] = None
|
||||
crop_bbox_xywh: Optional[torch.Tensor] = None
|
||||
camera: Optional[PerspectiveCameras] = None
|
||||
camera_quality_score: Optional[torch.Tensor] = None
|
||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||
sequence_point_cloud_path: Union[str, List[str], None] = None
|
||||
sequence_point_cloud: Optional[Pointclouds] = None
|
||||
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
||||
frame_type: Union[str, List[str], None] = None # known | unseen
|
||||
meta: dict = field(default_factory=lambda: {})
|
||||
|
||||
# NOTE that batching resets this attribute
|
||||
_uncropped: bool = field(init=False, default=True)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new_params = {}
|
||||
for field_name in iter(self):
|
||||
value = getattr(self, field_name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
new_params[field_name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[field_name] = value
|
||||
frame_data = type(self)(**new_params)
|
||||
frame_data._uncropped = self._uncropped
|
||||
return frame_data
|
||||
|
||||
def cpu(self):
|
||||
return self.to(device=torch.device("cpu"))
|
||||
|
||||
def cuda(self):
|
||||
return self.to(device=torch.device("cuda"))
|
||||
|
||||
# the following functions make sure **frame_data can be passed to functions
|
||||
def __iter__(self):
|
||||
for f in fields(self):
|
||||
if f.name.startswith("_"):
|
||||
continue
|
||||
|
||||
yield f.name
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __len__(self):
|
||||
return sum(1 for f in iter(self))
|
||||
|
||||
def crop_by_metadata_bbox_(
|
||||
self,
|
||||
box_crop_context: float,
|
||||
) -> None:
|
||||
"""Crops the frame data in-place by (possibly expanded) bounding box.
|
||||
The bounding box is taken from the object state (usually taken from
|
||||
the frame annotation or estimated from the foregroubnd mask).
|
||||
If the expanded bounding box does not fit the image, it is clamped,
|
||||
i.e. the image is *not* padded.
|
||||
|
||||
Args:
|
||||
box_crop_context: rate of expansion for bbox; 0 means no expansion,
|
||||
|
||||
Raises:
|
||||
ValueError: If the object does not contain a bounding box (usually when no
|
||||
mask annotation is provided)
|
||||
ValueError: If the frame data have been cropped or resized, thus the intrinsic
|
||||
bounding box is not valid for the current image size.
|
||||
ValueError: If the frame does not have an image size (usually a corner case
|
||||
when no image has been loaded)
|
||||
"""
|
||||
if self.bbox_xywh is None:
|
||||
raise ValueError("Attempted cropping by metadata with empty bounding box")
|
||||
|
||||
if not self._uncropped:
|
||||
raise ValueError(
|
||||
"Trying to apply the metadata bounding box to already cropped "
|
||||
"or resized image; coordinates have changed."
|
||||
)
|
||||
|
||||
self._crop_by_bbox_(
|
||||
box_crop_context,
|
||||
self.bbox_xywh,
|
||||
)
|
||||
|
||||
def crop_by_given_bbox_(
|
||||
self,
|
||||
box_crop_context: float,
|
||||
bbox_xywh: torch.Tensor,
|
||||
) -> None:
|
||||
"""Crops the frame data in-place by (possibly expanded) bounding box.
|
||||
If the expanded bounding box does not fit the image, it is clamped,
|
||||
i.e. the image is *not* padded.
|
||||
|
||||
Args:
|
||||
box_crop_context: rate of expansion for bbox; 0 means no expansion,
|
||||
bbox_xywh: bounding box in [x0, y0, width, height] format. If float
|
||||
tensor, values are floored (after converting to [x0, y0, x1, y1]).
|
||||
|
||||
Raises:
|
||||
ValueError: If the frame does not have an image size (usually a corner case
|
||||
when no image has been loaded)
|
||||
"""
|
||||
self._crop_by_bbox_(
|
||||
box_crop_context,
|
||||
bbox_xywh,
|
||||
)
|
||||
|
||||
def _crop_by_bbox_(
|
||||
self,
|
||||
box_crop_context: float,
|
||||
bbox_xywh: torch.Tensor,
|
||||
) -> None:
|
||||
"""Crops the frame data in-place by (possibly expanded) bounding box.
|
||||
If the expanded bounding box does not fit the image, it is clamped,
|
||||
i.e. the image is *not* padded.
|
||||
|
||||
Args:
|
||||
box_crop_context: rate of expansion for bbox; 0 means no expansion,
|
||||
bbox_xywh: bounding box in [x0, y0, width, height] format. If float
|
||||
tensor, values are floored (after converting to [x0, y0, x1, y1]).
|
||||
|
||||
Raises:
|
||||
ValueError: If the frame does not have an image size (usually a corner case
|
||||
when no image has been loaded)
|
||||
"""
|
||||
effective_image_size_hw = self.effective_image_size_hw
|
||||
if effective_image_size_hw is None:
|
||||
raise ValueError("Calling crop on image-less FrameData")
|
||||
|
||||
bbox_xyxy = get_clamp_bbox(
|
||||
bbox_xywh,
|
||||
image_path=self.image_path, # pyre-ignore
|
||||
box_crop_context=box_crop_context,
|
||||
)
|
||||
clamp_bbox_xyxy = clamp_box_to_image_bounds_and_round(
|
||||
bbox_xyxy,
|
||||
image_size_hw=tuple(self.effective_image_size_hw), # pyre-ignore
|
||||
)
|
||||
crop_bbox_xywh = bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
||||
|
||||
if self.fg_probability is not None:
|
||||
self.fg_probability = crop_around_box(
|
||||
self.fg_probability,
|
||||
clamp_bbox_xyxy,
|
||||
self.mask_path, # pyre-ignore
|
||||
)
|
||||
if self.image_rgb is not None:
|
||||
self.image_rgb = crop_around_box(
|
||||
self.image_rgb,
|
||||
clamp_bbox_xyxy,
|
||||
self.image_path, # pyre-ignore
|
||||
)
|
||||
|
||||
depth_map = self.depth_map
|
||||
if depth_map is not None:
|
||||
clamp_bbox_xyxy_depth = rescale_bbox(
|
||||
clamp_bbox_xyxy, tuple(depth_map.shape[-2:]), effective_image_size_hw
|
||||
).long()
|
||||
self.depth_map = crop_around_box(
|
||||
depth_map,
|
||||
clamp_bbox_xyxy_depth,
|
||||
self.depth_path, # pyre-ignore
|
||||
)
|
||||
|
||||
depth_mask = self.depth_mask
|
||||
if depth_mask is not None:
|
||||
clamp_bbox_xyxy_depth = rescale_bbox(
|
||||
clamp_bbox_xyxy, tuple(depth_mask.shape[-2:]), effective_image_size_hw
|
||||
).long()
|
||||
self.depth_mask = crop_around_box(
|
||||
depth_mask,
|
||||
clamp_bbox_xyxy_depth,
|
||||
self.mask_path, # pyre-ignore
|
||||
)
|
||||
|
||||
# changing principal_point according to bbox_crop
|
||||
if self.camera is not None:
|
||||
adjust_camera_to_bbox_crop_(
|
||||
camera=self.camera,
|
||||
image_size_wh=effective_image_size_hw.flip(dims=[-1]),
|
||||
clamp_bbox_xywh=crop_bbox_xywh,
|
||||
)
|
||||
|
||||
# pyre-ignore
|
||||
self.effective_image_size_hw = crop_bbox_xywh[..., 2:].flip(dims=[-1])
|
||||
self._uncropped = False
|
||||
|
||||
def resize_frame_(self, new_size_hw: torch.LongTensor) -> None:
|
||||
"""Resizes frame data in-place according to given dimensions.
|
||||
|
||||
Args:
|
||||
new_size_hw: target image size [height, width], a LongTensor of shape (2,)
|
||||
|
||||
Raises:
|
||||
ValueError: If the frame does not have an image size (usually a corner case
|
||||
when no image has been loaded)
|
||||
"""
|
||||
|
||||
effective_image_size_hw = self.effective_image_size_hw
|
||||
if effective_image_size_hw is None:
|
||||
raise ValueError("Calling resize on image-less FrameData")
|
||||
|
||||
image_height, image_width = new_size_hw.tolist()
|
||||
|
||||
if self.fg_probability is not None:
|
||||
self.fg_probability, _, _ = resize_image(
|
||||
self.fg_probability,
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
mode="nearest",
|
||||
)
|
||||
|
||||
if self.image_rgb is not None:
|
||||
self.image_rgb, _, self.mask_crop = resize_image(
|
||||
self.image_rgb, image_height=image_height, image_width=image_width
|
||||
)
|
||||
|
||||
if self.depth_map is not None:
|
||||
self.depth_map, _, _ = resize_image(
|
||||
self.depth_map,
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
mode="nearest",
|
||||
)
|
||||
|
||||
if self.depth_mask is not None:
|
||||
self.depth_mask, _, _ = resize_image(
|
||||
self.depth_mask,
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
mode="nearest",
|
||||
)
|
||||
|
||||
if self.camera is not None:
|
||||
if self.image_size_hw is None:
|
||||
raise ValueError(
|
||||
"image_size_hw has to be defined for resizing FrameData with cameras."
|
||||
)
|
||||
adjust_camera_to_image_scale_(
|
||||
camera=self.camera,
|
||||
original_size_wh=effective_image_size_hw.flip(dims=[-1]),
|
||||
new_size_wh=new_size_hw.flip(dims=[-1]), # pyre-ignore
|
||||
)
|
||||
|
||||
self.effective_image_size_hw = new_size_hw
|
||||
self._uncropped = False
|
||||
|
||||
@classmethod
|
||||
def collate(cls, batch):
|
||||
"""
|
||||
Given a list objects `batch` of class `cls`, collates them into a batched
|
||||
representation suitable for processing with deep networks.
|
||||
"""
|
||||
|
||||
elem = batch[0]
|
||||
|
||||
if isinstance(elem, cls):
|
||||
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
||||
id_to_idx = defaultdict(list)
|
||||
for i, pc_id in enumerate(pointcloud_ids):
|
||||
id_to_idx[pc_id].append(i)
|
||||
|
||||
sequence_point_cloud = []
|
||||
sequence_point_cloud_idx = -np.ones((len(batch),))
|
||||
for i, ind in enumerate(id_to_idx.values()):
|
||||
sequence_point_cloud_idx[ind] = i
|
||||
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
||||
assert (sequence_point_cloud_idx >= 0).all()
|
||||
|
||||
override_fields = {
|
||||
"sequence_point_cloud": sequence_point_cloud,
|
||||
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
||||
}
|
||||
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
||||
|
||||
collated = {}
|
||||
for f in fields(elem):
|
||||
if not f.init:
|
||||
continue
|
||||
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
collated[f.name] = (
|
||||
cls.collate(list_values)
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data._utils.collate.default_collate(batch)
|
||||
|
||||
|
||||
FrameDataSubtype = TypeVar("FrameDataSubtype", bound=FrameData)
|
||||
|
||||
|
||||
class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
||||
"""A base class for FrameDataBuilders that build a FrameData object, load and
|
||||
process the binary data (crop and resize). Implementations should parametrize
|
||||
the class with a subtype of FrameData and set frame_data_type class variable to
|
||||
that type. They have to also implement `build` method.
|
||||
"""
|
||||
|
||||
# To be initialised to FrameDataSubtype
|
||||
frame_data_type: ClassVar[Type[FrameDataSubtype]]
|
||||
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
) -> FrameDataSubtype:
|
||||
"""An abstract method to build the frame data based on raw frame/sequence
|
||||
annotations, load the binary data and adjust them according to the metadata.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
"""
|
||||
A class to build a FrameData object, load and process the binary data (crop and
|
||||
resize). This is an abstract class for extending to build FrameData subtypes. Most
|
||||
users need to use concrete `FrameDataBuilder` class instead.
|
||||
Beware that modifications of frame data are done in-place.
|
||||
|
||||
Args:
|
||||
dataset_root: The root folder of the dataset; all the paths in jsons are
|
||||
specified relative to this root (but not json paths themselves).
|
||||
load_images: Enable loading the frame RGB data.
|
||||
load_depths: Enable loading the frame depth maps.
|
||||
load_depth_masks: Enable loading the frame depth map masks denoting the
|
||||
depth values used for evaluation (the points consistent across views).
|
||||
load_masks: Enable loading frame foreground masks.
|
||||
load_point_clouds: Enable loading sequence-level point clouds.
|
||||
max_points: Cap on the number of loaded points in the point cloud;
|
||||
if reached, they are randomly sampled without replacement.
|
||||
mask_images: Whether to mask the images with the loaded foreground masks;
|
||||
0 value is used for background.
|
||||
mask_depths: Whether to mask the depth maps with the loaded foreground
|
||||
masks; 0 value is used for background.
|
||||
image_height: The height of the returned images, masks, and depth maps;
|
||||
aspect ratio is preserved during cropping/resizing.
|
||||
image_width: The width of the returned images, masks, and depth maps;
|
||||
aspect ratio is preserved during cropping/resizing.
|
||||
box_crop: Enable cropping of the image around the bounding box inferred
|
||||
from the foreground region of the loaded segmentation mask; masks
|
||||
and depth maps are cropped accordingly; cameras are corrected.
|
||||
box_crop_mask_thr: The threshold used to separate pixels into foreground
|
||||
and background based on the foreground_probability mask; if no value
|
||||
is greater than this threshold, the loader lowers it and repeats.
|
||||
box_crop_context: The amount of additional padding added to each
|
||||
dimension of the cropping bounding box, relative to box size.
|
||||
path_manager: Optionally a PathManager for interpreting paths in a special way.
|
||||
"""
|
||||
|
||||
dataset_root: str = ""
|
||||
load_images: bool = True
|
||||
load_depths: bool = True
|
||||
load_depth_masks: bool = True
|
||||
load_masks: bool = True
|
||||
load_point_clouds: bool = False
|
||||
max_points: int = 0
|
||||
mask_images: bool = False
|
||||
mask_depths: bool = False
|
||||
image_height: Optional[int] = 800
|
||||
image_width: Optional[int] = 800
|
||||
box_crop: bool = True
|
||||
box_crop_mask_thr: float = 0.4
|
||||
box_crop_context: float = 0.3
|
||||
path_manager: Any = None
|
||||
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
load_blobs: bool = True,
|
||||
) -> FrameDataSubtype:
|
||||
"""Builds the frame data based on raw frame/sequence annotations, loads the
|
||||
binary data and adjust them according to the metadata. The processing includes:
|
||||
* if box_crop is set, the image/mask/depth are cropped with the bounding
|
||||
box provided or estimated from MaskAnnotation,
|
||||
* if image_height/image_width are set, the image/mask/depth are resized to
|
||||
fit that resolution. Note that the aspect ratio is preserved, and the
|
||||
(possibly cropped) image is pasted into the top-left corner. In the
|
||||
resulting frame_data, mask_crop field corresponds to the mask of the
|
||||
pasted image.
|
||||
|
||||
Args:
|
||||
frame_annotation: frame annotation
|
||||
sequence_annotation: sequence annotation
|
||||
load_blobs: if the function should attempt loading the image, depth map
|
||||
and mask, and foreground mask
|
||||
|
||||
Returns:
|
||||
The constructed FrameData object.
|
||||
"""
|
||||
|
||||
point_cloud = sequence_annotation.point_cloud
|
||||
|
||||
frame_data = self.frame_data_type(
|
||||
frame_number=safe_as_tensor(frame_annotation.frame_number, torch.long),
|
||||
frame_timestamp=safe_as_tensor(
|
||||
frame_annotation.frame_timestamp, torch.float
|
||||
),
|
||||
sequence_name=frame_annotation.sequence_name,
|
||||
sequence_category=sequence_annotation.category,
|
||||
camera_quality_score=safe_as_tensor(
|
||||
sequence_annotation.viewpoint_quality_score, torch.float
|
||||
),
|
||||
point_cloud_quality_score=safe_as_tensor(
|
||||
point_cloud.quality_score, torch.float
|
||||
)
|
||||
if point_cloud is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
if load_blobs and self.load_masks and frame_annotation.mask is not None:
|
||||
(
|
||||
frame_data.fg_probability,
|
||||
frame_data.mask_path,
|
||||
frame_data.bbox_xywh,
|
||||
) = self._load_fg_probability(frame_annotation)
|
||||
|
||||
if frame_annotation.image is not None:
|
||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||
frame_data.image_size_hw = image_size_hw # original image size
|
||||
# image size after crop/resize
|
||||
frame_data.effective_image_size_hw = image_size_hw
|
||||
|
||||
if load_blobs and self.load_images:
|
||||
(
|
||||
frame_data.image_rgb,
|
||||
frame_data.image_path,
|
||||
) = self._load_images(frame_annotation, frame_data.fg_probability)
|
||||
|
||||
if load_blobs and self.load_depths and frame_annotation.depth is not None:
|
||||
(
|
||||
frame_data.depth_map,
|
||||
frame_data.depth_path,
|
||||
frame_data.depth_mask,
|
||||
) = self._load_mask_depth(frame_annotation, frame_data.fg_probability)
|
||||
|
||||
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
||||
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
||||
frame_data.sequence_point_cloud = load_pointcloud(
|
||||
self._local_path(pcl_path), max_points=self.max_points
|
||||
)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
if frame_annotation.viewpoint is not None:
|
||||
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
||||
|
||||
if self.box_crop:
|
||||
frame_data.crop_by_metadata_bbox_(self.box_crop_context)
|
||||
|
||||
if self.image_height is not None and self.image_width is not None:
|
||||
new_size = (self.image_height, self.image_width)
|
||||
frame_data.resize_frame_(
|
||||
new_size_hw=torch.tensor(new_size, dtype=torch.long), # pyre-ignore
|
||||
)
|
||||
|
||||
return frame_data
|
||||
|
||||
def _load_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:
|
||||
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
|
||||
fg_probability = load_mask(self._local_path(full_path))
|
||||
# we can use provided bbox_xywh or calculate it based on mask
|
||||
# saves time to skip bbox calculation
|
||||
# pyre-ignore
|
||||
bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
|
||||
fg_probability, self.box_crop_mask_thr
|
||||
)
|
||||
if fg_probability.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
return (
|
||||
safe_as_tensor(fg_probability, torch.float),
|
||||
full_path,
|
||||
safe_as_tensor(bbox_xywh, torch.long),
|
||||
)
|
||||
|
||||
def _load_images(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str]:
|
||||
assert self.dataset_root is not None and entry.image is not None
|
||||
path = os.path.join(self.dataset_root, entry.image.path)
|
||||
image_rgb = load_image(self._local_path(path))
|
||||
|
||||
if image_rgb.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
|
||||
if self.mask_images:
|
||||
assert fg_probability is not None
|
||||
image_rgb *= fg_probability
|
||||
|
||||
return image_rgb, path
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
assert entry_depth is not None
|
||||
path = os.path.join(self.dataset_root, entry_depth.path)
|
||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
if self.mask_depths:
|
||||
assert fg_probability is not None
|
||||
depth_map *= fg_probability
|
||||
|
||||
if self.load_depth_masks:
|
||||
assert entry_depth.mask_path is not None
|
||||
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
||||
depth_mask = load_depth_mask(self._local_path(mask_path))
|
||||
else:
|
||||
depth_mask = torch.ones_like(depth_map)
|
||||
|
||||
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
|
||||
|
||||
def _get_pytorch3d_camera(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
) -> PerspectiveCameras:
|
||||
entry_viewpoint = entry.viewpoint
|
||||
assert entry_viewpoint is not None
|
||||
# principal point and focal length
|
||||
principal_point = torch.tensor(
|
||||
entry_viewpoint.principal_point, dtype=torch.float
|
||||
)
|
||||
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
|
||||
|
||||
format = entry_viewpoint.intrinsics_format
|
||||
if entry_viewpoint.intrinsics_format == "ndc_norm_image_bounds":
|
||||
# legacy PyTorch3D NDC format
|
||||
# convert to pixels unequally and convert to ndc equally
|
||||
image_size_as_list = list(reversed(entry.image.size))
|
||||
image_size_wh = torch.tensor(image_size_as_list, dtype=torch.float)
|
||||
per_axis_scale = image_size_wh / image_size_wh.min()
|
||||
focal_length = focal_length * per_axis_scale
|
||||
principal_point = principal_point * per_axis_scale
|
||||
elif entry_viewpoint.intrinsics_format != "ndc_isotropic":
|
||||
raise ValueError(f"Unknown intrinsics format: {format}")
|
||||
|
||||
return PerspectiveCameras(
|
||||
focal_length=focal_length[None],
|
||||
principal_point=principal_point[None],
|
||||
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
|
||||
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
||||
)
|
||||
|
||||
def _fix_point_cloud_path(self, path: str) -> str:
|
||||
"""
|
||||
Fix up a point cloud path from the dataset.
|
||||
Some files in Co3Dv2 have an accidental absolute path stored.
|
||||
"""
|
||||
unwanted_prefix = (
|
||||
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
||||
)
|
||||
if path.startswith(unwanted_prefix):
|
||||
path = path[len(unwanted_prefix) :]
|
||||
return os.path.join(self.dataset_root, path)
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
if self.path_manager is None:
|
||||
return path
|
||||
return self.path_manager.get_local_path(path)
|
||||
|
||||
|
||||
@registry.register
|
||||
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):
|
||||
"""
|
||||
A concrete class to build a FrameData object, load and process the binary data (crop
|
||||
and resize). Beware that modifications of frame data are done in-place. Please see
|
||||
the documentation for `GenericFrameDataBuilder` for the description of parameters
|
||||
and methods.
|
||||
"""
|
||||
|
||||
frame_data_type: ClassVar[Type[FrameData]] = FrameData
|
@ -15,7 +15,6 @@ import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
@ -30,19 +29,15 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData, FrameDataBuilder
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from tqdm import tqdm
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from . import types
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .utils import is_known_frame_scalar
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -65,7 +60,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
A dataset with annotations in json files like the Common Objects in 3D
|
||||
(CO3D) dataset.
|
||||
|
||||
Args:
|
||||
Metadata-related args::
|
||||
frame_annotations_file: A zipped json file containing metadata of the
|
||||
frames in the dataset, serialized List[types.FrameAnnotation].
|
||||
sequence_annotations_file: A zipped json file containing metadata of the
|
||||
@ -83,6 +78,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
pick_sequence: A list of sequence names to restrict the dataset to.
|
||||
exclude_sequence: A list of the names of the sequences to exclude.
|
||||
limit_category_to: Restrict the dataset to the given list of categories.
|
||||
remove_empty_masks: Removes the frames with no active foreground pixels
|
||||
in the segmentation mask after thresholding (see box_crop_mask_thr).
|
||||
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
|
||||
frames in each sequences uniformly without replacement if it has
|
||||
more frames than that; applied before other frame-level filters.
|
||||
seed: The seed of the random generator sampling #n_frames_per_sequence
|
||||
random frames per sequence.
|
||||
sort_frames: Enable frame annotations sorting to group frames from the
|
||||
same sequences together and order them by timestamps
|
||||
eval_batches: A list of batches that form the evaluation set;
|
||||
list of batch-sized lists of indices corresponding to __getitem__
|
||||
of this class, thus it can be used directly as a batch sampler.
|
||||
eval_batch_index:
|
||||
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
|
||||
A list of batches of frames described as (sequence_name, frame_idx)
|
||||
that can form the evaluation set, `eval_batches` will be set from this.
|
||||
|
||||
Blob-loading parameters:
|
||||
dataset_root: The root folder of the dataset; all the paths in jsons are
|
||||
specified relative to this root (but not json paths themselves).
|
||||
load_images: Enable loading the frame RGB data.
|
||||
@ -109,23 +122,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
is greater than this threshold, the loader lowers it and repeats.
|
||||
box_crop_context: The amount of additional padding added to each
|
||||
dimension of the cropping bounding box, relative to box size.
|
||||
remove_empty_masks: Removes the frames with no active foreground pixels
|
||||
in the segmentation mask after thresholding (see box_crop_mask_thr).
|
||||
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
|
||||
frames in each sequences uniformly without replacement if it has
|
||||
more frames than that; applied before other frame-level filters.
|
||||
seed: The seed of the random generator sampling #n_frames_per_sequence
|
||||
random frames per sequence.
|
||||
sort_frames: Enable frame annotations sorting to group frames from the
|
||||
same sequences together and order them by timestamps
|
||||
eval_batches: A list of batches that form the evaluation set;
|
||||
list of batch-sized lists of indices corresponding to __getitem__
|
||||
of this class, thus it can be used directly as a batch sampler.
|
||||
eval_batch_index:
|
||||
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
|
||||
A list of batches of frames described as (sequence_name, frame_idx)
|
||||
that can form the evaluation set, `eval_batches` will be set from this.
|
||||
|
||||
"""
|
||||
|
||||
frame_annotations_type: ClassVar[
|
||||
@ -162,12 +158,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
sort_frames: bool = False
|
||||
eval_batches: Any = None
|
||||
eval_batch_index: Any = None
|
||||
# initialised in __post_init__
|
||||
# commented because of OmegaConf (for tests to pass)
|
||||
# _frame_data_builder: FrameDataBuilder = field(init=False)
|
||||
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
|
||||
self.subset_to_image_path = None
|
||||
self._load_frames()
|
||||
self._load_sequences()
|
||||
if self.sort_frames:
|
||||
@ -175,9 +173,27 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self._load_subset_lists()
|
||||
self._filter_db() # also computes sequence indices
|
||||
self._extract_and_set_eval_batches()
|
||||
|
||||
# pyre-ignore
|
||||
self._frame_data_builder = FrameDataBuilder(
|
||||
dataset_root=self.dataset_root,
|
||||
load_images=self.load_images,
|
||||
load_depths=self.load_depths,
|
||||
load_depth_masks=self.load_depth_masks,
|
||||
load_masks=self.load_masks,
|
||||
load_point_clouds=self.load_point_clouds,
|
||||
max_points=self.max_points,
|
||||
mask_images=self.mask_images,
|
||||
mask_depths=self.mask_depths,
|
||||
image_height=self.image_height,
|
||||
image_width=self.image_width,
|
||||
box_crop=self.box_crop,
|
||||
box_crop_mask_thr=self.box_crop_mask_thr,
|
||||
box_crop_context=self.box_crop_context,
|
||||
)
|
||||
logger.info(str(self))
|
||||
|
||||
def _extract_and_set_eval_batches(self):
|
||||
def _extract_and_set_eval_batches(self) -> None:
|
||||
"""
|
||||
Sets eval_batches based on input eval_batch_index.
|
||||
"""
|
||||
@ -207,13 +223,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
# https://gist.github.com/treyhunner/f35292e676efa0be1728
|
||||
functools.reduce(
|
||||
lambda a, b: {**a, **b},
|
||||
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
|
||||
# pyre-ignore[16]
|
||||
[d.seq_annots for d in other_datasets],
|
||||
)
|
||||
)
|
||||
all_eval_batches = [
|
||||
self.eval_batches,
|
||||
# pyre-ignore
|
||||
*[d.eval_batches for d in other_datasets],
|
||||
*[d.eval_batches for d in other_datasets], # pyre-ignore[16]
|
||||
]
|
||||
if not (
|
||||
all(ba is None for ba in all_eval_batches)
|
||||
@ -251,7 +267,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
allow_missing_indices: bool = False,
|
||||
remove_missing_indices: bool = False,
|
||||
suppress_missing_index_warning: bool = True,
|
||||
) -> List[List[Union[Optional[int], int]]]:
|
||||
) -> Union[List[List[Optional[int]]], List[List[int]]]:
|
||||
"""
|
||||
Obtain indices into the dataset object given a list of frame ids.
|
||||
|
||||
@ -323,9 +339,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
valid_dataset_idx = [
|
||||
[b for b in batch if b is not None] for batch in dataset_idx
|
||||
]
|
||||
return [ # pyre-ignore[7]
|
||||
batch for batch in valid_dataset_idx if len(batch) > 0
|
||||
]
|
||||
return [batch for batch in valid_dataset_idx if len(batch) > 0]
|
||||
|
||||
return dataset_idx
|
||||
|
||||
@ -417,255 +431,18 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
||||
|
||||
entry = self.frame_annots[index]["frame_annotation"]
|
||||
# pyre-ignore[16]
|
||||
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
||||
frame_data = FrameData(
|
||||
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
||||
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
|
||||
sequence_name=entry.sequence_name,
|
||||
sequence_category=self.seq_annots[entry.sequence_name].category,
|
||||
camera_quality_score=_safe_as_tensor(
|
||||
self.seq_annots[entry.sequence_name].viewpoint_quality_score,
|
||||
torch.float,
|
||||
),
|
||||
point_cloud_quality_score=_safe_as_tensor(
|
||||
point_cloud.quality_score, torch.float
|
||||
)
|
||||
if point_cloud is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
# The rest of the fields are optional
|
||||
# pyre-ignore
|
||||
frame_data = self._frame_data_builder.build(
|
||||
entry,
|
||||
# pyre-ignore
|
||||
self.seq_annots[entry.sequence_name],
|
||||
)
|
||||
# Optional field
|
||||
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
|
||||
|
||||
(
|
||||
frame_data.fg_probability,
|
||||
frame_data.mask_path,
|
||||
frame_data.bbox_xywh,
|
||||
clamp_bbox_xyxy,
|
||||
frame_data.crop_bbox_xywh,
|
||||
) = self._load_crop_fg_probability(entry)
|
||||
|
||||
scale = 1.0
|
||||
if self.load_images and entry.image is not None:
|
||||
# original image size
|
||||
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
|
||||
|
||||
(
|
||||
frame_data.image_rgb,
|
||||
frame_data.image_path,
|
||||
frame_data.mask_crop,
|
||||
scale,
|
||||
) = self._load_crop_images(
|
||||
entry, frame_data.fg_probability, clamp_bbox_xyxy
|
||||
)
|
||||
|
||||
if self.load_depths and entry.depth is not None:
|
||||
(
|
||||
frame_data.depth_map,
|
||||
frame_data.depth_path,
|
||||
frame_data.depth_mask,
|
||||
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
|
||||
|
||||
if entry.viewpoint is not None:
|
||||
frame_data.camera = self._get_pytorch3d_camera(
|
||||
entry,
|
||||
scale,
|
||||
clamp_bbox_xyxy,
|
||||
)
|
||||
|
||||
if self.load_point_clouds and point_cloud is not None:
|
||||
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
||||
frame_data.sequence_point_cloud = _load_pointcloud(
|
||||
self._local_path(pcl_path), max_points=self.max_points
|
||||
)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
return frame_data
|
||||
|
||||
def _fix_point_cloud_path(self, path: str) -> str:
|
||||
"""
|
||||
Fix up a point cloud path from the dataset.
|
||||
Some files in Co3Dv2 have an accidental absolute path stored.
|
||||
"""
|
||||
unwanted_prefix = (
|
||||
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
||||
)
|
||||
if path.startswith(unwanted_prefix):
|
||||
path = path[len(unwanted_prefix) :]
|
||||
return os.path.join(self.dataset_root, path)
|
||||
|
||||
def _load_crop_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[
|
||||
Optional[torch.Tensor],
|
||||
Optional[str],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
fg_probability = None
|
||||
full_path = None
|
||||
bbox_xywh = None
|
||||
clamp_bbox_xyxy = None
|
||||
crop_box_xywh = None
|
||||
|
||||
if (self.load_masks or self.box_crop) and entry.mask is not None:
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
mask = _load_mask(self._local_path(full_path))
|
||||
|
||||
if mask.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
|
||||
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
||||
|
||||
if self.box_crop:
|
||||
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
|
||||
_get_clamp_bbox(
|
||||
bbox_xywh,
|
||||
image_path=entry.image.path,
|
||||
box_crop_context=self.box_crop_context,
|
||||
),
|
||||
image_size_hw=tuple(mask.shape[-2:]),
|
||||
)
|
||||
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
||||
|
||||
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
||||
|
||||
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
|
||||
|
||||
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
|
||||
|
||||
def _load_crop_images(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
clamp_bbox_xyxy: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
|
||||
assert self.dataset_root is not None and entry.image is not None
|
||||
path = os.path.join(self.dataset_root, entry.image.path)
|
||||
image_rgb = _load_image(self._local_path(path))
|
||||
|
||||
if image_rgb.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
|
||||
if self.box_crop:
|
||||
assert clamp_bbox_xyxy is not None
|
||||
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
|
||||
|
||||
image_rgb, scale, mask_crop = self._resize_image(image_rgb)
|
||||
|
||||
if self.mask_images:
|
||||
assert fg_probability is not None
|
||||
image_rgb *= fg_probability
|
||||
|
||||
return image_rgb, path, mask_crop, scale
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
clamp_bbox_xyxy: Optional[torch.Tensor],
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
assert entry_depth is not None
|
||||
path = os.path.join(self.dataset_root, entry_depth.path)
|
||||
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
if self.box_crop:
|
||||
assert clamp_bbox_xyxy is not None
|
||||
depth_bbox_xyxy = _rescale_bbox(
|
||||
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
|
||||
)
|
||||
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
|
||||
|
||||
depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
|
||||
|
||||
if self.mask_depths:
|
||||
assert fg_probability is not None
|
||||
depth_map *= fg_probability
|
||||
|
||||
if self.load_depth_masks:
|
||||
assert entry_depth.mask_path is not None
|
||||
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
||||
depth_mask = _load_depth_mask(self._local_path(mask_path))
|
||||
|
||||
if self.box_crop:
|
||||
assert clamp_bbox_xyxy is not None
|
||||
depth_mask_bbox_xyxy = _rescale_bbox(
|
||||
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
|
||||
)
|
||||
depth_mask = _crop_around_box(
|
||||
depth_mask, depth_mask_bbox_xyxy, mask_path
|
||||
)
|
||||
|
||||
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
|
||||
else:
|
||||
depth_mask = torch.ones_like(depth_map)
|
||||
|
||||
return depth_map, path, depth_mask
|
||||
|
||||
def _get_pytorch3d_camera(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
scale: float,
|
||||
clamp_bbox_xyxy: Optional[torch.Tensor],
|
||||
) -> PerspectiveCameras:
|
||||
entry_viewpoint = entry.viewpoint
|
||||
assert entry_viewpoint is not None
|
||||
# principal point and focal length
|
||||
principal_point = torch.tensor(
|
||||
entry_viewpoint.principal_point, dtype=torch.float
|
||||
)
|
||||
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
|
||||
|
||||
half_image_size_wh_orig = (
|
||||
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
|
||||
)
|
||||
|
||||
# first, we convert from the dataset's NDC convention to pixels
|
||||
format = entry_viewpoint.intrinsics_format
|
||||
if format.lower() == "ndc_norm_image_bounds":
|
||||
# this is e.g. currently used in CO3D for storing intrinsics
|
||||
rescale = half_image_size_wh_orig
|
||||
elif format.lower() == "ndc_isotropic":
|
||||
rescale = half_image_size_wh_orig.min()
|
||||
else:
|
||||
raise ValueError(f"Unknown intrinsics format: {format}")
|
||||
|
||||
# principal point and focal length in pixels
|
||||
principal_point_px = half_image_size_wh_orig - principal_point * rescale
|
||||
focal_length_px = focal_length * rescale
|
||||
if self.box_crop:
|
||||
assert clamp_bbox_xyxy is not None
|
||||
principal_point_px -= clamp_bbox_xyxy[:2]
|
||||
|
||||
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
|
||||
if self.image_height is None or self.image_width is None:
|
||||
out_size = list(reversed(entry.image.size))
|
||||
else:
|
||||
out_size = [self.image_width, self.image_height]
|
||||
|
||||
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
|
||||
half_min_image_size_output = half_image_size_output.min()
|
||||
|
||||
# rescaled principal point and focal length in ndc
|
||||
principal_point = (
|
||||
half_image_size_output - principal_point_px * scale
|
||||
) / half_min_image_size_output
|
||||
focal_length = focal_length_px * scale / half_min_image_size_output
|
||||
|
||||
return PerspectiveCameras(
|
||||
focal_length=focal_length[None],
|
||||
principal_point=principal_point[None],
|
||||
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
|
||||
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
||||
)
|
||||
|
||||
def _load_frames(self) -> None:
|
||||
logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
|
||||
local_file = self._local_path(self.frame_annotations_file)
|
||||
@ -853,35 +630,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
# pyre-ignore[16]
|
||||
self._seq_to_idx = seq_to_idx
|
||||
|
||||
def _resize_image(
|
||||
self, image, mode="bilinear"
|
||||
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
||||
image_height, image_width = self.image_height, self.image_width
|
||||
if image_height is None or image_width is None:
|
||||
# skip the resizing
|
||||
imre_ = torch.from_numpy(image)
|
||||
return imre_, 1.0, torch.ones_like(imre_[:1])
|
||||
# takes numpy array, returns pytorch tensor
|
||||
minscale = min(
|
||||
image_height / image.shape[-2],
|
||||
image_width / image.shape[-1],
|
||||
)
|
||||
imre = torch.nn.functional.interpolate(
|
||||
torch.from_numpy(image)[None],
|
||||
scale_factor=minscale,
|
||||
mode=mode,
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
recompute_scale_factor=True,
|
||||
)[0]
|
||||
# pyre-fixme[19]: Expected 1 positional argument.
|
||||
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
|
||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
|
||||
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
|
||||
mask = torch.zeros(1, self.image_height, self.image_width)
|
||||
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
|
||||
return imre_, minscale, mask
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
if self.path_manager is None:
|
||||
return path
|
||||
@ -918,169 +666,3 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
|
||||
def _seq_name_to_seed(seq_name) -> int:
|
||||
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
|
||||
|
||||
|
||||
def _load_image(path) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
im = im.transpose((2, 0, 1))
|
||||
im = im.astype(np.float32) / 255.0
|
||||
return im
|
||||
|
||||
|
||||
def _load_16big_png_depth(depth_png) -> np.ndarray:
|
||||
with Image.open(depth_png) as depth_pil:
|
||||
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
|
||||
# we cast it to uint16, then reinterpret as float16, then cast to float32
|
||||
depth = (
|
||||
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
|
||||
.astype(np.float32)
|
||||
.reshape((depth_pil.size[1], depth_pil.size[0]))
|
||||
)
|
||||
return depth
|
||||
|
||||
|
||||
def _load_1bit_png_mask(file: str) -> np.ndarray:
|
||||
with Image.open(file) as pil_im:
|
||||
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
|
||||
return mask
|
||||
|
||||
|
||||
def _load_depth_mask(path: str) -> np.ndarray:
|
||||
if not path.lower().endswith(".png"):
|
||||
raise ValueError('unsupported depth mask file name "%s"' % path)
|
||||
m = _load_1bit_png_mask(path)
|
||||
return m[None] # fake feature channel
|
||||
|
||||
|
||||
def _load_depth(path, scale_adjustment) -> np.ndarray:
|
||||
if not path.lower().endswith(".png"):
|
||||
raise ValueError('unsupported depth file name "%s"' % path)
|
||||
|
||||
d = _load_16big_png_depth(path) * scale_adjustment
|
||||
d[~np.isfinite(d)] = 0.0
|
||||
return d[None] # fake feature channel
|
||||
|
||||
|
||||
def _load_mask(path) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
mask = np.array(pil_im)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
return mask[None] # fake feature channel
|
||||
|
||||
|
||||
def _get_1d_bounds(arr) -> Tuple[int, int]:
|
||||
nz = np.flatnonzero(arr)
|
||||
return nz[0], nz[-1] + 1
|
||||
|
||||
|
||||
def _get_bbox_from_mask(
|
||||
mask, thr, decrease_quant: float = 0.05
|
||||
) -> Tuple[int, int, int, int]:
|
||||
# bbox in xywh
|
||||
masks_for_box = np.zeros_like(mask)
|
||||
while masks_for_box.sum() <= 1.0:
|
||||
masks_for_box = (mask > thr).astype(np.float32)
|
||||
thr -= decrease_quant
|
||||
if thr <= 0.0:
|
||||
warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
|
||||
|
||||
x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
|
||||
y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
|
||||
|
||||
return x0, y0, x1 - x0, y1 - y0
|
||||
|
||||
|
||||
def _get_clamp_bbox(
|
||||
bbox: torch.Tensor,
|
||||
box_crop_context: float = 0.0,
|
||||
image_path: str = "",
|
||||
) -> torch.Tensor:
|
||||
# box_crop_context: rate of expansion for bbox
|
||||
# returns possibly expanded bbox xyxy as float
|
||||
|
||||
bbox = bbox.clone() # do not edit bbox in place
|
||||
|
||||
# increase box size
|
||||
if box_crop_context > 0.0:
|
||||
c = box_crop_context
|
||||
bbox = bbox.float()
|
||||
bbox[0] -= bbox[2] * c / 2
|
||||
bbox[1] -= bbox[3] * c / 2
|
||||
bbox[2] += bbox[2] * c
|
||||
bbox[3] += bbox[3] * c
|
||||
|
||||
if (bbox[2:] <= 1.0).any():
|
||||
raise ValueError(
|
||||
f"squashed image {image_path}!! The bounding box contains no pixels."
|
||||
)
|
||||
|
||||
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
|
||||
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
|
||||
|
||||
return bbox_xyxy
|
||||
|
||||
|
||||
def _crop_around_box(tensor, bbox, impath: str = ""):
|
||||
# bbox is xyxy, where the upper bound is corrected with +1
|
||||
bbox = _clamp_box_to_image_bounds_and_round(
|
||||
bbox,
|
||||
image_size_hw=tensor.shape[-2:],
|
||||
)
|
||||
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
||||
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
|
||||
return tensor
|
||||
|
||||
|
||||
def _clamp_box_to_image_bounds_and_round(
|
||||
bbox_xyxy: torch.Tensor,
|
||||
image_size_hw: Tuple[int, int],
|
||||
) -> torch.LongTensor:
|
||||
bbox_xyxy = bbox_xyxy.clone()
|
||||
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
|
||||
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
|
||||
if not isinstance(bbox_xyxy, torch.LongTensor):
|
||||
bbox_xyxy = bbox_xyxy.round().long()
|
||||
return bbox_xyxy # pyre-ignore [7]
|
||||
|
||||
|
||||
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
||||
assert bbox is not None
|
||||
assert np.prod(orig_res) > 1e-8
|
||||
# average ratio of dimensions
|
||||
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
|
||||
return bbox * rel_size
|
||||
|
||||
|
||||
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
|
||||
wh = xyxy[2:] - xyxy[:2]
|
||||
xywh = torch.cat([xyxy[:2], wh])
|
||||
return xywh
|
||||
|
||||
|
||||
def _bbox_xywh_to_xyxy(
|
||||
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
xyxy = xywh.clone()
|
||||
if clamp_size is not None:
|
||||
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||
xyxy[2:] += xyxy[:2]
|
||||
return xyxy
|
||||
|
||||
|
||||
def _safe_as_tensor(data, dtype):
|
||||
if data is None:
|
||||
return None
|
||||
return torch.tensor(data, dtype=dtype)
|
||||
|
||||
|
||||
# NOTE this cache is per-worker; they are implemented as processes.
|
||||
# each batch is loaded and collated by a single worker;
|
||||
# since sequences tend to co-occur within batches, this is useful.
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
|
||||
pcl = IO().load_pointcloud(pcl_path)
|
||||
if max_points > 0:
|
||||
pcl = pcl.subsample(max_points)
|
||||
|
||||
return pcl
|
||||
|
@ -20,8 +20,9 @@ from pytorch3d.implicitron.tools.config import (
|
||||
)
|
||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
||||
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_base import DatasetBase
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||
from .frame_data import FrameData
|
||||
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
||||
|
||||
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
||||
@ -69,7 +70,8 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
||||
sequence_name=_SINGLE_SEQUENCE_NAME,
|
||||
sequence_category=self.object_name,
|
||||
camera=pose,
|
||||
image_size_hw=torch.tensor(image.shape[1:]),
|
||||
# pyre-ignore
|
||||
image_size_hw=torch.tensor(image.shape[1:], dtype=torch.long),
|
||||
image_rgb=image,
|
||||
fg_probability=fg_probability,
|
||||
frame_type=frame_type,
|
||||
|
@ -55,6 +55,8 @@ class MaskAnnotation:
|
||||
path: str
|
||||
# (soft) number of pixels in the mask; sum(Prob(fg | pixel))
|
||||
mass: Optional[float] = None
|
||||
# tight bounding box around the foreground mask
|
||||
bounding_box_xywh: Optional[Tuple[float, float, float, float]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -5,10 +5,18 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List, Optional
|
||||
import functools
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
DATASET_TYPE_TRAIN = "train"
|
||||
DATASET_TYPE_TEST = "test"
|
||||
@ -16,6 +24,26 @@ DATASET_TYPE_KNOWN = "known"
|
||||
DATASET_TYPE_UNKNOWN = "unseen"
|
||||
|
||||
|
||||
class GenericWorkaround:
|
||||
"""
|
||||
OmegaConf.structured has a weirdness when you try to apply
|
||||
it to a dataclass whose first base class is a Generic which is not
|
||||
Dict. The issue is with a function called get_dict_key_value_types
|
||||
in omegaconf/_utils.py.
|
||||
For example this fails:
|
||||
|
||||
@dataclass(eq=False)
|
||||
class D(torch.utils.data.Dataset[int]):
|
||||
a: int = 3
|
||||
|
||||
OmegaConf.structured(D)
|
||||
|
||||
We avoid the problem by adding this class as an extra base class.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_known_frame_scalar(frame_type: str) -> bool:
|
||||
"""
|
||||
Given a single frame type corresponding to a single frame, return whether
|
||||
@ -52,3 +80,286 @@ def is_train_frame(
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def get_bbox_from_mask(
|
||||
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
||||
) -> Tuple[int, int, int, int]:
|
||||
# bbox in xywh
|
||||
masks_for_box = np.zeros_like(mask)
|
||||
while masks_for_box.sum() <= 1.0:
|
||||
masks_for_box = (mask > thr).astype(np.float32)
|
||||
thr -= decrease_quant
|
||||
if thr <= 0.0:
|
||||
warnings.warn(
|
||||
f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1
|
||||
)
|
||||
|
||||
x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
|
||||
y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
|
||||
|
||||
return x0, y0, x1 - x0, y1 - y0
|
||||
|
||||
|
||||
def crop_around_box(
|
||||
tensor: torch.Tensor, bbox: torch.Tensor, impath: str = ""
|
||||
) -> torch.Tensor:
|
||||
# bbox is xyxy, where the upper bound is corrected with +1
|
||||
bbox = clamp_box_to_image_bounds_and_round(
|
||||
bbox,
|
||||
image_size_hw=tuple(tensor.shape[-2:]),
|
||||
)
|
||||
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
||||
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
|
||||
return tensor
|
||||
|
||||
|
||||
def clamp_box_to_image_bounds_and_round(
|
||||
bbox_xyxy: torch.Tensor,
|
||||
image_size_hw: Tuple[int, int],
|
||||
) -> torch.LongTensor:
|
||||
bbox_xyxy = bbox_xyxy.clone()
|
||||
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
|
||||
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
|
||||
if not isinstance(bbox_xyxy, torch.LongTensor):
|
||||
bbox_xyxy = bbox_xyxy.round().long()
|
||||
return bbox_xyxy # pyre-ignore [7]
|
||||
|
||||
|
||||
T = TypeVar("T", bound=torch.Tensor)
|
||||
|
||||
|
||||
def bbox_xyxy_to_xywh(xyxy: T) -> T:
|
||||
wh = xyxy[2:] - xyxy[:2]
|
||||
xywh = torch.cat([xyxy[:2], wh])
|
||||
return xywh # pyre-ignore
|
||||
|
||||
|
||||
def get_clamp_bbox(
|
||||
bbox: torch.Tensor,
|
||||
box_crop_context: float = 0.0,
|
||||
image_path: str = "",
|
||||
) -> torch.Tensor:
|
||||
# box_crop_context: rate of expansion for bbox
|
||||
# returns possibly expanded bbox xyxy as float
|
||||
|
||||
bbox = bbox.clone() # do not edit bbox in place
|
||||
|
||||
# increase box size
|
||||
if box_crop_context > 0.0:
|
||||
c = box_crop_context
|
||||
bbox = bbox.float()
|
||||
bbox[0] -= bbox[2] * c / 2
|
||||
bbox[1] -= bbox[3] * c / 2
|
||||
bbox[2] += bbox[2] * c
|
||||
bbox[3] += bbox[3] * c
|
||||
|
||||
if (bbox[2:] <= 1.0).any():
|
||||
raise ValueError(
|
||||
f"squashed image {image_path}!! The bounding box contains no pixels."
|
||||
)
|
||||
|
||||
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
|
||||
bbox_xyxy = bbox_xywh_to_xyxy(bbox, clamp_size=2)
|
||||
|
||||
return bbox_xyxy
|
||||
|
||||
|
||||
def rescale_bbox(
|
||||
bbox: torch.Tensor,
|
||||
orig_res: Union[Tuple[int, int], torch.LongTensor],
|
||||
new_res: Union[Tuple[int, int], torch.LongTensor],
|
||||
) -> torch.Tensor:
|
||||
assert bbox is not None
|
||||
assert np.prod(orig_res) > 1e-8
|
||||
# average ratio of dimensions
|
||||
# pyre-ignore
|
||||
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
|
||||
return bbox * rel_size
|
||||
|
||||
|
||||
def bbox_xywh_to_xyxy(
|
||||
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
xyxy = xywh.clone()
|
||||
if clamp_size is not None:
|
||||
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||
xyxy[2:] += xyxy[:2]
|
||||
return xyxy
|
||||
|
||||
|
||||
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
|
||||
nz = np.flatnonzero(arr)
|
||||
return nz[0], nz[-1] + 1
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: Union[np.ndarray, torch.Tensor],
|
||||
image_height: Optional[int],
|
||||
image_width: Optional[int],
|
||||
mode: str = "bilinear",
|
||||
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
||||
|
||||
if type(image) == np.ndarray:
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
if image_height is None or image_width is None:
|
||||
# skip the resizing
|
||||
return image, 1.0, torch.ones_like(image[:1])
|
||||
# takes numpy array or tensor, returns pytorch tensor
|
||||
minscale = min(
|
||||
image_height / image.shape[-2],
|
||||
image_width / image.shape[-1],
|
||||
)
|
||||
imre = torch.nn.functional.interpolate(
|
||||
image[None],
|
||||
scale_factor=minscale,
|
||||
mode=mode,
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
recompute_scale_factor=True,
|
||||
)[0]
|
||||
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||
mask = torch.zeros(1, image_height, image_width)
|
||||
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
|
||||
return imre_, minscale, mask
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
im = im.transpose((2, 0, 1))
|
||||
im = im.astype(np.float32) / 255.0
|
||||
return im
|
||||
|
||||
|
||||
def load_mask(path: str) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
mask = np.array(pil_im)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
return mask[None] # fake feature channel
|
||||
|
||||
|
||||
def load_depth(path: str, scale_adjustment: float) -> np.ndarray:
|
||||
if not path.lower().endswith(".png"):
|
||||
raise ValueError('unsupported depth file name "%s"' % path)
|
||||
|
||||
d = load_16big_png_depth(path) * scale_adjustment
|
||||
d[~np.isfinite(d)] = 0.0
|
||||
return d[None] # fake feature channel
|
||||
|
||||
|
||||
def load_16big_png_depth(depth_png: str) -> np.ndarray:
|
||||
with Image.open(depth_png) as depth_pil:
|
||||
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
|
||||
# we cast it to uint16, then reinterpret as float16, then cast to float32
|
||||
depth = (
|
||||
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
|
||||
.astype(np.float32)
|
||||
.reshape((depth_pil.size[1], depth_pil.size[0]))
|
||||
)
|
||||
return depth
|
||||
|
||||
|
||||
def load_1bit_png_mask(file: str) -> np.ndarray:
|
||||
with Image.open(file) as pil_im:
|
||||
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
|
||||
return mask
|
||||
|
||||
|
||||
def load_depth_mask(path: str) -> np.ndarray:
|
||||
if not path.lower().endswith(".png"):
|
||||
raise ValueError('unsupported depth mask file name "%s"' % path)
|
||||
m = load_1bit_png_mask(path)
|
||||
return m[None] # fake feature channel
|
||||
|
||||
|
||||
def safe_as_tensor(data, dtype):
|
||||
return torch.tensor(data, dtype=dtype) if data is not None else None
|
||||
|
||||
|
||||
def _convert_ndc_to_pixels(
|
||||
focal_length: torch.Tensor,
|
||||
principal_point: torch.Tensor,
|
||||
image_size_wh: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
half_image_size = image_size_wh / 2
|
||||
rescale = half_image_size.min()
|
||||
principal_point_px = half_image_size - principal_point * rescale
|
||||
focal_length_px = focal_length * rescale
|
||||
return focal_length_px, principal_point_px
|
||||
|
||||
|
||||
def _convert_pixels_to_ndc(
|
||||
focal_length_px: torch.Tensor,
|
||||
principal_point_px: torch.Tensor,
|
||||
image_size_wh: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
half_image_size = image_size_wh / 2
|
||||
rescale = half_image_size.min()
|
||||
principal_point = (half_image_size - principal_point_px) / rescale
|
||||
focal_length = focal_length_px / rescale
|
||||
return focal_length, principal_point
|
||||
|
||||
|
||||
def adjust_camera_to_bbox_crop_(
|
||||
camera: PerspectiveCameras,
|
||||
image_size_wh: torch.Tensor,
|
||||
clamp_bbox_xywh: torch.Tensor,
|
||||
) -> None:
|
||||
if len(camera) != 1:
|
||||
raise ValueError("Adjusting currently works with singleton cameras camera only")
|
||||
|
||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||
camera.focal_length[0],
|
||||
camera.principal_point[0], # pyre-ignore
|
||||
image_size_wh,
|
||||
)
|
||||
principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
|
||||
|
||||
focal_length, principal_point_cropped = _convert_pixels_to_ndc(
|
||||
focal_length_px,
|
||||
principal_point_px_cropped,
|
||||
clamp_bbox_xywh[2:],
|
||||
)
|
||||
|
||||
camera.focal_length = focal_length[None]
|
||||
camera.principal_point = principal_point_cropped[None] # pyre-ignore
|
||||
|
||||
|
||||
def adjust_camera_to_image_scale_(
|
||||
camera: PerspectiveCameras,
|
||||
original_size_wh: torch.Tensor,
|
||||
new_size_wh: torch.LongTensor,
|
||||
) -> PerspectiveCameras:
|
||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||
camera.focal_length[0],
|
||||
camera.principal_point[0], # pyre-ignore
|
||||
original_size_wh,
|
||||
)
|
||||
|
||||
# now scale and convert from pixels to NDC
|
||||
image_size_wh_output = new_size_wh.float()
|
||||
scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
|
||||
focal_length_px_scaled = focal_length_px * scale
|
||||
principal_point_px_scaled = principal_point_px * scale
|
||||
|
||||
focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
|
||||
focal_length_px_scaled,
|
||||
principal_point_px_scaled,
|
||||
image_size_wh_output,
|
||||
)
|
||||
camera.focal_length = focal_length_scaled[None]
|
||||
camera.principal_point = principal_point_scaled[None] # pyre-ignore
|
||||
|
||||
|
||||
# NOTE this cache is per-worker; they are implemented as processes.
|
||||
# each batch is loaded and collated by a single worker;
|
||||
# since sequences tend to co-occur within batches, this is useful.
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
|
||||
pcl = IO().load_pointcloud(pcl_path)
|
||||
if max_points > 0:
|
||||
pcl = pcl.subsample(max_points)
|
||||
|
||||
return pcl
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
||||
from pytorch3d.structures import Pointclouds
|
||||
|
||||
from .dataset_base import FrameData
|
||||
from .frame_data import FrameData
|
||||
from .json_index_dataset import JsonIndexDataset
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Un
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||
from pytorch3d.implicitron.tools import vis_utils
|
||||
|
@ -17,7 +17,8 @@ from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
DoublePoolBatchSampler,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
|
@ -9,11 +9,19 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import (
|
||||
_bbox_xywh_to_xyxy,
|
||||
_bbox_xyxy_to_xywh,
|
||||
_get_bbox_from_mask,
|
||||
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
bbox_xywh_to_xyxy,
|
||||
bbox_xyxy_to_xywh,
|
||||
clamp_box_to_image_bounds_and_round,
|
||||
crop_around_box,
|
||||
get_1d_bounds,
|
||||
get_bbox_from_mask,
|
||||
get_clamp_bbox,
|
||||
rescale_bbox,
|
||||
resize_image,
|
||||
)
|
||||
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
@ -31,9 +39,9 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
|
||||
]
|
||||
)
|
||||
for bbox_xywh in bbox_xywh_list:
|
||||
bbox_xyxy = _bbox_xywh_to_xyxy(bbox_xywh)
|
||||
bbox_xywh_ = _bbox_xyxy_to_xywh(bbox_xyxy)
|
||||
bbox_xyxy_ = _bbox_xywh_to_xyxy(bbox_xywh_)
|
||||
bbox_xyxy = bbox_xywh_to_xyxy(bbox_xywh)
|
||||
bbox_xywh_ = bbox_xyxy_to_xywh(bbox_xyxy)
|
||||
bbox_xyxy_ = bbox_xywh_to_xyxy(bbox_xywh_)
|
||||
self.assertClose(bbox_xywh_, bbox_xywh)
|
||||
self.assertClose(bbox_xyxy, bbox_xyxy_)
|
||||
|
||||
@ -47,8 +55,8 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
|
||||
]
|
||||
)
|
||||
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected:
|
||||
self.assertClose(_bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
|
||||
self.assertClose(_bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
|
||||
self.assertClose(bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
|
||||
self.assertClose(bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
|
||||
|
||||
clamp_amnt = 3
|
||||
bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor(
|
||||
@ -61,7 +69,7 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected:
|
||||
self.assertClose(
|
||||
_bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
|
||||
bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
|
||||
bbox_xyxy_expected,
|
||||
)
|
||||
|
||||
@ -74,5 +82,61 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
|
||||
]
|
||||
).astype(np.float32)
|
||||
expected_bbox_xywh = [2, 1, 2, 1]
|
||||
bbox_xywh = _get_bbox_from_mask(mask, 0.5)
|
||||
bbox_xywh = get_bbox_from_mask(mask, 0.5)
|
||||
self.assertClose(bbox_xywh, expected_bbox_xywh)
|
||||
|
||||
def test_crop_around_box(self):
|
||||
bbox = torch.LongTensor([0, 1, 2, 3]) # (x_min, y_min, x_max, y_max)
|
||||
image = torch.LongTensor(
|
||||
[
|
||||
[0, 0, 10, 20],
|
||||
[10, 20, 5, 1],
|
||||
[10, 20, 1, 1],
|
||||
[5, 4, 0, 1],
|
||||
]
|
||||
)
|
||||
cropped = crop_around_box(image, bbox)
|
||||
self.assertClose(cropped, image[1:3, 0:2])
|
||||
|
||||
def test_clamp_box_to_image_bounds_and_round(self):
|
||||
bbox = torch.LongTensor([0, 1, 10, 12])
|
||||
image_size = (5, 6)
|
||||
expected_clamped_bbox = torch.LongTensor([0, 1, image_size[1], image_size[0]])
|
||||
clamped_bbox = clamp_box_to_image_bounds_and_round(bbox, image_size)
|
||||
self.assertClose(clamped_bbox, expected_clamped_bbox)
|
||||
|
||||
def test_get_clamp_bbox(self):
|
||||
bbox_xywh = torch.LongTensor([1, 1, 4, 5])
|
||||
clamped_bbox_xyxy = get_clamp_bbox(bbox_xywh, box_crop_context=2)
|
||||
# size multiplied by 2 and added coordinates
|
||||
self.assertClose(clamped_bbox_xyxy, torch.Tensor([-3, -4, 9, 11]))
|
||||
|
||||
def test_rescale_bbox(self):
|
||||
bbox = torch.Tensor([0.0, 1.0, 3.0, 4.0])
|
||||
original_resolution = (4, 4)
|
||||
new_resolution = (8, 8) # twice bigger
|
||||
rescaled_bbox = rescale_bbox(bbox, original_resolution, new_resolution)
|
||||
self.assertClose(bbox * 2, rescaled_bbox)
|
||||
|
||||
def test_get_1d_bounds(self):
|
||||
array = [0, 1, 2]
|
||||
bounds = get_1d_bounds(array)
|
||||
# make nonzero 1d bounds of image
|
||||
self.assertClose(bounds, [1, 3])
|
||||
|
||||
def test_resize_image(self):
|
||||
image = np.random.rand(3, 300, 500) # rgb image 300x500
|
||||
expected_shape = (150, 250)
|
||||
|
||||
resized_image, scale, mask_crop = resize_image(
|
||||
image, image_height=expected_shape[0], image_width=expected_shape[1]
|
||||
)
|
||||
|
||||
original_shape = image.shape[-2:]
|
||||
expected_scale = min(
|
||||
expected_shape[0] / original_shape[0], expected_shape[1] / original_shape[1]
|
||||
)
|
||||
|
||||
self.assertEqual(scale, expected_scale)
|
||||
self.assertEqual(resized_image.shape[-2:], expected_shape)
|
||||
self.assertEqual(mask_crop.shape[-2:], expected_shape)
|
||||
|
@ -8,7 +8,7 @@ import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import (
|
||||
RenderedMeshDatasetMapProvider,
|
||||
)
|
||||
|
@ -13,8 +13,10 @@ import os
|
||||
import unittest
|
||||
|
||||
import lpips
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
@ -268,7 +270,7 @@ class TestEvaluation(unittest.TestCase):
|
||||
for metric in lower_better:
|
||||
m_better = eval_result[metric]
|
||||
m_worse = eval_result_bad[metric]
|
||||
if m_better != m_better or m_worse != m_worse:
|
||||
if np.isnan(m_better) or np.isnan(m_worse):
|
||||
continue # metric is missing, i.e. NaN
|
||||
_assert = (
|
||||
self.assertLessEqual
|
||||
|
224
tests/implicitron/test_frame_data_builder.py
Normal file
224
tests/implicitron/test_frame_data_builder.py
Normal file
@ -0,0 +1,224 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import gzip
|
||||
import os
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
load_16big_png_depth,
|
||||
load_1bit_png_mask,
|
||||
load_depth,
|
||||
load_depth_mask,
|
||||
load_image,
|
||||
load_mask,
|
||||
safe_as_tensor,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
|
||||
from tests.common_testing import TestCaseMixin
|
||||
from tests.implicitron.common_resources import get_skateboard_data
|
||||
|
||||
|
||||
class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
category = "skateboard"
|
||||
stack = contextlib.ExitStack()
|
||||
self.dataset_root, self.path_manager = stack.enter_context(
|
||||
get_skateboard_data()
|
||||
)
|
||||
self.addCleanup(stack.close)
|
||||
self.image_height = 768
|
||||
self.image_width = 512
|
||||
|
||||
self.frame_data_builder = FrameDataBuilder(
|
||||
image_height=self.image_height,
|
||||
image_width=self.image_width,
|
||||
dataset_root=self.dataset_root,
|
||||
path_manager=self.path_manager,
|
||||
)
|
||||
|
||||
# loading single frame annotation of dataset (see JsonIndexDataset._load_frames())
|
||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||
local_file = self.path_manager.get_local_path(frame_file)
|
||||
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
||||
frame_annots_list = types.load_dataclass(
|
||||
zipfile, List[types.FrameAnnotation]
|
||||
)
|
||||
self.frame_annotation = frame_annots_list[0]
|
||||
|
||||
sequence_annotations_file = os.path.join(
|
||||
self.dataset_root, category, "sequence_annotations.jgz"
|
||||
)
|
||||
local_file = self.path_manager.get_local_path(sequence_annotations_file)
|
||||
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
||||
seq_annots_list = types.load_dataclass(
|
||||
zipfile, List[types.SequenceAnnotation]
|
||||
)
|
||||
seq_annots = {entry.sequence_name: entry for entry in seq_annots_list}
|
||||
self.seq_annotation = seq_annots[self.frame_annotation.sequence_name]
|
||||
|
||||
point_cloud = self.seq_annotation.point_cloud
|
||||
self.frame_data = FrameData(
|
||||
frame_number=safe_as_tensor(self.frame_annotation.frame_number, torch.long),
|
||||
frame_timestamp=safe_as_tensor(
|
||||
self.frame_annotation.frame_timestamp, torch.float
|
||||
),
|
||||
sequence_name=self.frame_annotation.sequence_name,
|
||||
sequence_category=self.seq_annotation.category,
|
||||
camera_quality_score=safe_as_tensor(
|
||||
self.seq_annotation.viewpoint_quality_score, torch.float
|
||||
),
|
||||
point_cloud_quality_score=safe_as_tensor(
|
||||
point_cloud.quality_score, torch.float
|
||||
)
|
||||
if point_cloud is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
def test_frame_data_builder_args(self):
|
||||
# test that FrameDataBuilder works with get_default_args
|
||||
get_default_args(FrameDataBuilder)
|
||||
|
||||
def test_fix_point_cloud_path(self):
|
||||
"""Some files in Co3Dv2 have an accidental absolute path stored."""
|
||||
original_path = "some_file_path"
|
||||
modified_path = self.frame_data_builder._fix_point_cloud_path(original_path)
|
||||
self.assertIn(original_path, modified_path)
|
||||
self.assertIn(self.frame_data_builder.dataset_root, modified_path)
|
||||
|
||||
def test_load_and_adjust_frame_data(self):
|
||||
self.frame_data.image_size_hw = safe_as_tensor(
|
||||
self.frame_annotation.image.size, torch.long
|
||||
)
|
||||
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
|
||||
|
||||
(
|
||||
self.frame_data.fg_probability,
|
||||
self.frame_data.mask_path,
|
||||
self.frame_data.bbox_xywh,
|
||||
) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
|
||||
|
||||
self.assertIsNotNone(self.frame_data.mask_path)
|
||||
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))
|
||||
self.assertTrue(torch.is_tensor(self.frame_data.bbox_xywh))
|
||||
# assert bboxes shape
|
||||
self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
|
||||
|
||||
(
|
||||
self.frame_data.image_rgb,
|
||||
self.frame_data.image_path,
|
||||
) = self.frame_data_builder._load_images(
|
||||
self.frame_annotation, self.frame_data.fg_probability
|
||||
)
|
||||
self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
|
||||
self.assertIsNotNone(self.frame_data.image_path)
|
||||
|
||||
(
|
||||
self.frame_data.depth_map,
|
||||
depth_path,
|
||||
self.frame_data.depth_mask,
|
||||
) = self.frame_data_builder._load_mask_depth(
|
||||
self.frame_annotation,
|
||||
self.frame_data.fg_probability,
|
||||
)
|
||||
self.assertTrue(torch.is_tensor(self.frame_data.depth_map))
|
||||
self.assertIsNotNone(depth_path)
|
||||
self.assertTrue(torch.is_tensor(self.frame_data.depth_mask))
|
||||
|
||||
new_size = (self.image_height, self.image_width)
|
||||
|
||||
if self.frame_data_builder.box_crop:
|
||||
self.frame_data.crop_by_metadata_bbox_(
|
||||
self.frame_data_builder.box_crop_context,
|
||||
)
|
||||
|
||||
# assert image and mask shapes after resize
|
||||
self.frame_data.resize_frame_(
|
||||
new_size_hw=torch.tensor(new_size, dtype=torch.long),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.frame_data.mask_crop.shape,
|
||||
torch.Size([1, self.image_height, self.image_width]),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.frame_data.image_rgb.shape,
|
||||
torch.Size([3, self.image_height, self.image_width]),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.frame_data.mask_crop.shape,
|
||||
torch.Size([1, self.image_height, self.image_width]),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.frame_data.fg_probability.shape,
|
||||
torch.Size([1, self.image_height, self.image_width]),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.frame_data.depth_map.shape,
|
||||
torch.Size([1, self.image_height, self.image_width]),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.frame_data.depth_mask.shape,
|
||||
torch.Size([1, self.image_height, self.image_width]),
|
||||
)
|
||||
self.frame_data.camera = self.frame_data_builder._get_pytorch3d_camera(
|
||||
self.frame_annotation,
|
||||
)
|
||||
self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)
|
||||
|
||||
def test_load_image(self):
|
||||
path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
|
||||
local_path = self.path_manager.get_local_path(path)
|
||||
image = load_image(local_path)
|
||||
self.assertEqual(image.dtype, np.float32)
|
||||
self.assertLessEqual(np.max(image), 1.0)
|
||||
self.assertGreaterEqual(np.min(image), 0.0)
|
||||
|
||||
def test_load_mask(self):
|
||||
path = os.path.join(self.dataset_root, self.frame_annotation.mask.path)
|
||||
mask = load_mask(path)
|
||||
self.assertEqual(mask.dtype, np.float32)
|
||||
self.assertLessEqual(np.max(mask), 1.0)
|
||||
self.assertGreaterEqual(np.min(mask), 0.0)
|
||||
|
||||
def test_load_depth(self):
|
||||
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
|
||||
depth_map = load_depth(path, self.frame_annotation.depth.scale_adjustment)
|
||||
self.assertEqual(depth_map.dtype, np.float32)
|
||||
self.assertEqual(len(depth_map.shape), 3)
|
||||
|
||||
def test_load_16big_png_depth(self):
|
||||
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
|
||||
depth_map = load_16big_png_depth(path)
|
||||
self.assertEqual(depth_map.dtype, np.float32)
|
||||
self.assertEqual(len(depth_map.shape), 2)
|
||||
|
||||
def test_load_1bit_png_mask(self):
|
||||
mask_path = os.path.join(
|
||||
self.dataset_root, self.frame_annotation.depth.mask_path
|
||||
)
|
||||
mask = load_1bit_png_mask(mask_path)
|
||||
self.assertEqual(mask.dtype, np.float32)
|
||||
self.assertEqual(len(mask.shape), 2)
|
||||
|
||||
def test_load_depth_mask(self):
|
||||
mask_path = os.path.join(
|
||||
self.dataset_root, self.frame_annotation.depth.mask_path
|
||||
)
|
||||
mask = load_depth_mask(mask_path)
|
||||
self.assertEqual(mask.dtype, np.float32)
|
||||
self.assertEqual(len(mask.shape), 3)
|
@ -17,7 +17,7 @@ import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
|
||||
JsonIndexDatasetMapProviderV2,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user