mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
When bounding boxes are cached in metadata, don’t crash on load_masks=False
Summary: We currently support caching bounding boxes in MaskAnnotation. If present, they are not re-computed from the mask. However, the masks need to be loaded for the bbox to be set. This diff fixes that. Even if load_masks / load_blobs are unset, the bounding box can be picked up from the metadata. Reviewed By: bottler Differential Revision: D45144918 fbshipit-source-id: 8a2e2c115e96070b6fcdc29cbe57e1cee606ddcd
This commit is contained in:
parent
0e3138eca8
commit
7aeedd17a4
@ -555,12 +555,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
else None,
|
||||
)
|
||||
|
||||
if load_blobs and self.load_masks and frame_annotation.mask is not None:
|
||||
(
|
||||
frame_data.fg_probability,
|
||||
frame_data.mask_path,
|
||||
frame_data.bbox_xywh,
|
||||
) = self._load_fg_probability(frame_annotation)
|
||||
mask_annotation = frame_annotation.mask
|
||||
if mask_annotation is not None:
|
||||
fg_mask_np: Optional[np.ndarray] = None
|
||||
if load_blobs and self.load_masks:
|
||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||
frame_data.mask_path = mask_path
|
||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
|
||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
|
||||
|
||||
if frame_annotation.image is not None:
|
||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||
@ -604,25 +611,15 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def _load_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:
|
||||
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
|
||||
fg_probability = load_mask(self._local_path(full_path))
|
||||
# we can use provided bbox_xywh or calculate it based on mask
|
||||
# saves time to skip bbox calculation
|
||||
# pyre-ignore
|
||||
bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
|
||||
fg_probability, self.box_crop_mask_thr
|
||||
)
|
||||
if fg_probability.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
return (
|
||||
safe_as_tensor(fg_probability, torch.float),
|
||||
full_path,
|
||||
safe_as_tensor(bbox_xywh, torch.long),
|
||||
)
|
||||
|
||||
return fg_probability, full_path
|
||||
|
||||
def _load_images(
|
||||
self,
|
||||
|
@ -17,6 +17,7 @@ from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
get_bbox_from_mask,
|
||||
load_16big_png_depth,
|
||||
load_1bit_png_mask,
|
||||
load_depth,
|
||||
@ -107,11 +108,14 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
|
||||
|
||||
(
|
||||
self.frame_data.fg_probability,
|
||||
self.frame_data.mask_path,
|
||||
self.frame_data.bbox_xywh,
|
||||
) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
|
||||
fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability(
|
||||
self.frame_annotation
|
||||
)
|
||||
self.frame_data.mask_path = mask_path
|
||||
self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
mask_thr = self.frame_data_builder.box_crop_mask_thr
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr)
|
||||
self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
|
||||
|
||||
self.assertIsNotNone(self.frame_data.mask_path)
|
||||
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))
|
||||
|
Loading…
x
Reference in New Issue
Block a user