Changes for CO3Dv2 release [part1]

Summary:
Implements several changes needed for the CO3Dv2 release:
- FrameData contains crop_bbox_xywh which defines the outline of the image crop corresponding to the image-shaped tensors in FrameData
- revised the definition of a bounding box inside JsonDatasetIndex: bbox_xyxy is [xmin, ymin, xmax, ymax], where xmax, ymax are not inclusive; bbox_xywh = [xmin, ymain, xmax-xmin, ymax-ymin]
- is_filtered for detecting whether the entries of the dataset were somehow filtered
- seq_frame_index_to_dataset_index allows to skip entries that are not present in the dataset

Reviewed By: shapovalov

Differential Revision: D37687547

fbshipit-source-id: 7842756b0517878cc0964fc0935d3c0769454d78
This commit is contained in:
David Novotny 2022-07-09 17:16:24 -07:00 committed by Facebook GitHub Bot
parent 00acf0b0c7
commit 4300030d7a
5 changed files with 224 additions and 33 deletions

View File

@ -61,8 +61,16 @@ class FrameData(Mapping[str, Any]):
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats pixels belonging to the captured object; elements are floats
in [0, 1]. in [0, 1].
bbox_xywh: The bounding box capturing the object in the bbox_xywh: The bounding box tightly enclosing the foreground object in the
format (x0, y0, width, height). format (x0, y0, width, height). The convention assumes that
`x0+width` and `y0+height` includes the boundary of the box.
I.e., to slice out the corresponding crop from an image tensor `I`
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
in the original image coordinates in the format (x0, y0, width, height).
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
from `bbox_xywh` due to padding (which can happen e.g. due to
setting `JsonIndexDataset.box_crop_context > 0`)
camera: A PyTorch3D camera object corresponding the frame's viewpoint, camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened. corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the camera_quality_score: The score proportional to the confidence of the
@ -98,6 +106,7 @@ class FrameData(Mapping[str, Any]):
mask_path: Union[str, List[str], None] = None mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None bbox_xywh: Optional[torch.Tensor] = None
crop_bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None point_cloud_quality_score: Optional[torch.Tensor] = None

View File

@ -160,17 +160,50 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
self._filter_db() # also computes sequence indices self._filter_db() # also computes sequence indices
logger.info(str(self)) logger.info(str(self))
def is_filtered(self):
"""
Returns `True` in case the dataset has been filtered and thus some frame annotations
stored on the disk might be missing in the dataset object.
Returns:
is_filtered: `True` if the dataset has been filtered, else `False`.
"""
return (
self.remove_empty_masks
or self.limit_to > 0
or self.limit_sequences_to > 0
or len(self.pick_sequence) > 0
or len(self.exclude_sequence) > 0
or len(self.limit_category_to) > 0
or self.n_frames_per_sequence > 0
)
def seq_frame_index_to_dataset_index( def seq_frame_index_to_dataset_index(
self, self,
seq_frame_index: Union[ seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], allow_missing_indices: bool = False,
], remove_missing_indices: bool = False,
) -> List[List[int]]: ) -> List[List[Union[Optional[int], int]]]:
""" """
Obtain indices into the dataset object given a list of frames specified as Obtain indices into the dataset object given a list of frame ids.
`seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`.
Args:
seq_frame_index: The list of frame ids specified as
`List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally,
Image paths relative to the dataset_root can be stored specified as well:
`List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]`
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
entry from `seq_frame_index` which is missing in the dataset.
Otherwise, depending on `remove_missing_indices`, either returns `None`
in place of missing entries or removes the indices of missing entries.
remove_missing_indices: Active when `allow_missing_indices=True`.
If `False`, returns `None` in place of `seq_frame_index` entries that
are not present in the dataset.
If `True` removes missing indices from the returned indices.
Returns:
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
""" """
# TODO: check the frame numbers are unique
_dataset_seq_frame_n_index = { _dataset_seq_frame_n_index = {
seq: { seq: {
# pyre-ignore[16] # pyre-ignore[16]
@ -181,8 +214,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
for seq, seq_idx in self._seq_to_idx.items() for seq, seq_idx in self._seq_to_idx.items()
} }
def _get_batch_idx(seq_name, frame_no, path=None) -> int: def _get_dataset_idx(
idx = _dataset_seq_frame_n_index[seq_name][frame_no] seq_name: str, frame_no: int, path: Optional[str] = None
) -> Optional[int]:
idx_seq = _dataset_seq_frame_n_index.get(seq_name, None)
idx = idx_seq.get(frame_no, None) if idx_seq is not None else None
if idx is None:
msg = (
f"sequence_name={seq_name} / frame_number={frame_no}"
" not in the dataset!"
)
if not allow_missing_indices:
raise IndexError(msg)
warnings.warn(msg)
return idx
if path is not None: if path is not None:
# Check that the loaded frame path is consistent # Check that the loaded frame path is consistent
# with the one stored in self.frame_annots. # with the one stored in self.frame_annots.
@ -191,11 +236,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
self.frame_annots[idx]["frame_annotation"].image.path self.frame_annots[idx]["frame_annotation"].image.path
) == os.path.normpath( ) == os.path.normpath(
path path
), f"Inconsistent batch {seq_name, frame_no, path}." ), f"Inconsistent frame indices {seq_name, frame_no, path}."
return idx return idx
batches_idx = [[_get_batch_idx(*b) for b in batch] for batch in seq_frame_index] dataset_idx = [
return batches_idx [_get_dataset_idx(*b) for b in batch] # pyre-ignore [6]
for batch in seq_frame_index
]
if allow_missing_indices and remove_missing_indices:
# remove all None indices, and also batches with only None entries
valid_dataset_idx = [
[b for b in batch if b is not None] for batch in dataset_idx
]
return [ # pyre-ignore[7]
batch for batch in valid_dataset_idx if len(batch) > 0
]
return dataset_idx
def __str__(self) -> str: def __str__(self) -> str:
# pyre-ignore[16] # pyre-ignore[16]
@ -254,6 +312,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
frame_data.mask_path, frame_data.mask_path,
frame_data.bbox_xywh, frame_data.bbox_xywh,
clamp_bbox_xyxy, clamp_bbox_xyxy,
frame_data.crop_bbox_xywh,
) = self._load_crop_fg_probability(entry) ) = self._load_crop_fg_probability(entry)
scale = 1.0 scale = 1.0
@ -301,13 +360,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
Optional[str], Optional[str],
Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor],
]: ]:
fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy = ( fg_probability = None
None, full_path = None
None, bbox_xywh = None
None, clamp_bbox_xyxy = None
None, crop_box_xywh = None
)
if (self.load_masks or self.box_crop) and entry.mask is not None: if (self.load_masks or self.box_crop) and entry.mask is not None:
full_path = os.path.join(self.dataset_root, entry.mask.path) full_path = os.path.join(self.dataset_root, entry.mask.path)
mask = _load_mask(self._local_path(full_path)) mask = _load_mask(self._local_path(full_path))
@ -320,11 +380,22 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
if self.box_crop: if self.box_crop:
clamp_bbox_xyxy = _get_clamp_bbox(bbox_xywh, self.box_crop_context) clamp_bbox_xyxy = _get_clamp_bbox(
bbox_xywh,
image_path=entry.image.path,
box_crop_context=self.box_crop_context,
)
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
crop_box_xyxy = _clamp_box_to_image_bounds_and_round(
clamp_bbox_xyxy,
image_size_hw=tuple(mask.shape[-2:]),
)
crop_box_xywh = _bbox_xyxy_to_xywh(crop_box_xyxy)
fg_probability, _, _ = self._resize_image(mask, mode="nearest") fg_probability, _, _ = self._resize_image(mask, mode="nearest")
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
def _load_crop_images( def _load_crop_images(
self, self,
@ -746,7 +817,7 @@ def _load_mask(path) -> np.ndarray:
def _get_1d_bounds(arr) -> Tuple[int, int]: def _get_1d_bounds(arr) -> Tuple[int, int]:
nz = np.flatnonzero(arr) nz = np.flatnonzero(arr)
return nz[0], nz[-1] return nz[0], nz[-1] + 1
def _get_bbox_from_mask( def _get_bbox_from_mask(
@ -767,11 +838,15 @@ def _get_bbox_from_mask(
def _get_clamp_bbox( def _get_clamp_bbox(
bbox: torch.Tensor, box_crop_context: float = 0.0, impath: str = "" bbox: torch.Tensor,
box_crop_context: float = 0.0,
image_path: str = "",
) -> torch.Tensor: ) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox # box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float # returns possibly expanded bbox xyxy as float
bbox = bbox.clone() # do not edit bbox in place
# increase box size # increase box size
if box_crop_context > 0.0: if box_crop_context > 0.0:
c = box_crop_context c = box_crop_context
@ -783,27 +858,37 @@ def _get_clamp_bbox(
if (bbox[2:] <= 1.0).any(): if (bbox[2:] <= 1.0).any():
raise ValueError( raise ValueError(
f"squashed image {impath}!! The bounding box contains no pixels." f"squashed image {image_path}!! The bounding box contains no pixels."
) )
bbox[2:] = torch.clamp(bbox[2:], 2) bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax] bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
# +1 because upper bound is not inclusive
return bbox return bbox_xyxy
def _crop_around_box(tensor, bbox, impath: str = ""): def _crop_around_box(tensor, bbox, impath: str = ""):
# bbox is xyxy, where the upper bound is corrected with +1 # bbox is xyxy, where the upper bound is corrected with +1
bbox[[0, 2]] = torch.clamp(bbox[[0, 2]], 0.0, tensor.shape[-1]) bbox = _clamp_box_to_image_bounds_and_round(
bbox[[1, 3]] = torch.clamp(bbox[[1, 3]], 0.0, tensor.shape[-2]) bbox,
bbox = bbox.round().long() image_size_hw=tensor.shape[-2:],
)
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
return tensor return tensor
def _clamp_box_to_image_bounds_and_round(
bbox_xyxy: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.LongTensor:
bbox_xyxy = bbox_xyxy.clone()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0.0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0.0, image_size_hw[-2])
bbox_xyxy = bbox_xyxy.round().long()
return bbox_xyxy # pyre-ignore [7]
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
assert bbox is not None assert bbox is not None
assert np.prod(orig_res) > 1e-8 assert np.prod(orig_res) > 1e-8
@ -812,6 +897,22 @@ def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
return bbox * rel_size return bbox * rel_size
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
wh = xyxy[2:] - xyxy[:2]
xywh = torch.cat([xyxy[:2], wh])
return xywh
def _bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy
def _safe_as_tensor(data, dtype): def _safe_as_tensor(data, dtype):
if data is None: if data is None:
return None return None

View File

@ -142,6 +142,7 @@ def render_point_cloud_pytorch3d(
rendered_blob, rendered_blob,
size=tuple(render_size), size=tuple(render_size),
mode="bilinear", mode="bilinear",
align_corners=False,
) )
data_rendered, depth_rendered, render_mask = rendered_blob.split( data_rendered, depth_rendered, render_mask = rendered_blob.split(

View File

@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import torch
from pytorch3d.implicitron.dataset.json_index_dataset import (
_bbox_xywh_to_xyxy,
_bbox_xyxy_to_xywh,
_get_bbox_from_mask,
)
from tests.common_testing import TestCaseMixin
class TestBBox(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def test_bbox_conversion(self):
bbox_xywh_list = torch.LongTensor(
[
[0, 0, 10, 20],
[10, 20, 5, 1],
[10, 20, 1, 1],
[5, 4, 0, 1],
]
)
for bbox_xywh in bbox_xywh_list:
bbox_xyxy = _bbox_xywh_to_xyxy(bbox_xywh)
bbox_xywh_ = _bbox_xyxy_to_xywh(bbox_xyxy)
bbox_xyxy_ = _bbox_xywh_to_xyxy(bbox_xywh_)
self.assertClose(bbox_xywh_, bbox_xywh)
self.assertClose(bbox_xyxy, bbox_xyxy_)
def test_compare_to_expected(self):
bbox_xywh_to_xyxy_expected = torch.LongTensor(
[
[[0, 0, 10, 20], [0, 0, 10, 20]],
[[10, 20, 5, 1], [10, 20, 15, 21]],
[[10, 20, 1, 1], [10, 20, 11, 21]],
[[5, 4, 0, 1], [5, 4, 5, 5]],
]
)
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected:
self.assertClose(_bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
self.assertClose(_bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
clamp_amnt = 3
bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor(
[
[[0, 0, 10, 20], [0, 0, 10, 20]],
[[10, 20, 5, 1], [10, 20, 15, 20 + clamp_amnt]],
[[10, 20, 1, 1], [10, 20, 10 + clamp_amnt, 20 + clamp_amnt]],
[[5, 4, 0, 1], [5, 4, 5 + clamp_amnt, 4 + clamp_amnt]],
]
)
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected:
self.assertClose(
_bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
bbox_xyxy_expected,
)
def test_mask_to_bbox(self):
mask = np.array(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0],
]
).astype(np.float32)
expected_bbox_xywh = [2, 1, 2, 1]
bbox_xywh = _get_bbox_from_mask(mask, 0.5)
self.assertClose(bbox_xywh, expected_bbox_xywh)

View File

@ -14,6 +14,7 @@ import torch
import torchvision import torchvision
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
from pytorch3d.vis.plotly_vis import plot_scene from pytorch3d.vis.plotly_vis import plot_scene
@ -37,6 +38,7 @@ class TestDatasetVisualize(unittest.TestCase):
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256 self.image_size = 256
expand_args_fields(JsonIndexDataset)
self.datasets = { self.datasets = {
"simple": JsonIndexDataset( "simple": JsonIndexDataset(
frame_annotations_file=frame_file, frame_annotations_file=frame_file,