Loading fg probability from the alpha channel of image_rgb

Summary:
It is often easier to store the mask together with RGB, especially for renders. The logic in this diff:
* if load_mask and mask_path provided, take the mask from mask_path,
* otherwise, check if the image has the alpha channel and take it as a mask.

Reviewed By: antoinetlc

Differential Revision: D68160212

fbshipit-source-id: d9b6779f90027a4987ba96800983f441edff9c74
This commit is contained in:
Roman Shapovalov 2025-01-15 11:53:30 -08:00 committed by Facebook GitHub Bot
parent 89b851e64c
commit 49cf5a0f37
2 changed files with 41 additions and 8 deletions

View File

@ -589,7 +589,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
),
)
fg_mask_np: Optional[np.ndarray] = None
fg_mask_np: np.ndarray | None = None
bbox_xywh: tuple[float, float, float, float] | None = None
mask_annotation = frame_annotation.mask
if mask_annotation is not None:
if load_blobs and self.load_masks:
@ -598,10 +599,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
bbox_xywh = mask_annotation.bounding_box_xywh
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
@ -618,11 +615,27 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
if image_path is None:
raise ValueError("Image path is required to load images.")
image_np = load_image(self._local_path(image_path))
no_mask = fg_mask_np is None # didnt read the mask file
image_np = load_image(
self._local_path(image_path), try_read_alpha=no_mask
)
if image_np.shape[0] == 4: # RGBA image
if no_mask:
fg_mask_np = image_np[3:]
frame_data.fg_probability = safe_as_tensor(
fg_mask_np, torch.float
)
image_np = image_np[:3]
frame_data.image_rgb = self._postprocess_image(
image_np, frame_annotation.image.size, frame_data.fg_probability
)
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
depth_annotation = frame_annotation.depth
if (
load_blobs

View File

@ -87,6 +87,15 @@ def is_train_frame(
def get_bbox_from_mask(
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# these corner cases need to be handled in order to avoid an infinite loop
if mask.size == 0:
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
return 0, 0, 1, 1
if not mask.min() >= 0.0:
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
mask = mask.clip(min=0.0)
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
@ -229,9 +238,20 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
return im.astype(np.float32) / 255.0
def load_image(path: str) -> np.ndarray:
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
"""
Load an image from a path and return it as a numpy array.
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
returned as the fourth channel.
Otherwise, the image is read as RGB and a three-channel image is returned.
"""
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
# Check if the image has an alpha channel
if try_read_alpha and pil_im.mode == "RGBA":
im = np.array(pil_im)
else:
im = np.array(pil_im.convert("RGB"))
return transpose_normalize_image(im)