mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Add background color support to compositors
Summary: Support rendering different color backgrounds for pointclouds for both compositors Reviewed By: nikhilaravi Differential Revision: D23611043 fbshipit-source-id: ab029650d51349340372c5bd66700e6577d48851
This commit is contained in:
parent
dc40adfa24
commit
872ff8c796
@ -1,5 +1,8 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -16,11 +19,20 @@ class AlphaCompositor(nn.Module):
|
|||||||
Accumulate points using alpha compositing.
|
Accumulate points using alpha compositing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(
|
||||||
|
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.background_color = background_color
|
||||||
|
|
||||||
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
|
||||||
|
background_color = kwargs.get("background_color", self.background_color)
|
||||||
images = alpha_composite(fragments, alphas, ptclds)
|
images = alpha_composite(fragments, alphas, ptclds)
|
||||||
|
|
||||||
|
# images are of shape (N, C, H, W)
|
||||||
|
# check for background color & feature size C (C=4 indicates rgba)
|
||||||
|
if background_color is not None and images.shape[1] == 4:
|
||||||
|
return _add_background_color_to_images(fragments, images, background_color)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
@ -29,9 +41,68 @@ class NormWeightedCompositor(nn.Module):
|
|||||||
Accumulate points using a normalized weighted sum.
|
Accumulate points using a normalized weighted sum.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(
|
||||||
|
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.background_color = background_color
|
||||||
|
|
||||||
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
|
||||||
|
background_color = kwargs.get("background_color", self.background_color)
|
||||||
images = norm_weighted_sum(fragments, alphas, ptclds)
|
images = norm_weighted_sum(fragments, alphas, ptclds)
|
||||||
|
|
||||||
|
# images are of shape (N, C, H, W)
|
||||||
|
# check for background color & feature size C (C=4 indicates rgba)
|
||||||
|
if background_color is not None and images.shape[1] == 4:
|
||||||
|
return _add_background_color_to_images(fragments, images, background_color)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def _add_background_color_to_images(pix_idxs, images, background_color):
|
||||||
|
"""
|
||||||
|
Mask pixels in images without corresponding points with a given background_color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pix_idxs: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||||
|
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||||
|
images: Tensor of shape (N, 4, image_size, image_size) giving the
|
||||||
|
accumulated features at each point, where 4 refers to a rgba feature.
|
||||||
|
background_color: Tensor, list, or tuple with 3 or 4 values indicating the rgb/rgba
|
||||||
|
value for the new background. Values should be in the interval [0,1].
|
||||||
|
Returns:
|
||||||
|
images: Tensor of shape (N, 4, image_size, image_size), where pixels with
|
||||||
|
no nearest points have features set to the background color, and other
|
||||||
|
pixels with accumulated features have unchanged values.
|
||||||
|
"""
|
||||||
|
# Initialize background mask
|
||||||
|
background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size)
|
||||||
|
|
||||||
|
# Convert background_color to an appropriate tensor and check shape
|
||||||
|
if not torch.is_tensor(background_color):
|
||||||
|
background_color = images.new_tensor(background_color)
|
||||||
|
|
||||||
|
background_shape = background_color.shape
|
||||||
|
|
||||||
|
if len(background_shape) != 1 or background_shape[0] not in (3, 4):
|
||||||
|
warnings.warn(
|
||||||
|
"Background color should be size (3) or (4), but is size %s instead"
|
||||||
|
% (background_shape,)
|
||||||
|
)
|
||||||
|
return images
|
||||||
|
|
||||||
|
background_color = background_color.to(images)
|
||||||
|
|
||||||
|
# add alpha channel
|
||||||
|
if background_shape[0] == 3:
|
||||||
|
alpha = images.new_ones(1)
|
||||||
|
background_color = torch.cat([background_color, alpha])
|
||||||
|
|
||||||
|
num_background_pixels = background_mask.sum()
|
||||||
|
|
||||||
|
# permute so that features are the last dimension for masked_scatter to work
|
||||||
|
masked_images = images.permute(0, 2, 3, 1)[..., :4].masked_scatter(
|
||||||
|
background_mask[..., None],
|
||||||
|
background_color[None, :].expand(num_background_pixels, -1),
|
||||||
|
)
|
||||||
|
|
||||||
|
return masked_images.permute(0, 3, 1, 2)
|
||||||
|
@ -18,6 +18,7 @@ from pytorch3d.renderer.cameras import (
|
|||||||
FoVPerspectiveCameras,
|
FoVPerspectiveCameras,
|
||||||
look_at_view_transform,
|
look_at_view_transform,
|
||||||
)
|
)
|
||||||
|
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
|
||||||
from pytorch3d.renderer.points import (
|
from pytorch3d.renderer.points import (
|
||||||
AlphaCompositor,
|
AlphaCompositor,
|
||||||
NormWeightedCompositor,
|
NormWeightedCompositor,
|
||||||
@ -171,3 +172,54 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
|||||||
DATA_DIR / filename
|
DATA_DIR / filename
|
||||||
)
|
)
|
||||||
self.assertClose(rgb, image_ref)
|
self.assertClose(rgb, image_ref)
|
||||||
|
|
||||||
|
def test_compositor_background_color(self):
|
||||||
|
|
||||||
|
N, H, W, K, C, P = 1, 15, 15, 20, 4, 225
|
||||||
|
ptclds = torch.randn((C, P))
|
||||||
|
alphas = torch.rand((N, K, H, W))
|
||||||
|
pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1
|
||||||
|
background_color = [0.5, 0, 1]
|
||||||
|
|
||||||
|
compositor_funcs = [
|
||||||
|
(NormWeightedCompositor, norm_weighted_sum),
|
||||||
|
(AlphaCompositor, alpha_composite),
|
||||||
|
]
|
||||||
|
|
||||||
|
for (compositor_class, composite_func) in compositor_funcs:
|
||||||
|
|
||||||
|
compositor = compositor_class(background_color)
|
||||||
|
|
||||||
|
# run the forward method to generate masked images
|
||||||
|
masked_images = compositor.forward(pix_idxs, alphas, ptclds)
|
||||||
|
|
||||||
|
# generate unmasked images for testing purposes
|
||||||
|
images = composite_func(pix_idxs, alphas, ptclds)
|
||||||
|
|
||||||
|
is_foreground = pix_idxs[:, 0] >= 0
|
||||||
|
|
||||||
|
# make sure foreground values are unchanged
|
||||||
|
self.assertClose(
|
||||||
|
torch.masked_select(masked_images, is_foreground[:, None]),
|
||||||
|
torch.masked_select(images, is_foreground[:, None]),
|
||||||
|
)
|
||||||
|
|
||||||
|
is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4)
|
||||||
|
|
||||||
|
# permute masked_images to correctly get rgb values
|
||||||
|
masked_images = masked_images.permute(0, 2, 3, 1)
|
||||||
|
for i in range(3):
|
||||||
|
channel_color = background_color[i]
|
||||||
|
|
||||||
|
# check if background colors are properly changed
|
||||||
|
self.assertTrue(
|
||||||
|
masked_images[is_background]
|
||||||
|
.view(-1, 4)[..., i]
|
||||||
|
.eq(channel_color)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# check background color alpha values
|
||||||
|
self.assertTrue(
|
||||||
|
masked_images[is_background].view(-1, 4)[..., 3].eq(1).all()
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user