diff --git a/pytorch3d/renderer/points/compositor.py b/pytorch3d/renderer/points/compositor.py index 23c73916..17483ec7 100644 --- a/pytorch3d/renderer/points/compositor.py +++ b/pytorch3d/renderer/points/compositor.py @@ -35,7 +35,7 @@ class AlphaCompositor(nn.Module): # images are of shape (N, C, H, W) # check for background color & feature size C (C=4 indicates rgba) - if background_color is not None and images.shape[1] == 4: + if background_color is not None: return _add_background_color_to_images(fragments, images, background_color) return images @@ -57,7 +57,7 @@ class NormWeightedCompositor(nn.Module): # images are of shape (N, C, H, W) # check for background color & feature size C (C=4 indicates rgba) - if background_color is not None and images.shape[1] == 4: + if background_color is not None: return _add_background_color_to_images(fragments, images, background_color) return images @@ -85,22 +85,22 @@ def _add_background_color_to_images(pix_idxs, images, background_color): if not torch.is_tensor(background_color): background_color = images.new_tensor(background_color) - background_shape = background_color.shape - - if len(background_shape) != 1 or background_shape[0] not in (3, 4): - warnings.warn( - "Background color should be size (3) or (4), but is size %s instead" - % (background_shape,) - ) - return images + if len(background_color.shape) != 1: + raise ValueError("Wrong shape of background_color") background_color = background_color.to(images) # add alpha channel - if background_shape[0] == 3: + if background_color.shape[0] == 3 and images.shape[1] == 4: + # special case to allow giving RGB background for RGBA alpha = images.new_ones(1) background_color = torch.cat([background_color, alpha]) + if images.shape[1] != background_color.shape[0]: + raise ValueError( + f"background color has {background_color.shape[0] } channels not {images.shape[1]}" + ) + num_background_pixels = background_mask.sum() # permute so that features are the last dimension for masked_scatter to work diff --git a/tests/test_render_points.py b/tests/test_render_points.py index dfe9f431..f2ff9bb3 100644 --- a/tests/test_render_points.py +++ b/tests/test_render_points.py @@ -326,7 +326,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): ) self.assertClose(rgb, image_ref) - def test_compositor_background_color(self): + def test_compositor_background_color_rgba(self): N, H, W, K, C, P = 1, 15, 15, 20, 4, 225 ptclds = torch.randn((C, P)) @@ -357,7 +357,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): torch.masked_select(images, is_foreground[:, None]), ) - is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4) + is_background = ~is_foreground[..., None].expand(-1, -1, -1, C) # permute masked_images to correctly get rgb values masked_images = masked_images.permute(0, 2, 3, 1) @@ -367,12 +367,58 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): # check if background colors are properly changed self.assertTrue( masked_images[is_background] - .view(-1, 4)[..., i] + .view(-1, C)[..., i] .eq(channel_color) .all() ) # check background color alpha values self.assertTrue( - masked_images[is_background].view(-1, 4)[..., 3].eq(1).all() + masked_images[is_background].view(-1, C)[..., 3].eq(1).all() ) + + def test_compositor_background_color_rgb(self): + + N, H, W, K, C, P = 1, 15, 15, 20, 3, 225 + ptclds = torch.randn((C, P)) + alphas = torch.rand((N, K, H, W)) + pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1 + background_color = [0.5, 0, 1] + + compositor_funcs = [ + (NormWeightedCompositor, norm_weighted_sum), + (AlphaCompositor, alpha_composite), + ] + + for (compositor_class, composite_func) in compositor_funcs: + + compositor = compositor_class(background_color) + + # run the forward method to generate masked images + masked_images = compositor.forward(pix_idxs, alphas, ptclds) + + # generate unmasked images for testing purposes + images = composite_func(pix_idxs, alphas, ptclds) + + is_foreground = pix_idxs[:, 0] >= 0 + + # make sure foreground values are unchanged + self.assertClose( + torch.masked_select(masked_images, is_foreground[:, None]), + torch.masked_select(images, is_foreground[:, None]), + ) + + is_background = ~is_foreground[..., None].expand(-1, -1, -1, C) + + # permute masked_images to correctly get rgb values + masked_images = masked_images.permute(0, 2, 3, 1) + for i in range(3): + channel_color = background_color[i] + + # check if background colors are properly changed + self.assertTrue( + masked_images[is_background] + .view(-1, C)[..., i] + .eq(channel_color) + .all() + )