mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Loading fg probability from the alpha channel of image_rgb
Summary: It is often easier to store the mask together with RGB, especially for renders. The logic in this diff: * if load_mask and mask_path provided, take the mask from mask_path, * otherwise, check if the image has the alpha channel and take it as a mask. Reviewed By: antoinetlc Differential Revision: D68160212 fbshipit-source-id: d9b6779f90027a4987ba96800983f441edff9c74
This commit is contained in:
		
							parent
							
								
									89b851e64c
								
							
						
					
					
						commit
						49cf5a0f37
					
				@ -589,7 +589,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        fg_mask_np: Optional[np.ndarray] = None
 | 
			
		||||
        fg_mask_np: np.ndarray | None = None
 | 
			
		||||
        bbox_xywh: tuple[float, float, float, float] | None = None
 | 
			
		||||
        mask_annotation = frame_annotation.mask
 | 
			
		||||
        if mask_annotation is not None:
 | 
			
		||||
            if load_blobs and self.load_masks:
 | 
			
		||||
@ -598,10 +599,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
                frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
 | 
			
		||||
 | 
			
		||||
            bbox_xywh = mask_annotation.bounding_box_xywh
 | 
			
		||||
            if bbox_xywh is None and fg_mask_np is not None:
 | 
			
		||||
                bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
 | 
			
		||||
 | 
			
		||||
            frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
 | 
			
		||||
 | 
			
		||||
        if frame_annotation.image is not None:
 | 
			
		||||
            image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
 | 
			
		||||
@ -618,11 +615,27 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
                if image_path is None:
 | 
			
		||||
                    raise ValueError("Image path is required to load images.")
 | 
			
		||||
 | 
			
		||||
                image_np = load_image(self._local_path(image_path))
 | 
			
		||||
                no_mask = fg_mask_np is None  # didn’t read the mask file
 | 
			
		||||
                image_np = load_image(
 | 
			
		||||
                    self._local_path(image_path), try_read_alpha=no_mask
 | 
			
		||||
                )
 | 
			
		||||
                if image_np.shape[0] == 4:  # RGBA image
 | 
			
		||||
                    if no_mask:
 | 
			
		||||
                        fg_mask_np = image_np[3:]
 | 
			
		||||
                        frame_data.fg_probability = safe_as_tensor(
 | 
			
		||||
                            fg_mask_np, torch.float
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    image_np = image_np[:3]
 | 
			
		||||
 | 
			
		||||
                frame_data.image_rgb = self._postprocess_image(
 | 
			
		||||
                    image_np, frame_annotation.image.size, frame_data.fg_probability
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if bbox_xywh is None and fg_mask_np is not None:
 | 
			
		||||
            bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
 | 
			
		||||
        frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
 | 
			
		||||
 | 
			
		||||
        depth_annotation = frame_annotation.depth
 | 
			
		||||
        if (
 | 
			
		||||
            load_blobs
 | 
			
		||||
 | 
			
		||||
@ -87,6 +87,15 @@ def is_train_frame(
 | 
			
		||||
def get_bbox_from_mask(
 | 
			
		||||
    mask: np.ndarray, thr: float, decrease_quant: float = 0.05
 | 
			
		||||
) -> Tuple[int, int, int, int]:
 | 
			
		||||
    # these corner cases need to be handled in order to avoid an infinite loop
 | 
			
		||||
    if mask.size == 0:
 | 
			
		||||
        warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
 | 
			
		||||
        return 0, 0, 1, 1
 | 
			
		||||
 | 
			
		||||
    if not mask.min() >= 0.0:
 | 
			
		||||
        warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
 | 
			
		||||
        mask = mask.clip(min=0.0)
 | 
			
		||||
 | 
			
		||||
    # bbox in xywh
 | 
			
		||||
    masks_for_box = np.zeros_like(mask)
 | 
			
		||||
    while masks_for_box.sum() <= 1.0:
 | 
			
		||||
@ -229,9 +238,20 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
 | 
			
		||||
    return im.astype(np.float32) / 255.0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_image(path: str) -> np.ndarray:
 | 
			
		||||
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
 | 
			
		||||
    """
 | 
			
		||||
    Load an image from a path and return it as a numpy array.
 | 
			
		||||
    If try_read_alpha is True, the image is read as RGBA and the alpha channel is
 | 
			
		||||
    returned as the fourth channel.
 | 
			
		||||
    Otherwise, the image is read as RGB and a three-channel image is returned.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    with Image.open(path) as pil_im:
 | 
			
		||||
        im = np.array(pil_im.convert("RGB"))
 | 
			
		||||
        # Check if the image has an alpha channel
 | 
			
		||||
        if try_read_alpha and pil_im.mode == "RGBA":
 | 
			
		||||
            im = np.array(pil_im)
 | 
			
		||||
        else:
 | 
			
		||||
            im = np.array(pil_im.convert("RGB"))
 | 
			
		||||
 | 
			
		||||
    return transpose_normalize_image(im)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user