From 872ff8c796e0947469fda76a028439e1dc8d3696 Mon Sep 17 00:00:00 2001 From: Amitav Baruah Date: Mon, 14 Sep 2020 10:36:56 -0700 Subject: [PATCH] Add background color support to compositors Summary: Support rendering different color backgrounds for pointclouds for both compositors Reviewed By: nikhilaravi Differential Revision: D23611043 fbshipit-source-id: ab029650d51349340372c5bd66700e6577d48851 --- pytorch3d/renderer/points/compositor.py | 75 ++++++++++++++++++++++++- tests/test_render_points.py | 52 +++++++++++++++++ 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/pytorch3d/renderer/points/compositor.py b/pytorch3d/renderer/points/compositor.py index fe07c2c1..6f6c274c 100644 --- a/pytorch3d/renderer/points/compositor.py +++ b/pytorch3d/renderer/points/compositor.py @@ -1,5 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import warnings +from typing import List, Optional, Tuple, Union + import torch import torch.nn as nn @@ -16,11 +19,20 @@ class AlphaCompositor(nn.Module): Accumulate points using alpha compositing. """ - def __init__(self): + def __init__( + self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None + ): super().__init__() + self.background_color = background_color def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor: + background_color = kwargs.get("background_color", self.background_color) images = alpha_composite(fragments, alphas, ptclds) + + # images are of shape (N, C, H, W) + # check for background color & feature size C (C=4 indicates rgba) + if background_color is not None and images.shape[1] == 4: + return _add_background_color_to_images(fragments, images, background_color) return images @@ -29,9 +41,68 @@ class NormWeightedCompositor(nn.Module): Accumulate points using a normalized weighted sum. """ - def __init__(self): + def __init__( + self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None + ): super().__init__() + self.background_color = background_color def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor: + background_color = kwargs.get("background_color", self.background_color) images = norm_weighted_sum(fragments, alphas, ptclds) + + # images are of shape (N, C, H, W) + # check for background color & feature size C (C=4 indicates rgba) + if background_color is not None and images.shape[1] == 4: + return _add_background_color_to_images(fragments, images, background_color) return images + + +def _add_background_color_to_images(pix_idxs, images, background_color): + """ + Mask pixels in images without corresponding points with a given background_color. + + Args: + pix_idxs: int32 Tensor of shape (N, points_per_pixel, image_size, image_size) + giving the indices of the nearest points at each pixel, sorted in z-order. + images: Tensor of shape (N, 4, image_size, image_size) giving the + accumulated features at each point, where 4 refers to a rgba feature. + background_color: Tensor, list, or tuple with 3 or 4 values indicating the rgb/rgba + value for the new background. Values should be in the interval [0,1]. + Returns: + images: Tensor of shape (N, 4, image_size, image_size), where pixels with + no nearest points have features set to the background color, and other + pixels with accumulated features have unchanged values. + """ + # Initialize background mask + background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size) + + # Convert background_color to an appropriate tensor and check shape + if not torch.is_tensor(background_color): + background_color = images.new_tensor(background_color) + + background_shape = background_color.shape + + if len(background_shape) != 1 or background_shape[0] not in (3, 4): + warnings.warn( + "Background color should be size (3) or (4), but is size %s instead" + % (background_shape,) + ) + return images + + background_color = background_color.to(images) + + # add alpha channel + if background_shape[0] == 3: + alpha = images.new_ones(1) + background_color = torch.cat([background_color, alpha]) + + num_background_pixels = background_mask.sum() + + # permute so that features are the last dimension for masked_scatter to work + masked_images = images.permute(0, 2, 3, 1)[..., :4].masked_scatter( + background_mask[..., None], + background_color[None, :].expand(num_background_pixels, -1), + ) + + return masked_images.permute(0, 3, 1, 2) diff --git a/tests/test_render_points.py b/tests/test_render_points.py index c5820267..13778ac4 100644 --- a/tests/test_render_points.py +++ b/tests/test_render_points.py @@ -18,6 +18,7 @@ from pytorch3d.renderer.cameras import ( FoVPerspectiveCameras, look_at_view_transform, ) +from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum from pytorch3d.renderer.points import ( AlphaCompositor, NormWeightedCompositor, @@ -171,3 +172,54 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): DATA_DIR / filename ) self.assertClose(rgb, image_ref) + + def test_compositor_background_color(self): + + N, H, W, K, C, P = 1, 15, 15, 20, 4, 225 + ptclds = torch.randn((C, P)) + alphas = torch.rand((N, K, H, W)) + pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1 + background_color = [0.5, 0, 1] + + compositor_funcs = [ + (NormWeightedCompositor, norm_weighted_sum), + (AlphaCompositor, alpha_composite), + ] + + for (compositor_class, composite_func) in compositor_funcs: + + compositor = compositor_class(background_color) + + # run the forward method to generate masked images + masked_images = compositor.forward(pix_idxs, alphas, ptclds) + + # generate unmasked images for testing purposes + images = composite_func(pix_idxs, alphas, ptclds) + + is_foreground = pix_idxs[:, 0] >= 0 + + # make sure foreground values are unchanged + self.assertClose( + torch.masked_select(masked_images, is_foreground[:, None]), + torch.masked_select(images, is_foreground[:, None]), + ) + + is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4) + + # permute masked_images to correctly get rgb values + masked_images = masked_images.permute(0, 2, 3, 1) + for i in range(3): + channel_color = background_color[i] + + # check if background colors are properly changed + self.assertTrue( + masked_images[is_background] + .view(-1, 4)[..., i] + .eq(channel_color) + .all() + ) + + # check background color alpha values + self.assertTrue( + masked_images[is_background].view(-1, 4)[..., 3].eq(1).all() + )