From 59972b121d8c7bfc0e156b5ad5fcd77c11874178 Mon Sep 17 00:00:00 2001 From: Alex Greene Date: Fri, 18 Feb 2022 07:00:01 -0800 Subject: [PATCH] flexible background color for point compositing Summary: Modified the compositor background color tests to account for either a 3rd or 4th channel. Also replaced hard coding of channel value with C. Implemented changes to alpha channel appending logic, and cleaned up extraneous warnings and checks, per task instructions. Fixes https://github.com/facebookresearch/pytorch3d/issues/1048 Reviewed By: bottler Differential Revision: D34305312 fbshipit-source-id: 2176c3bdd897d1a2ba6ff4c6fa801fea889e4f02 --- pytorch3d/renderer/points/compositor.py | 22 +++++----- tests/test_render_points.py | 54 +++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/pytorch3d/renderer/points/compositor.py b/pytorch3d/renderer/points/compositor.py index 23c73916..17483ec7 100644 --- a/pytorch3d/renderer/points/compositor.py +++ b/pytorch3d/renderer/points/compositor.py @@ -35,7 +35,7 @@ class AlphaCompositor(nn.Module): # images are of shape (N, C, H, W) # check for background color & feature size C (C=4 indicates rgba) - if background_color is not None and images.shape[1] == 4: + if background_color is not None: return _add_background_color_to_images(fragments, images, background_color) return images @@ -57,7 +57,7 @@ class NormWeightedCompositor(nn.Module): # images are of shape (N, C, H, W) # check for background color & feature size C (C=4 indicates rgba) - if background_color is not None and images.shape[1] == 4: + if background_color is not None: return _add_background_color_to_images(fragments, images, background_color) return images @@ -85,22 +85,22 @@ def _add_background_color_to_images(pix_idxs, images, background_color): if not torch.is_tensor(background_color): background_color = images.new_tensor(background_color) - background_shape = background_color.shape - - if len(background_shape) != 1 or background_shape[0] not in (3, 4): - warnings.warn( - "Background color should be size (3) or (4), but is size %s instead" - % (background_shape,) - ) - return images + if len(background_color.shape) != 1: + raise ValueError("Wrong shape of background_color") background_color = background_color.to(images) # add alpha channel - if background_shape[0] == 3: + if background_color.shape[0] == 3 and images.shape[1] == 4: + # special case to allow giving RGB background for RGBA alpha = images.new_ones(1) background_color = torch.cat([background_color, alpha]) + if images.shape[1] != background_color.shape[0]: + raise ValueError( + f"background color has {background_color.shape[0] } channels not {images.shape[1]}" + ) + num_background_pixels = background_mask.sum() # permute so that features are the last dimension for masked_scatter to work diff --git a/tests/test_render_points.py b/tests/test_render_points.py index dfe9f431..f2ff9bb3 100644 --- a/tests/test_render_points.py +++ b/tests/test_render_points.py @@ -326,7 +326,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): ) self.assertClose(rgb, image_ref) - def test_compositor_background_color(self): + def test_compositor_background_color_rgba(self): N, H, W, K, C, P = 1, 15, 15, 20, 4, 225 ptclds = torch.randn((C, P)) @@ -357,7 +357,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): torch.masked_select(images, is_foreground[:, None]), ) - is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4) + is_background = ~is_foreground[..., None].expand(-1, -1, -1, C) # permute masked_images to correctly get rgb values masked_images = masked_images.permute(0, 2, 3, 1) @@ -367,12 +367,58 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): # check if background colors are properly changed self.assertTrue( masked_images[is_background] - .view(-1, 4)[..., i] + .view(-1, C)[..., i] .eq(channel_color) .all() ) # check background color alpha values self.assertTrue( - masked_images[is_background].view(-1, 4)[..., 3].eq(1).all() + masked_images[is_background].view(-1, C)[..., 3].eq(1).all() ) + + def test_compositor_background_color_rgb(self): + + N, H, W, K, C, P = 1, 15, 15, 20, 3, 225 + ptclds = torch.randn((C, P)) + alphas = torch.rand((N, K, H, W)) + pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1 + background_color = [0.5, 0, 1] + + compositor_funcs = [ + (NormWeightedCompositor, norm_weighted_sum), + (AlphaCompositor, alpha_composite), + ] + + for (compositor_class, composite_func) in compositor_funcs: + + compositor = compositor_class(background_color) + + # run the forward method to generate masked images + masked_images = compositor.forward(pix_idxs, alphas, ptclds) + + # generate unmasked images for testing purposes + images = composite_func(pix_idxs, alphas, ptclds) + + is_foreground = pix_idxs[:, 0] >= 0 + + # make sure foreground values are unchanged + self.assertClose( + torch.masked_select(masked_images, is_foreground[:, None]), + torch.masked_select(images, is_foreground[:, None]), + ) + + is_background = ~is_foreground[..., None].expand(-1, -1, -1, C) + + # permute masked_images to correctly get rgb values + masked_images = masked_images.permute(0, 2, 3, 1) + for i in range(3): + channel_color = background_color[i] + + # check if background colors are properly changed + self.assertTrue( + masked_images[is_background] + .view(-1, C)[..., i] + .eq(channel_color) + .all() + )