Add background color support to compositors

Summary: Support rendering different color backgrounds for pointclouds for both compositors

Reviewed By: nikhilaravi

Differential Revision: D23611043

fbshipit-source-id: ab029650d51349340372c5bd66700e6577d48851
This commit is contained in:
Amitav Baruah 2020-09-14 10:36:56 -07:00 committed by Facebook GitHub Bot
parent dc40adfa24
commit 872ff8c796
2 changed files with 125 additions and 2 deletions

View File

@ -1,5 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,11 +19,20 @@ class AlphaCompositor(nn.Module):
Accumulate points using alpha compositing. Accumulate points using alpha compositing.
""" """
def __init__(self): def __init__(
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
):
super().__init__() super().__init__()
self.background_color = background_color
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor: def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
background_color = kwargs.get("background_color", self.background_color)
images = alpha_composite(fragments, alphas, ptclds) images = alpha_composite(fragments, alphas, ptclds)
# images are of shape (N, C, H, W)
# check for background color & feature size C (C=4 indicates rgba)
if background_color is not None and images.shape[1] == 4:
return _add_background_color_to_images(fragments, images, background_color)
return images return images
@ -29,9 +41,68 @@ class NormWeightedCompositor(nn.Module):
Accumulate points using a normalized weighted sum. Accumulate points using a normalized weighted sum.
""" """
def __init__(self): def __init__(
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
):
super().__init__() super().__init__()
self.background_color = background_color
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor: def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
background_color = kwargs.get("background_color", self.background_color)
images = norm_weighted_sum(fragments, alphas, ptclds) images = norm_weighted_sum(fragments, alphas, ptclds)
# images are of shape (N, C, H, W)
# check for background color & feature size C (C=4 indicates rgba)
if background_color is not None and images.shape[1] == 4:
return _add_background_color_to_images(fragments, images, background_color)
return images return images
def _add_background_color_to_images(pix_idxs, images, background_color):
"""
Mask pixels in images without corresponding points with a given background_color.
Args:
pix_idxs: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
giving the indices of the nearest points at each pixel, sorted in z-order.
images: Tensor of shape (N, 4, image_size, image_size) giving the
accumulated features at each point, where 4 refers to a rgba feature.
background_color: Tensor, list, or tuple with 3 or 4 values indicating the rgb/rgba
value for the new background. Values should be in the interval [0,1].
Returns:
images: Tensor of shape (N, 4, image_size, image_size), where pixels with
no nearest points have features set to the background color, and other
pixels with accumulated features have unchanged values.
"""
# Initialize background mask
background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size)
# Convert background_color to an appropriate tensor and check shape
if not torch.is_tensor(background_color):
background_color = images.new_tensor(background_color)
background_shape = background_color.shape
if len(background_shape) != 1 or background_shape[0] not in (3, 4):
warnings.warn(
"Background color should be size (3) or (4), but is size %s instead"
% (background_shape,)
)
return images
background_color = background_color.to(images)
# add alpha channel
if background_shape[0] == 3:
alpha = images.new_ones(1)
background_color = torch.cat([background_color, alpha])
num_background_pixels = background_mask.sum()
# permute so that features are the last dimension for masked_scatter to work
masked_images = images.permute(0, 2, 3, 1)[..., :4].masked_scatter(
background_mask[..., None],
background_color[None, :].expand(num_background_pixels, -1),
)
return masked_images.permute(0, 3, 1, 2)

View File

@ -18,6 +18,7 @@ from pytorch3d.renderer.cameras import (
FoVPerspectiveCameras, FoVPerspectiveCameras,
look_at_view_transform, look_at_view_transform,
) )
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
from pytorch3d.renderer.points import ( from pytorch3d.renderer.points import (
AlphaCompositor, AlphaCompositor,
NormWeightedCompositor, NormWeightedCompositor,
@ -171,3 +172,54 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
DATA_DIR / filename DATA_DIR / filename
) )
self.assertClose(rgb, image_ref) self.assertClose(rgb, image_ref)
def test_compositor_background_color(self):
N, H, W, K, C, P = 1, 15, 15, 20, 4, 225
ptclds = torch.randn((C, P))
alphas = torch.rand((N, K, H, W))
pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1
background_color = [0.5, 0, 1]
compositor_funcs = [
(NormWeightedCompositor, norm_weighted_sum),
(AlphaCompositor, alpha_composite),
]
for (compositor_class, composite_func) in compositor_funcs:
compositor = compositor_class(background_color)
# run the forward method to generate masked images
masked_images = compositor.forward(pix_idxs, alphas, ptclds)
# generate unmasked images for testing purposes
images = composite_func(pix_idxs, alphas, ptclds)
is_foreground = pix_idxs[:, 0] >= 0
# make sure foreground values are unchanged
self.assertClose(
torch.masked_select(masked_images, is_foreground[:, None]),
torch.masked_select(images, is_foreground[:, None]),
)
is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4)
# permute masked_images to correctly get rgb values
masked_images = masked_images.permute(0, 2, 3, 1)
for i in range(3):
channel_color = background_color[i]
# check if background colors are properly changed
self.assertTrue(
masked_images[is_background]
.view(-1, 4)[..., i]
.eq(channel_color)
.all()
)
# check background color alpha values
self.assertTrue(
masked_images[is_background].view(-1, 4)[..., 3].eq(1).all()
)