flexible background color for point compositing

Summary:
Modified the compositor background color tests to account for either a 3rd or 4th channel. Also replaced hard coding of channel value with C.

Implemented changes to alpha channel appending logic, and cleaned up extraneous warnings and checks, per task instructions.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1048

Reviewed By: bottler

Differential Revision: D34305312

fbshipit-source-id: 2176c3bdd897d1a2ba6ff4c6fa801fea889e4f02
This commit is contained in:
Alex Greene 2022-02-18 07:00:01 -08:00 committed by Facebook GitHub Bot
parent c8f3d6bc0b
commit 59972b121d
2 changed files with 61 additions and 15 deletions

View File

@ -35,7 +35,7 @@ class AlphaCompositor(nn.Module):
# images are of shape (N, C, H, W)
# check for background color & feature size C (C=4 indicates rgba)
if background_color is not None and images.shape[1] == 4:
if background_color is not None:
return _add_background_color_to_images(fragments, images, background_color)
return images
@ -57,7 +57,7 @@ class NormWeightedCompositor(nn.Module):
# images are of shape (N, C, H, W)
# check for background color & feature size C (C=4 indicates rgba)
if background_color is not None and images.shape[1] == 4:
if background_color is not None:
return _add_background_color_to_images(fragments, images, background_color)
return images
@ -85,22 +85,22 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
if not torch.is_tensor(background_color):
background_color = images.new_tensor(background_color)
background_shape = background_color.shape
if len(background_shape) != 1 or background_shape[0] not in (3, 4):
warnings.warn(
"Background color should be size (3) or (4), but is size %s instead"
% (background_shape,)
)
return images
if len(background_color.shape) != 1:
raise ValueError("Wrong shape of background_color")
background_color = background_color.to(images)
# add alpha channel
if background_shape[0] == 3:
if background_color.shape[0] == 3 and images.shape[1] == 4:
# special case to allow giving RGB background for RGBA
alpha = images.new_ones(1)
background_color = torch.cat([background_color, alpha])
if images.shape[1] != background_color.shape[0]:
raise ValueError(
f"background color has {background_color.shape[0] } channels not {images.shape[1]}"
)
num_background_pixels = background_mask.sum()
# permute so that features are the last dimension for masked_scatter to work

View File

@ -326,7 +326,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
)
self.assertClose(rgb, image_ref)
def test_compositor_background_color(self):
def test_compositor_background_color_rgba(self):
N, H, W, K, C, P = 1, 15, 15, 20, 4, 225
ptclds = torch.randn((C, P))
@ -357,7 +357,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
torch.masked_select(images, is_foreground[:, None]),
)
is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4)
is_background = ~is_foreground[..., None].expand(-1, -1, -1, C)
# permute masked_images to correctly get rgb values
masked_images = masked_images.permute(0, 2, 3, 1)
@ -367,12 +367,58 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
# check if background colors are properly changed
self.assertTrue(
masked_images[is_background]
.view(-1, 4)[..., i]
.view(-1, C)[..., i]
.eq(channel_color)
.all()
)
# check background color alpha values
self.assertTrue(
masked_images[is_background].view(-1, 4)[..., 3].eq(1).all()
masked_images[is_background].view(-1, C)[..., 3].eq(1).all()
)
def test_compositor_background_color_rgb(self):
N, H, W, K, C, P = 1, 15, 15, 20, 3, 225
ptclds = torch.randn((C, P))
alphas = torch.rand((N, K, H, W))
pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1
background_color = [0.5, 0, 1]
compositor_funcs = [
(NormWeightedCompositor, norm_weighted_sum),
(AlphaCompositor, alpha_composite),
]
for (compositor_class, composite_func) in compositor_funcs:
compositor = compositor_class(background_color)
# run the forward method to generate masked images
masked_images = compositor.forward(pix_idxs, alphas, ptclds)
# generate unmasked images for testing purposes
images = composite_func(pix_idxs, alphas, ptclds)
is_foreground = pix_idxs[:, 0] >= 0
# make sure foreground values are unchanged
self.assertClose(
torch.masked_select(masked_images, is_foreground[:, None]),
torch.masked_select(images, is_foreground[:, None]),
)
is_background = ~is_foreground[..., None].expand(-1, -1, -1, C)
# permute masked_images to correctly get rgb values
masked_images = masked_images.permute(0, 2, 3, 1)
for i in range(3):
channel_color = background_color[i]
# check if background colors are properly changed
self.assertTrue(
masked_images[is_background]
.view(-1, C)[..., i]
.eq(channel_color)
.all()
)