When bounding boxes are cached in metadata, don’t crash on load_masks=False

Summary:
We currently support caching bounding boxes in MaskAnnotation. If present, they are not re-computed from the mask. However, the masks need to be loaded for the bbox to be set.

This diff fixes that. Even if load_masks / load_blobs are unset, the bounding box can be picked up from the metadata.

Reviewed By: bottler

Differential Revision: D45144918

fbshipit-source-id: 8a2e2c115e96070b6fcdc29cbe57e1cee606ddcd
This commit is contained in:
Roman Shapovalov 2023-04-20 07:28:45 -07:00 committed by Facebook GitHub Bot
parent 0e3138eca8
commit 7aeedd17a4
2 changed files with 25 additions and 24 deletions

View File

@ -555,12 +555,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
else None, else None,
) )
if load_blobs and self.load_masks and frame_annotation.mask is not None: mask_annotation = frame_annotation.mask
( if mask_annotation is not None:
frame_data.fg_probability, fg_mask_np: Optional[np.ndarray] = None
frame_data.mask_path, if load_blobs and self.load_masks:
frame_data.bbox_xywh, fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
) = self._load_fg_probability(frame_annotation) frame_data.mask_path = mask_path
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
bbox_xywh = mask_annotation.bounding_box_xywh
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
if frame_annotation.image is not None: if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long) image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
@ -604,25 +611,15 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_fg_probability( def _load_fg_probability(
self, entry: types.FrameAnnotation self, entry: types.FrameAnnotation
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]: ) -> Tuple[np.ndarray, str]:
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
fg_probability = load_mask(self._local_path(full_path)) fg_probability = load_mask(self._local_path(full_path))
# we can use provided bbox_xywh or calculate it based on mask
# saves time to skip bbox calculation
# pyre-ignore
bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
fg_probability, self.box_crop_mask_thr
)
if fg_probability.shape[-2:] != entry.image.size: if fg_probability.shape[-2:] != entry.image.size:
raise ValueError( raise ValueError(
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!" f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
) )
return (
safe_as_tensor(fg_probability, torch.float), return fg_probability, full_path
full_path,
safe_as_tensor(bbox_xywh, torch.long),
)
def _load_images( def _load_images(
self, self,

View File

@ -17,6 +17,7 @@ from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
from pytorch3d.implicitron.dataset.utils import ( from pytorch3d.implicitron.dataset.utils import (
get_bbox_from_mask,
load_16big_png_depth, load_16big_png_depth,
load_1bit_png_mask, load_1bit_png_mask,
load_depth, load_depth,
@ -107,11 +108,14 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
) )
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
( fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability(
self.frame_data.fg_probability, self.frame_annotation
self.frame_data.mask_path, )
self.frame_data.bbox_xywh, self.frame_data.mask_path = mask_path
) = self.frame_data_builder._load_fg_probability(self.frame_annotation) self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
mask_thr = self.frame_data_builder.box_crop_mask_thr
bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr)
self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
self.assertIsNotNone(self.frame_data.mask_path) self.assertIsNotNone(self.frame_data.mask_path)
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability)) self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))