mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Loading fg probability from the alpha channel of image_rgb
Summary: It is often easier to store the mask together with RGB, especially for renders. The logic in this diff: * if load_mask and mask_path provided, take the mask from mask_path, * otherwise, check if the image has the alpha channel and take it as a mask. Reviewed By: antoinetlc Differential Revision: D68160212 fbshipit-source-id: d9b6779f90027a4987ba96800983f441edff9c74
This commit is contained in:
parent
89b851e64c
commit
49cf5a0f37
@ -589,7 +589,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
),
|
||||
)
|
||||
|
||||
fg_mask_np: Optional[np.ndarray] = None
|
||||
fg_mask_np: np.ndarray | None = None
|
||||
bbox_xywh: tuple[float, float, float, float] | None = None
|
||||
mask_annotation = frame_annotation.mask
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks:
|
||||
@ -598,10 +599,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
|
||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if frame_annotation.image is not None:
|
||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||
@ -618,11 +615,27 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
if image_path is None:
|
||||
raise ValueError("Image path is required to load images.")
|
||||
|
||||
image_np = load_image(self._local_path(image_path))
|
||||
no_mask = fg_mask_np is None # didn’t read the mask file
|
||||
image_np = load_image(
|
||||
self._local_path(image_path), try_read_alpha=no_mask
|
||||
)
|
||||
if image_np.shape[0] == 4: # RGBA image
|
||||
if no_mask:
|
||||
fg_mask_np = image_np[3:]
|
||||
frame_data.fg_probability = safe_as_tensor(
|
||||
fg_mask_np, torch.float
|
||||
)
|
||||
|
||||
image_np = image_np[:3]
|
||||
|
||||
frame_data.image_rgb = self._postprocess_image(
|
||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||
)
|
||||
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
depth_annotation = frame_annotation.depth
|
||||
if (
|
||||
load_blobs
|
||||
|
@ -87,6 +87,15 @@ def is_train_frame(
|
||||
def get_bbox_from_mask(
|
||||
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
||||
) -> Tuple[int, int, int, int]:
|
||||
# these corner cases need to be handled in order to avoid an infinite loop
|
||||
if mask.size == 0:
|
||||
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
|
||||
return 0, 0, 1, 1
|
||||
|
||||
if not mask.min() >= 0.0:
|
||||
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
|
||||
mask = mask.clip(min=0.0)
|
||||
|
||||
# bbox in xywh
|
||||
masks_for_box = np.zeros_like(mask)
|
||||
while masks_for_box.sum() <= 1.0:
|
||||
@ -229,9 +238,20 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
return im.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Load an image from a path and return it as a numpy array.
|
||||
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
||||
returned as the fourth channel.
|
||||
Otherwise, the image is read as RGB and a three-channel image is returned.
|
||||
"""
|
||||
|
||||
with Image.open(path) as pil_im:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
# Check if the image has an alpha channel
|
||||
if try_read_alpha and pil_im.mode == "RGBA":
|
||||
im = np.array(pil_im)
|
||||
else:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
|
||||
return transpose_normalize_image(im)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user