From 4300030d7af3d982eb95b6a9a565475ead9f7810 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Sat, 9 Jul 2022 17:16:24 -0700 Subject: [PATCH] Changes for CO3Dv2 release [part1] Summary: Implements several changes needed for the CO3Dv2 release: - FrameData contains crop_bbox_xywh which defines the outline of the image crop corresponding to the image-shaped tensors in FrameData - revised the definition of a bounding box inside JsonDatasetIndex: bbox_xyxy is [xmin, ymin, xmax, ymax], where xmax, ymax are not inclusive; bbox_xywh = [xmin, ymain, xmax-xmin, ymax-ymin] - is_filtered for detecting whether the entries of the dataset were somehow filtered - seq_frame_index_to_dataset_index allows to skip entries that are not present in the dataset Reviewed By: shapovalov Differential Revision: D37687547 fbshipit-source-id: 7842756b0517878cc0964fc0935d3c0769454d78 --- pytorch3d/implicitron/dataset/dataset_base.py | 13 +- .../implicitron/dataset/json_index_dataset.py | 163 ++++++++++++++---- .../implicitron/tools/point_cloud_utils.py | 1 + tests/implicitron/test_bbox.py | 78 +++++++++ tests/implicitron/test_dataset_visualize.py | 2 + 5 files changed, 224 insertions(+), 33 deletions(-) create mode 100644 tests/implicitron/test_bbox.py diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index bd5e5c7e..f3d8d615 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -61,8 +61,16 @@ class FrameData(Mapping[str, Any]): fg_probability: A Tensor of `(1, H, W)` denoting the probability of the pixels belonging to the captured object; elements are floats in [0, 1]. - bbox_xywh: The bounding box capturing the object in the - format (x0, y0, width, height). + bbox_xywh: The bounding box tightly enclosing the foreground object in the + format (x0, y0, width, height). The convention assumes that + `x0+width` and `y0+height` includes the boundary of the box. + I.e., to slice out the corresponding crop from an image tensor `I` + we execute `crop = I[..., y0:y0+height, x0:x0+width]` + crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` + in the original image coordinates in the format (x0, y0, width, height). + The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs + from `bbox_xywh` due to padding (which can happen e.g. due to + setting `JsonIndexDataset.box_crop_context > 0`) camera: A PyTorch3D camera object corresponding the frame's viewpoint, corrected for cropping if it happened. camera_quality_score: The score proportional to the confidence of the @@ -98,6 +106,7 @@ class FrameData(Mapping[str, Any]): mask_path: Union[str, List[str], None] = None fg_probability: Optional[torch.Tensor] = None bbox_xywh: Optional[torch.Tensor] = None + crop_bbox_xywh: Optional[torch.Tensor] = None camera: Optional[PerspectiveCameras] = None camera_quality_score: Optional[torch.Tensor] = None point_cloud_quality_score: Optional[torch.Tensor] = None diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 139abbd6..17ba4cec 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -160,17 +160,50 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): self._filter_db() # also computes sequence indices logger.info(str(self)) + def is_filtered(self): + """ + Returns `True` in case the dataset has been filtered and thus some frame annotations + stored on the disk might be missing in the dataset object. + + Returns: + is_filtered: `True` if the dataset has been filtered, else `False`. + """ + return ( + self.remove_empty_masks + or self.limit_to > 0 + or self.limit_sequences_to > 0 + or len(self.pick_sequence) > 0 + or len(self.exclude_sequence) > 0 + or len(self.limit_category_to) > 0 + or self.n_frames_per_sequence > 0 + ) + def seq_frame_index_to_dataset_index( self, - seq_frame_index: Union[ - List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], - ], - ) -> List[List[int]]: + seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], + allow_missing_indices: bool = False, + remove_missing_indices: bool = False, + ) -> List[List[Union[Optional[int], int]]]: """ - Obtain indices into the dataset object given a list of frames specified as - `seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`. + Obtain indices into the dataset object given a list of frame ids. + + Args: + seq_frame_index: The list of frame ids specified as + `List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally, + Image paths relative to the dataset_root can be stored specified as well: + `List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]` + allow_missing_indices: If `False`, throws an IndexError upon reaching the first + entry from `seq_frame_index` which is missing in the dataset. + Otherwise, depending on `remove_missing_indices`, either returns `None` + in place of missing entries or removes the indices of missing entries. + remove_missing_indices: Active when `allow_missing_indices=True`. + If `False`, returns `None` in place of `seq_frame_index` entries that + are not present in the dataset. + If `True` removes missing indices from the returned indices. + + Returns: + dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`. """ - # TODO: check the frame numbers are unique _dataset_seq_frame_n_index = { seq: { # pyre-ignore[16] @@ -181,8 +214,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): for seq, seq_idx in self._seq_to_idx.items() } - def _get_batch_idx(seq_name, frame_no, path=None) -> int: - idx = _dataset_seq_frame_n_index[seq_name][frame_no] + def _get_dataset_idx( + seq_name: str, frame_no: int, path: Optional[str] = None + ) -> Optional[int]: + idx_seq = _dataset_seq_frame_n_index.get(seq_name, None) + idx = idx_seq.get(frame_no, None) if idx_seq is not None else None + if idx is None: + msg = ( + f"sequence_name={seq_name} / frame_number={frame_no}" + " not in the dataset!" + ) + if not allow_missing_indices: + raise IndexError(msg) + warnings.warn(msg) + return idx if path is not None: # Check that the loaded frame path is consistent # with the one stored in self.frame_annots. @@ -191,11 +236,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): self.frame_annots[idx]["frame_annotation"].image.path ) == os.path.normpath( path - ), f"Inconsistent batch {seq_name, frame_no, path}." + ), f"Inconsistent frame indices {seq_name, frame_no, path}." return idx - batches_idx = [[_get_batch_idx(*b) for b in batch] for batch in seq_frame_index] - return batches_idx + dataset_idx = [ + [_get_dataset_idx(*b) for b in batch] # pyre-ignore [6] + for batch in seq_frame_index + ] + + if allow_missing_indices and remove_missing_indices: + # remove all None indices, and also batches with only None entries + valid_dataset_idx = [ + [b for b in batch if b is not None] for batch in dataset_idx + ] + return [ # pyre-ignore[7] + batch for batch in valid_dataset_idx if len(batch) > 0 + ] + + return dataset_idx def __str__(self) -> str: # pyre-ignore[16] @@ -254,6 +312,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): frame_data.mask_path, frame_data.bbox_xywh, clamp_bbox_xyxy, + frame_data.crop_bbox_xywh, ) = self._load_crop_fg_probability(entry) scale = 1.0 @@ -301,13 +360,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): Optional[str], Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], ]: - fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy = ( - None, - None, - None, - None, - ) + fg_probability = None + full_path = None + bbox_xywh = None + clamp_bbox_xyxy = None + crop_box_xywh = None + if (self.load_masks or self.box_crop) and entry.mask is not None: full_path = os.path.join(self.dataset_root, entry.mask.path) mask = _load_mask(self._local_path(full_path)) @@ -320,11 +380,22 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) if self.box_crop: - clamp_bbox_xyxy = _get_clamp_bbox(bbox_xywh, self.box_crop_context) + clamp_bbox_xyxy = _get_clamp_bbox( + bbox_xywh, + image_path=entry.image.path, + box_crop_context=self.box_crop_context, + ) mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) + crop_box_xyxy = _clamp_box_to_image_bounds_and_round( + clamp_bbox_xyxy, + image_size_hw=tuple(mask.shape[-2:]), + ) + crop_box_xywh = _bbox_xyxy_to_xywh(crop_box_xyxy) + fg_probability, _, _ = self._resize_image(mask, mode="nearest") - return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy + + return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh def _load_crop_images( self, @@ -746,7 +817,7 @@ def _load_mask(path) -> np.ndarray: def _get_1d_bounds(arr) -> Tuple[int, int]: nz = np.flatnonzero(arr) - return nz[0], nz[-1] + return nz[0], nz[-1] + 1 def _get_bbox_from_mask( @@ -767,11 +838,15 @@ def _get_bbox_from_mask( def _get_clamp_bbox( - bbox: torch.Tensor, box_crop_context: float = 0.0, impath: str = "" + bbox: torch.Tensor, + box_crop_context: float = 0.0, + image_path: str = "", ) -> torch.Tensor: # box_crop_context: rate of expansion for bbox # returns possibly expanded bbox xyxy as float + bbox = bbox.clone() # do not edit bbox in place + # increase box size if box_crop_context > 0.0: c = box_crop_context @@ -783,27 +858,37 @@ def _get_clamp_bbox( if (bbox[2:] <= 1.0).any(): raise ValueError( - f"squashed image {impath}!! The bounding box contains no pixels." + f"squashed image {image_path}!! The bounding box contains no pixels." ) - bbox[2:] = torch.clamp(bbox[2:], 2) - bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax] - # +1 because upper bound is not inclusive + bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes + bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) - return bbox + return bbox_xyxy def _crop_around_box(tensor, bbox, impath: str = ""): # bbox is xyxy, where the upper bound is corrected with +1 - bbox[[0, 2]] = torch.clamp(bbox[[0, 2]], 0.0, tensor.shape[-1]) - bbox[[1, 3]] = torch.clamp(bbox[[1, 3]], 0.0, tensor.shape[-2]) - bbox = bbox.round().long() + bbox = _clamp_box_to_image_bounds_and_round( + bbox, + image_size_hw=tensor.shape[-2:], + ) tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" - return tensor +def _clamp_box_to_image_bounds_and_round( + bbox_xyxy: torch.Tensor, + image_size_hw: Tuple[int, int], +) -> torch.LongTensor: + bbox_xyxy = bbox_xyxy.clone() + bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0.0, image_size_hw[-1]) + bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0.0, image_size_hw[-2]) + bbox_xyxy = bbox_xyxy.round().long() + return bbox_xyxy # pyre-ignore [7] + + def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: assert bbox is not None assert np.prod(orig_res) > 1e-8 @@ -812,6 +897,22 @@ def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: return bbox * rel_size +def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: + wh = xyxy[2:] - xyxy[:2] + xywh = torch.cat([xyxy[:2], wh]) + return xywh + + +def _bbox_xywh_to_xyxy( + xywh: torch.Tensor, clamp_size: Optional[int] = None +) -> torch.Tensor: + xyxy = xywh.clone() + if clamp_size is not None: + xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) + xyxy[2:] += xyxy[:2] + return xyxy + + def _safe_as_tensor(data, dtype): if data is None: return None diff --git a/pytorch3d/implicitron/tools/point_cloud_utils.py b/pytorch3d/implicitron/tools/point_cloud_utils.py index edcb16f3..f6533d42 100644 --- a/pytorch3d/implicitron/tools/point_cloud_utils.py +++ b/pytorch3d/implicitron/tools/point_cloud_utils.py @@ -142,6 +142,7 @@ def render_point_cloud_pytorch3d( rendered_blob, size=tuple(render_size), mode="bilinear", + align_corners=False, ) data_rendered, depth_rendered, render_mask = rendered_blob.split( diff --git a/tests/implicitron/test_bbox.py b/tests/implicitron/test_bbox.py new file mode 100644 index 00000000..999dfc92 --- /dev/null +++ b/tests/implicitron/test_bbox.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import numpy as np + +import torch +from pytorch3d.implicitron.dataset.json_index_dataset import ( + _bbox_xywh_to_xyxy, + _bbox_xyxy_to_xywh, + _get_bbox_from_mask, +) +from tests.common_testing import TestCaseMixin + + +class TestBBox(TestCaseMixin, unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + def test_bbox_conversion(self): + bbox_xywh_list = torch.LongTensor( + [ + [0, 0, 10, 20], + [10, 20, 5, 1], + [10, 20, 1, 1], + [5, 4, 0, 1], + ] + ) + for bbox_xywh in bbox_xywh_list: + bbox_xyxy = _bbox_xywh_to_xyxy(bbox_xywh) + bbox_xywh_ = _bbox_xyxy_to_xywh(bbox_xyxy) + bbox_xyxy_ = _bbox_xywh_to_xyxy(bbox_xywh_) + self.assertClose(bbox_xywh_, bbox_xywh) + self.assertClose(bbox_xyxy, bbox_xyxy_) + + def test_compare_to_expected(self): + bbox_xywh_to_xyxy_expected = torch.LongTensor( + [ + [[0, 0, 10, 20], [0, 0, 10, 20]], + [[10, 20, 5, 1], [10, 20, 15, 21]], + [[10, 20, 1, 1], [10, 20, 11, 21]], + [[5, 4, 0, 1], [5, 4, 5, 5]], + ] + ) + for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected: + self.assertClose(_bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected) + self.assertClose(_bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh) + + clamp_amnt = 3 + bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor( + [ + [[0, 0, 10, 20], [0, 0, 10, 20]], + [[10, 20, 5, 1], [10, 20, 15, 20 + clamp_amnt]], + [[10, 20, 1, 1], [10, 20, 10 + clamp_amnt, 20 + clamp_amnt]], + [[5, 4, 0, 1], [5, 4, 5 + clamp_amnt, 4 + clamp_amnt]], + ] + ) + for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected: + self.assertClose( + _bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt), + bbox_xyxy_expected, + ) + + def test_mask_to_bbox(self): + mask = np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32) + expected_bbox_xywh = [2, 1, 2, 1] + bbox_xywh = _get_bbox_from_mask(mask, 0.5) + self.assertClose(bbox_xywh, expected_bbox_xywh) diff --git a/tests/implicitron/test_dataset_visualize.py b/tests/implicitron/test_dataset_visualize.py index 2eb4d5c5..c7b66d8b 100644 --- a/tests/implicitron/test_dataset_visualize.py +++ b/tests/implicitron/test_dataset_visualize.py @@ -14,6 +14,7 @@ import torch import torchvision from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud +from pytorch3d.implicitron.tools.config import expand_args_fields from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d from pytorch3d.vis.plotly_vis import plot_scene @@ -37,6 +38,7 @@ class TestDatasetVisualize(unittest.TestCase): frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") self.image_size = 256 + expand_args_fields(JsonIndexDataset) self.datasets = { "simple": JsonIndexDataset( frame_annotations_file=frame_file,