mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Changes for CO3Dv2 release [part1]
Summary: Implements several changes needed for the CO3Dv2 release: - FrameData contains crop_bbox_xywh which defines the outline of the image crop corresponding to the image-shaped tensors in FrameData - revised the definition of a bounding box inside JsonDatasetIndex: bbox_xyxy is [xmin, ymin, xmax, ymax], where xmax, ymax are not inclusive; bbox_xywh = [xmin, ymain, xmax-xmin, ymax-ymin] - is_filtered for detecting whether the entries of the dataset were somehow filtered - seq_frame_index_to_dataset_index allows to skip entries that are not present in the dataset Reviewed By: shapovalov Differential Revision: D37687547 fbshipit-source-id: 7842756b0517878cc0964fc0935d3c0769454d78
This commit is contained in:
parent
00acf0b0c7
commit
4300030d7a
@ -61,8 +61,16 @@ class FrameData(Mapping[str, Any]):
|
||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||
pixels belonging to the captured object; elements are floats
|
||||
in [0, 1].
|
||||
bbox_xywh: The bounding box capturing the object in the
|
||||
format (x0, y0, width, height).
|
||||
bbox_xywh: The bounding box tightly enclosing the foreground object in the
|
||||
format (x0, y0, width, height). The convention assumes that
|
||||
`x0+width` and `y0+height` includes the boundary of the box.
|
||||
I.e., to slice out the corresponding crop from an image tensor `I`
|
||||
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
|
||||
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
|
||||
in the original image coordinates in the format (x0, y0, width, height).
|
||||
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
|
||||
from `bbox_xywh` due to padding (which can happen e.g. due to
|
||||
setting `JsonIndexDataset.box_crop_context > 0`)
|
||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||
corrected for cropping if it happened.
|
||||
camera_quality_score: The score proportional to the confidence of the
|
||||
@ -98,6 +106,7 @@ class FrameData(Mapping[str, Any]):
|
||||
mask_path: Union[str, List[str], None] = None
|
||||
fg_probability: Optional[torch.Tensor] = None
|
||||
bbox_xywh: Optional[torch.Tensor] = None
|
||||
crop_bbox_xywh: Optional[torch.Tensor] = None
|
||||
camera: Optional[PerspectiveCameras] = None
|
||||
camera_quality_score: Optional[torch.Tensor] = None
|
||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||
|
@ -160,17 +160,50 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self._filter_db() # also computes sequence indices
|
||||
logger.info(str(self))
|
||||
|
||||
def is_filtered(self):
|
||||
"""
|
||||
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||
stored on the disk might be missing in the dataset object.
|
||||
|
||||
Returns:
|
||||
is_filtered: `True` if the dataset has been filtered, else `False`.
|
||||
"""
|
||||
return (
|
||||
self.remove_empty_masks
|
||||
or self.limit_to > 0
|
||||
or self.limit_sequences_to > 0
|
||||
or len(self.pick_sequence) > 0
|
||||
or len(self.exclude_sequence) > 0
|
||||
or len(self.limit_category_to) > 0
|
||||
or self.n_frames_per_sequence > 0
|
||||
)
|
||||
|
||||
def seq_frame_index_to_dataset_index(
|
||||
self,
|
||||
seq_frame_index: Union[
|
||||
List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||
],
|
||||
) -> List[List[int]]:
|
||||
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||
allow_missing_indices: bool = False,
|
||||
remove_missing_indices: bool = False,
|
||||
) -> List[List[Union[Optional[int], int]]]:
|
||||
"""
|
||||
Obtain indices into the dataset object given a list of frames specified as
|
||||
`seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`.
|
||||
Obtain indices into the dataset object given a list of frame ids.
|
||||
|
||||
Args:
|
||||
seq_frame_index: The list of frame ids specified as
|
||||
`List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally,
|
||||
Image paths relative to the dataset_root can be stored specified as well:
|
||||
`List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]`
|
||||
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
||||
entry from `seq_frame_index` which is missing in the dataset.
|
||||
Otherwise, depending on `remove_missing_indices`, either returns `None`
|
||||
in place of missing entries or removes the indices of missing entries.
|
||||
remove_missing_indices: Active when `allow_missing_indices=True`.
|
||||
If `False`, returns `None` in place of `seq_frame_index` entries that
|
||||
are not present in the dataset.
|
||||
If `True` removes missing indices from the returned indices.
|
||||
|
||||
Returns:
|
||||
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
||||
"""
|
||||
# TODO: check the frame numbers are unique
|
||||
_dataset_seq_frame_n_index = {
|
||||
seq: {
|
||||
# pyre-ignore[16]
|
||||
@ -181,8 +214,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
for seq, seq_idx in self._seq_to_idx.items()
|
||||
}
|
||||
|
||||
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
|
||||
idx = _dataset_seq_frame_n_index[seq_name][frame_no]
|
||||
def _get_dataset_idx(
|
||||
seq_name: str, frame_no: int, path: Optional[str] = None
|
||||
) -> Optional[int]:
|
||||
idx_seq = _dataset_seq_frame_n_index.get(seq_name, None)
|
||||
idx = idx_seq.get(frame_no, None) if idx_seq is not None else None
|
||||
if idx is None:
|
||||
msg = (
|
||||
f"sequence_name={seq_name} / frame_number={frame_no}"
|
||||
" not in the dataset!"
|
||||
)
|
||||
if not allow_missing_indices:
|
||||
raise IndexError(msg)
|
||||
warnings.warn(msg)
|
||||
return idx
|
||||
if path is not None:
|
||||
# Check that the loaded frame path is consistent
|
||||
# with the one stored in self.frame_annots.
|
||||
@ -191,11 +236,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self.frame_annots[idx]["frame_annotation"].image.path
|
||||
) == os.path.normpath(
|
||||
path
|
||||
), f"Inconsistent batch {seq_name, frame_no, path}."
|
||||
), f"Inconsistent frame indices {seq_name, frame_no, path}."
|
||||
return idx
|
||||
|
||||
batches_idx = [[_get_batch_idx(*b) for b in batch] for batch in seq_frame_index]
|
||||
return batches_idx
|
||||
dataset_idx = [
|
||||
[_get_dataset_idx(*b) for b in batch] # pyre-ignore [6]
|
||||
for batch in seq_frame_index
|
||||
]
|
||||
|
||||
if allow_missing_indices and remove_missing_indices:
|
||||
# remove all None indices, and also batches with only None entries
|
||||
valid_dataset_idx = [
|
||||
[b for b in batch if b is not None] for batch in dataset_idx
|
||||
]
|
||||
return [ # pyre-ignore[7]
|
||||
batch for batch in valid_dataset_idx if len(batch) > 0
|
||||
]
|
||||
|
||||
return dataset_idx
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pyre-ignore[16]
|
||||
@ -254,6 +312,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
frame_data.mask_path,
|
||||
frame_data.bbox_xywh,
|
||||
clamp_bbox_xyxy,
|
||||
frame_data.crop_bbox_xywh,
|
||||
) = self._load_crop_fg_probability(entry)
|
||||
|
||||
scale = 1.0
|
||||
@ -301,13 +360,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
Optional[str],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
fg_probability = None
|
||||
full_path = None
|
||||
bbox_xywh = None
|
||||
clamp_bbox_xyxy = None
|
||||
crop_box_xywh = None
|
||||
|
||||
if (self.load_masks or self.box_crop) and entry.mask is not None:
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
mask = _load_mask(self._local_path(full_path))
|
||||
@ -320,11 +380,22 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
||||
|
||||
if self.box_crop:
|
||||
clamp_bbox_xyxy = _get_clamp_bbox(bbox_xywh, self.box_crop_context)
|
||||
clamp_bbox_xyxy = _get_clamp_bbox(
|
||||
bbox_xywh,
|
||||
image_path=entry.image.path,
|
||||
box_crop_context=self.box_crop_context,
|
||||
)
|
||||
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
||||
|
||||
crop_box_xyxy = _clamp_box_to_image_bounds_and_round(
|
||||
clamp_bbox_xyxy,
|
||||
image_size_hw=tuple(mask.shape[-2:]),
|
||||
)
|
||||
crop_box_xywh = _bbox_xyxy_to_xywh(crop_box_xyxy)
|
||||
|
||||
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
|
||||
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy
|
||||
|
||||
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
|
||||
|
||||
def _load_crop_images(
|
||||
self,
|
||||
@ -746,7 +817,7 @@ def _load_mask(path) -> np.ndarray:
|
||||
|
||||
def _get_1d_bounds(arr) -> Tuple[int, int]:
|
||||
nz = np.flatnonzero(arr)
|
||||
return nz[0], nz[-1]
|
||||
return nz[0], nz[-1] + 1
|
||||
|
||||
|
||||
def _get_bbox_from_mask(
|
||||
@ -767,11 +838,15 @@ def _get_bbox_from_mask(
|
||||
|
||||
|
||||
def _get_clamp_bbox(
|
||||
bbox: torch.Tensor, box_crop_context: float = 0.0, impath: str = ""
|
||||
bbox: torch.Tensor,
|
||||
box_crop_context: float = 0.0,
|
||||
image_path: str = "",
|
||||
) -> torch.Tensor:
|
||||
# box_crop_context: rate of expansion for bbox
|
||||
# returns possibly expanded bbox xyxy as float
|
||||
|
||||
bbox = bbox.clone() # do not edit bbox in place
|
||||
|
||||
# increase box size
|
||||
if box_crop_context > 0.0:
|
||||
c = box_crop_context
|
||||
@ -783,27 +858,37 @@ def _get_clamp_bbox(
|
||||
|
||||
if (bbox[2:] <= 1.0).any():
|
||||
raise ValueError(
|
||||
f"squashed image {impath}!! The bounding box contains no pixels."
|
||||
f"squashed image {image_path}!! The bounding box contains no pixels."
|
||||
)
|
||||
|
||||
bbox[2:] = torch.clamp(bbox[2:], 2)
|
||||
bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax]
|
||||
# +1 because upper bound is not inclusive
|
||||
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
|
||||
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
|
||||
|
||||
return bbox
|
||||
return bbox_xyxy
|
||||
|
||||
|
||||
def _crop_around_box(tensor, bbox, impath: str = ""):
|
||||
# bbox is xyxy, where the upper bound is corrected with +1
|
||||
bbox[[0, 2]] = torch.clamp(bbox[[0, 2]], 0.0, tensor.shape[-1])
|
||||
bbox[[1, 3]] = torch.clamp(bbox[[1, 3]], 0.0, tensor.shape[-2])
|
||||
bbox = bbox.round().long()
|
||||
bbox = _clamp_box_to_image_bounds_and_round(
|
||||
bbox,
|
||||
image_size_hw=tensor.shape[-2:],
|
||||
)
|
||||
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
||||
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def _clamp_box_to_image_bounds_and_round(
|
||||
bbox_xyxy: torch.Tensor,
|
||||
image_size_hw: Tuple[int, int],
|
||||
) -> torch.LongTensor:
|
||||
bbox_xyxy = bbox_xyxy.clone()
|
||||
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0.0, image_size_hw[-1])
|
||||
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0.0, image_size_hw[-2])
|
||||
bbox_xyxy = bbox_xyxy.round().long()
|
||||
return bbox_xyxy # pyre-ignore [7]
|
||||
|
||||
|
||||
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
||||
assert bbox is not None
|
||||
assert np.prod(orig_res) > 1e-8
|
||||
@ -812,6 +897,22 @@ def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
||||
return bbox * rel_size
|
||||
|
||||
|
||||
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
|
||||
wh = xyxy[2:] - xyxy[:2]
|
||||
xywh = torch.cat([xyxy[:2], wh])
|
||||
return xywh
|
||||
|
||||
|
||||
def _bbox_xywh_to_xyxy(
|
||||
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
xyxy = xywh.clone()
|
||||
if clamp_size is not None:
|
||||
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||
xyxy[2:] += xyxy[:2]
|
||||
return xyxy
|
||||
|
||||
|
||||
def _safe_as_tensor(data, dtype):
|
||||
if data is None:
|
||||
return None
|
||||
|
@ -142,6 +142,7 @@ def render_point_cloud_pytorch3d(
|
||||
rendered_blob,
|
||||
size=tuple(render_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
data_rendered, depth_rendered, render_mask = rendered_blob.split(
|
||||
|
78
tests/implicitron/test_bbox.py
Normal file
78
tests/implicitron/test_bbox.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import (
|
||||
_bbox_xywh_to_xyxy,
|
||||
_bbox_xyxy_to_xywh,
|
||||
_get_bbox_from_mask,
|
||||
)
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestBBox(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_bbox_conversion(self):
|
||||
bbox_xywh_list = torch.LongTensor(
|
||||
[
|
||||
[0, 0, 10, 20],
|
||||
[10, 20, 5, 1],
|
||||
[10, 20, 1, 1],
|
||||
[5, 4, 0, 1],
|
||||
]
|
||||
)
|
||||
for bbox_xywh in bbox_xywh_list:
|
||||
bbox_xyxy = _bbox_xywh_to_xyxy(bbox_xywh)
|
||||
bbox_xywh_ = _bbox_xyxy_to_xywh(bbox_xyxy)
|
||||
bbox_xyxy_ = _bbox_xywh_to_xyxy(bbox_xywh_)
|
||||
self.assertClose(bbox_xywh_, bbox_xywh)
|
||||
self.assertClose(bbox_xyxy, bbox_xyxy_)
|
||||
|
||||
def test_compare_to_expected(self):
|
||||
bbox_xywh_to_xyxy_expected = torch.LongTensor(
|
||||
[
|
||||
[[0, 0, 10, 20], [0, 0, 10, 20]],
|
||||
[[10, 20, 5, 1], [10, 20, 15, 21]],
|
||||
[[10, 20, 1, 1], [10, 20, 11, 21]],
|
||||
[[5, 4, 0, 1], [5, 4, 5, 5]],
|
||||
]
|
||||
)
|
||||
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected:
|
||||
self.assertClose(_bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
|
||||
self.assertClose(_bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
|
||||
|
||||
clamp_amnt = 3
|
||||
bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor(
|
||||
[
|
||||
[[0, 0, 10, 20], [0, 0, 10, 20]],
|
||||
[[10, 20, 5, 1], [10, 20, 15, 20 + clamp_amnt]],
|
||||
[[10, 20, 1, 1], [10, 20, 10 + clamp_amnt, 20 + clamp_amnt]],
|
||||
[[5, 4, 0, 1], [5, 4, 5 + clamp_amnt, 4 + clamp_amnt]],
|
||||
]
|
||||
)
|
||||
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected:
|
||||
self.assertClose(
|
||||
_bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
|
||||
bbox_xyxy_expected,
|
||||
)
|
||||
|
||||
def test_mask_to_bbox(self):
|
||||
mask = np.array(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
).astype(np.float32)
|
||||
expected_bbox_xywh = [2, 1, 2, 1]
|
||||
bbox_xywh = _get_bbox_from_mask(mask, 0.5)
|
||||
self.assertClose(bbox_xywh, expected_bbox_xywh)
|
@ -14,6 +14,7 @@ import torch
|
||||
import torchvision
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
|
||||
from pytorch3d.vis.plotly_vis import plot_scene
|
||||
|
||||
@ -37,6 +38,7 @@ class TestDatasetVisualize(unittest.TestCase):
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
self.image_size = 256
|
||||
expand_args_fields(JsonIndexDataset)
|
||||
self.datasets = {
|
||||
"simple": JsonIndexDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
|
Loading…
x
Reference in New Issue
Block a user