mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
flexible background color for point compositing
Summary: Modified the compositor background color tests to account for either a 3rd or 4th channel. Also replaced hard coding of channel value with C. Implemented changes to alpha channel appending logic, and cleaned up extraneous warnings and checks, per task instructions. Fixes https://github.com/facebookresearch/pytorch3d/issues/1048 Reviewed By: bottler Differential Revision: D34305312 fbshipit-source-id: 2176c3bdd897d1a2ba6ff4c6fa801fea889e4f02
This commit is contained in:
parent
c8f3d6bc0b
commit
59972b121d
@ -35,7 +35,7 @@ class AlphaCompositor(nn.Module):
|
||||
|
||||
# images are of shape (N, C, H, W)
|
||||
# check for background color & feature size C (C=4 indicates rgba)
|
||||
if background_color is not None and images.shape[1] == 4:
|
||||
if background_color is not None:
|
||||
return _add_background_color_to_images(fragments, images, background_color)
|
||||
return images
|
||||
|
||||
@ -57,7 +57,7 @@ class NormWeightedCompositor(nn.Module):
|
||||
|
||||
# images are of shape (N, C, H, W)
|
||||
# check for background color & feature size C (C=4 indicates rgba)
|
||||
if background_color is not None and images.shape[1] == 4:
|
||||
if background_color is not None:
|
||||
return _add_background_color_to_images(fragments, images, background_color)
|
||||
return images
|
||||
|
||||
@ -85,22 +85,22 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
|
||||
if not torch.is_tensor(background_color):
|
||||
background_color = images.new_tensor(background_color)
|
||||
|
||||
background_shape = background_color.shape
|
||||
|
||||
if len(background_shape) != 1 or background_shape[0] not in (3, 4):
|
||||
warnings.warn(
|
||||
"Background color should be size (3) or (4), but is size %s instead"
|
||||
% (background_shape,)
|
||||
)
|
||||
return images
|
||||
if len(background_color.shape) != 1:
|
||||
raise ValueError("Wrong shape of background_color")
|
||||
|
||||
background_color = background_color.to(images)
|
||||
|
||||
# add alpha channel
|
||||
if background_shape[0] == 3:
|
||||
if background_color.shape[0] == 3 and images.shape[1] == 4:
|
||||
# special case to allow giving RGB background for RGBA
|
||||
alpha = images.new_ones(1)
|
||||
background_color = torch.cat([background_color, alpha])
|
||||
|
||||
if images.shape[1] != background_color.shape[0]:
|
||||
raise ValueError(
|
||||
f"background color has {background_color.shape[0] } channels not {images.shape[1]}"
|
||||
)
|
||||
|
||||
num_background_pixels = background_mask.sum()
|
||||
|
||||
# permute so that features are the last dimension for masked_scatter to work
|
||||
|
@ -326,7 +326,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertClose(rgb, image_ref)
|
||||
|
||||
def test_compositor_background_color(self):
|
||||
def test_compositor_background_color_rgba(self):
|
||||
|
||||
N, H, W, K, C, P = 1, 15, 15, 20, 4, 225
|
||||
ptclds = torch.randn((C, P))
|
||||
@ -357,7 +357,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
||||
torch.masked_select(images, is_foreground[:, None]),
|
||||
)
|
||||
|
||||
is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4)
|
||||
is_background = ~is_foreground[..., None].expand(-1, -1, -1, C)
|
||||
|
||||
# permute masked_images to correctly get rgb values
|
||||
masked_images = masked_images.permute(0, 2, 3, 1)
|
||||
@ -367,12 +367,58 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
||||
# check if background colors are properly changed
|
||||
self.assertTrue(
|
||||
masked_images[is_background]
|
||||
.view(-1, 4)[..., i]
|
||||
.view(-1, C)[..., i]
|
||||
.eq(channel_color)
|
||||
.all()
|
||||
)
|
||||
|
||||
# check background color alpha values
|
||||
self.assertTrue(
|
||||
masked_images[is_background].view(-1, 4)[..., 3].eq(1).all()
|
||||
masked_images[is_background].view(-1, C)[..., 3].eq(1).all()
|
||||
)
|
||||
|
||||
def test_compositor_background_color_rgb(self):
|
||||
|
||||
N, H, W, K, C, P = 1, 15, 15, 20, 3, 225
|
||||
ptclds = torch.randn((C, P))
|
||||
alphas = torch.rand((N, K, H, W))
|
||||
pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1
|
||||
background_color = [0.5, 0, 1]
|
||||
|
||||
compositor_funcs = [
|
||||
(NormWeightedCompositor, norm_weighted_sum),
|
||||
(AlphaCompositor, alpha_composite),
|
||||
]
|
||||
|
||||
for (compositor_class, composite_func) in compositor_funcs:
|
||||
|
||||
compositor = compositor_class(background_color)
|
||||
|
||||
# run the forward method to generate masked images
|
||||
masked_images = compositor.forward(pix_idxs, alphas, ptclds)
|
||||
|
||||
# generate unmasked images for testing purposes
|
||||
images = composite_func(pix_idxs, alphas, ptclds)
|
||||
|
||||
is_foreground = pix_idxs[:, 0] >= 0
|
||||
|
||||
# make sure foreground values are unchanged
|
||||
self.assertClose(
|
||||
torch.masked_select(masked_images, is_foreground[:, None]),
|
||||
torch.masked_select(images, is_foreground[:, None]),
|
||||
)
|
||||
|
||||
is_background = ~is_foreground[..., None].expand(-1, -1, -1, C)
|
||||
|
||||
# permute masked_images to correctly get rgb values
|
||||
masked_images = masked_images.permute(0, 2, 3, 1)
|
||||
for i in range(3):
|
||||
channel_color = background_color[i]
|
||||
|
||||
# check if background colors are properly changed
|
||||
self.assertTrue(
|
||||
masked_images[is_background]
|
||||
.view(-1, C)[..., i]
|
||||
.eq(channel_color)
|
||||
.all()
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user