SplatterBlender

Summary: Splatting shader. See code comments for details. Same API as SoftPhongShader.

Reviewed By: jcjohnson

Differential Revision: D36354301

fbshipit-source-id: 71ee37f7ff6bb9ce028ba42a65741424a427a92d
This commit is contained in:
Krzysztof Chalupka 2022-05-24 21:04:11 -07:00 committed by Facebook GitHub Bot
parent 1702c85bec
commit c5a83f46ef
7 changed files with 1260 additions and 7 deletions

View File

@ -57,6 +57,7 @@ from .mesh import (
SoftGouraudShader,
SoftPhongShader,
SoftSilhouetteShader,
SplatterPhongShader,
Textures,
TexturesAtlas,
TexturesUV,
@ -71,6 +72,7 @@ from .points import (
PulsarPointsRenderer,
rasterize_points,
)
from .splatter_blend import SplatterBlender
from .utils import (
convert_to_tensors_and_broadcast,
ndc_grid_sample,

View File

@ -4,14 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import NamedTuple, Sequence, Union
import torch
from pytorch3d import _C
from pytorch3d.common.datatypes import Device
# Example functions for blending the top K colors per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return an RGBA image per batch element
@ -22,10 +20,12 @@ class BlendParams(NamedTuple):
Data class to store blending params with defaults
Members:
sigma (float): Controls the width of the sigmoid function used to
calculate the 2D distance based probability. Determines the
sharpness of the edges of the shape.
Higher => faces have less defined edges.
sigma (float): For SoftmaxPhong, controls the width of the sigmoid
function used to calculate the 2D distance based probability. Determines
the sharpness of the edges of the shape. Higher => faces have less defined
edges. For SplatterPhong, this is the standard deviation of the Gaussian
kernel. Higher => splats have a stronger effect and the rendered image is
more blurry.
gamma (float): Controls the scaling of the exponential function used
to set the opacity of the color.
Higher => faces are more transparent.
@ -36,6 +36,7 @@ class BlendParams(NamedTuple):
sigma: float = 1e-4
gamma: float = 1e-4
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
background_alpha: float = 0.0
def _get_background_color(

View File

@ -22,6 +22,7 @@ from .shader import ( # DEPRECATED
SoftGouraudShader,
SoftPhongShader,
SoftSilhouetteShader,
SplatterPhongShader,
TexturedSoftPhongShader,
)
from .shading import gouraud_shading, phong_shading

View File

@ -20,9 +20,15 @@ from ..blending import (
)
from ..lighting import PointLights
from ..materials import Materials
from ..splatter_blend import SplatterBlender
from ..utils import TensorProperties
from .rasterizer import Fragments
from .shading import flat_shading, gouraud_shading, phong_shading
from .shading import (
_phong_shading_with_pixels,
flat_shading,
gouraud_shading,
phong_shading,
)
# A Shader should take as input fragments from the output of rasterization
@ -308,3 +314,64 @@ class SoftSilhouetteShader(nn.Module):
blend_params = kwargs.get("blend_params", self.blend_params)
images = sigmoid_alpha_blend(colors, fragments, blend_params)
return images
class SplatterPhongShader(ShaderBase):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function returns the
color aggregated using splats from surrounding pixels (see [0]).
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = SplatterPhongShader(device=torch.device("cuda:0"))
Args:
detach_rasterizer: If True, stop gradients from flowing through the rasterizer.
This simulates the pipeline in [0] which uses a non-differentiable OpenGL
rasterizer.
[0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
Sampling".
"""
def __init__(self, **kwargs):
self.splatter_blender = None
super().__init__(**kwargs)
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SplatterPhongShader."
raise ValueError(msg)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors, pixel_coords_cameras = _phong_shading_with_pixels(
meshes=meshes,
fragments=fragments.detach(),
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
if not self.splatter_blender:
# Init only once, to avoid re-computing constants.
N, H, W, K, _ = colors.shape
self.splatter_blender = SplatterBlender((N, H, W, K), colors.device)
images = self.splatter_blender(
colors,
pixel_coords_cameras,
cameras,
fragments.pix_to_face < 0,
self.blend_params,
)
return images

View File

@ -0,0 +1,553 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from typing import Tuple
import torch
import torch.nn.functional as F
from pytorch3d.common.datatypes import Device
from pytorch3d.renderer import BlendParams
from pytorch3d.renderer.cameras import FoVPerspectiveCameras
from .blending import _get_background_color
def _precompute(
input_shape: Tuple[int, int, int, int], device: Device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Precompute padding and offset constants that won't change for a given NHWK shape.
Args:
input_shape: Tuple indicating N (batch size), H, W (image size) and K (number of
intersections) output by the rasterizer.
device: Device to store the tensors on.
returns:
crop_ids_h: An (N, H, W+2, K, 9, 5) tensor, used during splatting to offset the
p-pixels (splatting pixels) in one of the 9 splatting directions within a
call to torch.gather. See comments and offset_splats for details.
crop_ids_w: An (N, H, W, K, 9, 5) tensor, used similarly to crop_ids_h.
offsets: A (1, 1, 1, 1, 9, 2) tensor (shaped so for broadcasting) containing va-
lues [-1, -1], [-1, 0], [-1, 1], [0, -1], ..., [1, 1] which correspond to
the nine splatting directions.
"""
N, H, W, K = input_shape
# (N, H, W+2, K, 9, 5) tensor, used to reduce a tensor from (N, H+2, W+2...) to
# (N, H, W+2, ...) in torch.gather. If only torch.gather broadcasted, we wouldn't
# need the tiling. But it doesn't.
crop_ids_h = (
torch.arange(0, H, device=device).view(1, H, 1, 1, 1, 1)
+ torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], device=device).view(
1, 1, 1, 1, 9, 1
)
).expand(N, H, W + 2, K, 9, 5)
# (N, H, W, K, 9, 5) tensor, used to reduce a tensor from (N, H, W+2, ...) to
# (N, H, W, ...) in torch.gather.
crop_ids_w = (
torch.arange(0, W, device=device).view(1, 1, W, 1, 1, 1)
+ torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device=device).view(
1, 1, 1, 1, 9, 1
)
).expand(N, H, W, K, 9, 5)
offsets = torch.tensor(
list(itertools.product((-1, 0, 1), repeat=2)),
dtype=torch.long,
device=device,
)
return crop_ids_h, crop_ids_w, offsets
def _prepare_pixels_and_colors(
pixel_coords_cameras: torch.Tensor,
colors: torch.Tensor,
cameras: FoVPerspectiveCameras,
background_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Project pixel coords into the un-inverted screen frame of reference, and set
background pixel z-values to 1.0 and alphas to 0.0.
Args:
pixel_coords_cameras: (N, H, W, K, 3) float tensor.
colors: (N, H, W, K, 3) float tensor.
cameras: PyTorch3D cameras, for now we assume FoVPerspectiveCameras.
background_mask: (N, H, W, K) boolean tensor.
Returns:
pixel_coords_screen: (N, H, W, K, 3) float tensor. Background pixels have
x=y=z=1.0.
colors: (N, H, W, K, 4). Alpha is set to 1 for foreground pixels and 0 for back-
ground pixels.
"""
N, H, W, K, C = colors.shape
# pixel_coords_screen will contain invalid values at background
# intersections, and [H+0.5, W+0.5, z] at valid intersections. It is important
# to not flip the xy axes, otherwise the gradients will be inverted when the
# splatter works with a detached rasterizer.
pixel_coords_screen = cameras.transform_points_screen(
pixel_coords_cameras.view([N, -1, 3]), image_size=(H, W), with_xyflip=False
).reshape(pixel_coords_cameras.shape)
# Set colors' alpha to 1 and background to 0.
colors = torch.cat(
[colors, torch.ones_like(colors[..., :1])], dim=-1
) # (N, H, W, K, 4)
# The hw values of background don't matter because their alpha is set
# to 0 in the next step (which means that no matter what their splatting kernel
# value is, they will not splat as the kernel is multiplied by alpha). However,
# their z-values need to be at max depth. Otherwise, we could incorrectly compute
# occlusion layer linkage.
pixel_coords_screen[background_mask] = 1.0
# Any background color value value with alpha=0 will do, as anything with
# alpha=0 will have a zero-weight splatting power. Note that neighbors can still
# splat on zero-alpha pixels: that's the way we get non-zero gradients at the
# boundary with the background.
colors[background_mask] = 0.0
return pixel_coords_screen, colors
def _get_splat_kernel_normalization(
offsets: torch.Tensor,
sigma: float = 0.5,
):
if sigma <= 0.0:
raise ValueError("Only positive standard deviations make sense.")
epsilon = 0.05
normalization_constant = torch.exp(
-(offsets**2).sum(dim=1) / (2 * sigma**2)
).sum()
# We add an epsilon to the normalization constant to ensure the gradient will travel
# through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia-
# ble Surface Rendering via Non-Differentiable Sampling", Cole et al.
return (1 + epsilon) / normalization_constant
def _compute_occlusion_layers(
q_depth: torch.Tensor,
) -> torch.Tensor:
"""
For each splatting pixel, decide whether it splats from a background, surface, or
foreground depth relative to the splatted pixel. See unit tests in
test_splatter_blend for some enlightening examples.
Args:
q_depth: (N, H, W, K) tensor of z-values of the splatted pixels.
Returns:
occlusion_layers: (N, H, W, 9) long tensor. Each of the 9 values corresponds to
one of the nine splatting directions ([-1, -1], [-1, 0], ..., [1,
1]). The value at nhwd (where d is the splatting direction) is 0 if
the splat in direction d is on the same surface level as the pixel at
hw. The value is negative if the splat is in the background (occluded
by another splat above it that is at the same surface level as the
pixel splatted on), and the value is positive if the splat is in the
foreground.
"""
N, H, W, K = q_depth.shape
# q are the "center pixels" and p the pixels splatting onto them. Use `unfold` to
# create `p_depth`, a tensor with 9 layers, each of which corresponds to the
# depth of a neighbor of q in one of the 9 directions. For example, p_depth[nk0hw]
# is the depth of the pixel splatting onto pixel nhwk from the [-1, -1] direction,
# and p_depth[nk4hw] the depth of q (self-splatting onto itself).
# More concretely, imagine the pixel depths in a 2x2 image's k-th layer are
# .1 .2
# .3 .4
# Then (remembering that we pad with zeros when a pixel has fewer than 9 neighbors):
#
# p_depth[n, k, :, 0, 0] = [ 0 0 0 0 .1 .2 0 .3 .4] - neighbors of .1
# p_depth[n, k, :, 0, 1] = [ 0 0 0 .1 .2 0 .3 .4 0] - neighbors of .2
# p_depth[n, k, :, 1, 0] = [ 0 .1 .2 0 .3 .4 0 0 0] - neighbors of .3
# p_depth[n, k, :, 0, 1] = [.1 .2 0 .3 .4 0 0 0 0] - neighbors of .4
q_depth = q_depth.permute(0, 3, 1, 2) # (N, K, H, W)
p_depth = F.unfold(q_depth, kernel_size=3, padding=1) # (N, 3^2 * K, H * W)
q_depth = q_depth.view(N, K, 1, H, W)
p_depth = p_depth.view(N, K, 9, H, W)
# Take the center pixel q's top rasterization layer. This is the "surface layer"
# that we're splatting on. For each of the nine splatting directions p, find which
# of the K splatting rasterization layers is closest in depth to the surface
# splatted layer.
qtop_to_p_zdist = torch.abs(p_depth - q_depth[:, 0:1]) # (N, K, 9, H, W)
qtop_to_p_closest_zdist, qtop_to_p_closest_id = qtop_to_p_zdist.min(dim=1)
# For each of the nine splatting directions p, take the top of the K rasterization
# layers. Check which of the K q-layers (that the given direction is splatting on)
# is closest in depth to the top splatting layer.
ptop_to_q_zdist = torch.abs(p_depth[:, 0:1] - q_depth) # (N, K, 9, H, W)
ptop_to_q_closest_zdist, ptop_to_q_closest_id = ptop_to_q_zdist.min(dim=1)
# Decide whether each p is on the same level, below, or above the q it is splatting
# on. See Fig. 4 in [0] for an illustration. Briefly: say we're interested in pixel
# p_{h, w} = [10, 32] splatting onto its neighbor q_{h, w} = [11, 33]. The splat is
# coming from direction [-1, -1], which has index 0 in our enumeration of splatting
# directions. Hence, we are interested in
#
# P = p_depth[n, :, d=0, 11, 33] - a vector of K depth values, and
# Q = q_depth.squeeze()[n, :, 11, 33] - a vector of K depth values.
#
# If Q[0] is closest, say, to P[2], then we assume the 0th surface layer of Q is
# the same surface as P[2] that's splatting onto it, and P[:2] are foreground splats
# and P[3:] are background splats.
#
# If instead say Q[2] is closest to P[0], then all the splats are background splats,
# because the top splatting layer is the same surface as a non-top splatted layer.
#
# Finally, if Q[0] is closest to P[0], then the top-level P is splatting onto top-
# level Q, and P[1:] are all background splats.
occlusion_offsets = torch.where( # noqa
ptop_to_q_closest_zdist < qtop_to_p_closest_zdist,
-ptop_to_q_closest_id,
qtop_to_p_closest_id,
) # (N, 9, H, W)
occlusion_layers = occlusion_offsets.permute((0, 2, 3, 1)) # (N, H, W, 9)
return occlusion_layers
def _compute_splatting_colors_and_weights(
pixel_coords_screen: torch.Tensor,
colors: torch.Tensor,
sigma: float,
offsets: torch.Tensor,
) -> torch.Tensor:
"""
For each center pixel q, compute the splatting weights of its surrounding nine spla-
tting pixels p, as well as their splatting colors (which are just their colors re-
weighted by the splatting weights).
Args:
pixel_coords_screen: (N, H, W, K, 2) tensor of pixel screen coords.
colors: (N, H, W, K, 4) RGBA tensor of pixel colors.
sigma: splatting kernel variance.
offsets: (9, 2) tensor computed by _precompute, indicating the nine
splatting directions ([-1, -1], ..., [1, 1]).
Returns:
splat_colors_and_weights: (N, H, W, K, 9, 5) tensor.
splat_colors_and_weights[..., :4] corresponds to the splatting colors, and
splat_colors_and_weights[..., 4:5] to the splatting weights. The "9" di-
mension corresponds to the nine splatting directions.
"""
N, H, W, K, C = colors.shape
splat_kernel_normalization = _get_splat_kernel_normalization(offsets, sigma)
# Distance from each barycentric-interpolated triangle vertices' triplet from its
# "ideal" pixel-center location. pixel_coords_screen are in screen coordinates, and
# should be at the "ideal" locations on the forward pass -- e.g.
# pixel_coords_screen[n, 24, 31, k] = [24.5, 31.5]. For this reason, q_to_px_center
# should equal torch.zeros during the forward pass. On the backwards pass, these
# coordinates will be adjusted and non-zero, allowing the gradients to flow back
# to the mesh vertex coordinates.
q_to_px_center = (
torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5
).view((N, H, W, K, 1, 2))
dist2_p_q = torch.sum((q_to_px_center + offsets) ** 2, dim=5) # (N, H, W, K, 9)
splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2))
alpha = colors[..., 3:4]
splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze(
5
) # (N, H, W, K, 9, 1)
# splat_colors[n, h, w, direction, :] contains the splatting color (weighted by the
# splatting weight) that pixel h, w will splat in one of the nine possible
# directions (e.g. nhw0 corresponds to splatting in [-1, 1] direciton, nhw4 is
# self-splatting).
splat_colors = splat_weights * colors.unsqueeze(4) # (N, H, W, K, 9, 4)
return torch.cat([splat_colors, splat_weights], dim=5)
def _offset_splats(
splat_colors_and_weights: torch.Tensor,
crop_ids_h: torch.Tensor,
crop_ids_w: torch.Tensor,
) -> torch.Tensor:
"""
Pad splatting colors and weights so that tensor locations/coordinates are aligned
with the splatting directions. For example, say we have an example input Red channel
splat_colors_and_weights[n, :, :, k, direction=0, channel=0] equal to
.1 .2 .3
.4 .5 .6
.7 .8 .9
the (h, w) entry indicates that pixel n, h, w, k splats the given color in direction
equal to 0, which corresponds to offsets[0] = (-1, -1). Note that this is the x-y
direction, not h-w. This function pads and crops this array to
0 0 0
.2 .3 0
.5 .6 0
which indicates, for example, that:
* There is no pixel splatting in direction (-1, -1) whose splat lands on pixel
h=w=0.
* There is a pixel splatting in direction (-1, -1) whose splat lands on the pi-
xel h=1, w=0, and that pixel's splatting color is .2.
* There is a pixel splatting in direction (-1, -1) whose splat lands on the pi-
xel h=2, w=1, and that pixel's splatting color is .6.
Args:
*splat_colors_and_weights*: (N, H, W, K, 9, 5) tensor of colors and weights,
where dim=-2 corresponds to the splatting directions/offsets.
*crop_ids_h*: (N, H, W+2, K, 9, 5) precomputed tensor used for padding within
torch.gather. See _precompute for more info.
*crop_ids_w*: (N, H, W, K, 9, 5) precomputed tensor used for padding within
torch.gather. See _precompute for more info.
Returns:
*splat_colors_and_weights*: (N, H, W, K, 9, 5) tensor.
"""
N, H, W, K, _, _ = splat_colors_and_weights.shape
# Transform splat_colors such that each of the 9 layers (corresponding to
# the 9 splat offsets) is padded with 1 and shifted in the appropriate
# direction. E.g. splat_colors[n, :, :, 0] corresponds to the (-1, -1)
# offset, so will be padded with one rows of 1 on the right and have a
# single row clipped at the bottom, and splat_colors[n, :, :, 4] corrsponds
# to offset (0, 0) and will remain unchanged.
splat_colors_and_weights = F.pad(
splat_colors_and_weights, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]
) # N, H+2, W+2, 9, 5
# (N, H+2, W+2, K, 9, 5) -> (N, H, W+2, K, 9, 5)
splat_colors_and_weights = torch.gather(
splat_colors_and_weights, dim=1, index=crop_ids_h
)
# (N, H, W+2, K, 9, 5) -> (N, H, W, K, 9, 5)
splat_colors_and_weights = torch.gather(
splat_colors_and_weights, dim=2, index=crop_ids_w
)
return splat_colors_and_weights
def _compute_splatted_colors_and_weights(
occlusion_layers: torch.Tensor, # (N, H, W, 9)
splat_colors_and_weights: torch.Tensor, # (N, H, W, K, 9, 5)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Accumulate splatted colors in background, surface and foreground occlusion buffers.
Args:
occlusion_layers: (N, H, W, 9) tensor. See _compute_occlusion_layers.
splat_colors_and_weights: (N, H, W, K, 9, 5) tensor. See _offset_splats.
Returns:
splatted_colors: (N, H, W, 4, 3) tensor. Last dimension corresponds to back-
ground, surface, and foreground splat colors.
splatted_weights: (N, H, W, 1, 3) tensor. Last dimension corresponds to back-
ground, surface, and foreground splat weights and is used for normalization.
"""
N, H, W, K, _, _ = splat_colors_and_weights.shape
# Create an occlusion mask, with the last dimension of length 3, corresponding to
# background/surface/foreground splatting. E.g. occlusion_layer_mask[n,h,w,k,d,0] is
# 1 if the pixel at hw is splatted from direction d such that the splatting pixel p
# is below the splatted pixel q (in the background); otherwise, the value is 0.
# occlusion_layer_mask[n,h,w,k,d,1] is 1 if the splatting pixel is at the same
# surface level as the splatted pixel q, and occlusion_layer_mask[n,h,w,k,d,2] is
# 1 only if the splatting pixel is in the foreground.
layer_ids = torch.arange(K, device=splat_colors_and_weights.device).view(
1, 1, 1, K, 1
)
occlusion_layers = occlusion_layers.view(N, H, W, 1, 9)
occlusion_layer_mask = torch.stack(
[
occlusion_layers > layer_ids, # (N, H, W, K, 9)
occlusion_layers == layer_ids, # (N, H, W, K, 9)
occlusion_layers < layer_ids, # (N, H, W, K, 9)
],
dim=5,
).float() # (N, H, W, K, 9, 3)
# (N * H * W, 5, 9 * K) x (N * H * W, 9 * K, 3) -> (N * H * W, 5, 3)
splatted_colors_and_weights = torch.bmm(
splat_colors_and_weights.permute(0, 1, 2, 5, 3, 4).reshape(
(N * H * W, 5, K * 9)
),
occlusion_layer_mask.reshape((N * H * W, K * 9, 3)),
).reshape((N, H, W, 5, 3))
return (
splatted_colors_and_weights[..., :4, :],
splatted_colors_and_weights[..., 4:5, :],
)
def _normalize_and_compose_all_layers(
background_color: torch.Tensor,
splatted_colors_per_occlusion_layer: torch.Tensor,
splatted_weights_per_occlusion_layer: torch.Tensor,
) -> torch.Tensor:
"""
Normalize each bg/surface/fg buffer by its weight, and compose.
Args:
background_color: (3) RGB tensor.
splatter_colors_per_occlusion_layer: (N, H, W, 4, 3) RGBA tensor, last dimension
corresponds to foreground, surface, and background splatting.
splatted_weights_per_occlusion_layer: (N, H, W, 1, 3) weight tensor.
Returns:
output_colors: (N, H, W, 4) RGBA tensor.
"""
device = splatted_colors_per_occlusion_layer.device
# Normalize each of bg/surface/fg splat layers separately.
normalization_scales = 1.0 / (
torch.maximum(
splatted_weights_per_occlusion_layer,
torch.tensor([1.0], device=device),
)
) # (N, H, W, 1, 3)
normalized_splatted_colors = (
splatted_colors_per_occlusion_layer * normalization_scales
) # (N, H, W, 4, 3)
# Use alpha-compositing to compose the splat layers.
output_colors = torch.cat(
[background_color, torch.tensor([0.0], device=device)]
) # (4), will broadcast to (N, H, W, 4) below.
for occlusion_layer_id in (-1, -2, -3):
# Over-compose the bg, surface, and fg occlusion layers. Note that we already
# multiplied each pixel's RGBA by its own alpha as part of self-splatting in
# _compute_splatting_colors_and_weights, so we don't re-multiply by alpha here.
alpha = normalized_splatted_colors[..., 3:4, occlusion_layer_id] # (N, H, W, 1)
output_colors = (
normalized_splatted_colors[..., occlusion_layer_id]
+ (1.0 - alpha) * output_colors
)
return output_colors
class SplatterBlender(torch.nn.Module):
def __init__(
self,
input_shape: Tuple[int, int, int, int],
device,
):
"""
A splatting blender. See `forward` docs for details of the splatting mechanism.
Args:
input_shape: Tuple (N, H, W, K) indicating the batch size, image height,
image width, and number of rasterized layers. Used to precompute
constant tensors that do not change as long as this tuple is unchanged.
"""
super().__init__()
self.crop_ids_h, self.crop_ids_w, self.offsets = _precompute(
input_shape, device
)
def forward(
self,
colors: torch.Tensor,
pixel_coords_cameras: torch.Tensor,
cameras: FoVPerspectiveCameras,
background_mask: torch.Tensor,
blend_params: BlendParams,
) -> torch.Tensor:
"""
RGB blending using splatting, as proposed in [0].
Args:
colors: (N, H, W, K, 3) tensor of RGB colors at each h, w pixel location for
K intersection layers.
pixel_coords_cameras: (N, H, W, K, 3) tensor of pixel coordinates in the
camera frame of reference. It is *crucial* that these are computed by
interpolating triangle vertex positions using barycentric coordinates --
this allows gradients to travel through pixel_coords_camera back to the
vertex positions.
cameras: Cameras object used to project pixel_coords_cameras screen coords.
background_mask: (N, H, W, K, 3) boolean tensor, True for bg pixels. A pixel
is considered "background" if no mesh triangle projects to it. This is
typically computed by the rasterizer.
blend_params: BlendParams, from which we use sigma (splatting kernel
variance) and background_color.
Returns:
output_colors: (N, H, W, 4) tensor of RGBA values. The alpha layer is set to
fully transparent in the background.
[0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
Sampling".
"""
# Our implementation has 6 stages. In the description below, we will call each
# pixel q and the 9 surrounding splatting pixels (including itself) p.
# 1. Use barycentrics to compute the position of each pixel in screen
# coordinates. These should exactly correspond to pixel centers during the
# forward pass, but can be shifted on backwards. This step allows gradients to
# travel to vertex coordinates, even if the rasterizer is non-differentiable.
# 2a. For each center pixel q, take each splatting p and decide whether it
# is on the same surface level as q, or in the background or foreground.
# 2b. For each center pixel q, compute the splatting weight of surrounding
# pixels p, and their splatting colors (which are just the original colors
# weighted by the splatting weights).
# 3. As a vectorization technicality, offset the tensors corresponding to
# the splatting p values in the nine directions, by padding each of nine
# splatting layers on the bottom/top, left/right.
# 4. Do the actual splatting, by accumulating the splatting colors of the
# surrounding p's for each pixel q. The weights get accumulated separately for
# p's that got assigned to the background/surface/foreground in Step 2a.
# 5. Normalize each the splatted bg/surface/fg colors for each q, and
# compose the resulting color maps.
#
# Note that it is crucial that in Step 1 we compute the pixel coordinates by in-
# terpolating triangle vertices using barycentric coords from the rasterizer. In
# our case, these pixel_coords_camera are computed by the shader and passed to
# this function to avoid re-computation.
pixel_coords_screen, colors = _prepare_pixels_and_colors(
pixel_coords_cameras, colors, cameras, background_mask
) # (N, H, W, K, 3) and (N, H, W, K, 4)
occlusion_layers = _compute_occlusion_layers(
pixel_coords_screen[..., 2:3].squeeze(dim=-1)
) # (N, H, W, 9)
splat_colors_and_weights = _compute_splatting_colors_and_weights(
pixel_coords_screen,
colors,
blend_params.sigma,
self.offsets,
) # (N, H, W, K, 9, 5)
splat_colors_and_weights = _offset_splats(
splat_colors_and_weights,
self.crop_ids_h,
self.crop_ids_w,
) # (N, H, W, K, 9, 5)
(
splatted_colors_per_occlusion_layer,
splatted_weights_per_occlusion_layer,
) = _compute_splatted_colors_and_weights(
occlusion_layers, splat_colors_and_weights
) # (N, H, W, 4, 3) and (N, H, W, 1, 3)
output_colors = _normalize_and_compose_all_layers(
_get_background_color(blend_params, colors.device),
splatted_colors_per_occlusion_layer,
splatted_weights_per_occlusion_layer,
) # (N, H, W, 4)
return output_colors

View File

@ -41,6 +41,7 @@ from pytorch3d.renderer.mesh.shader import (
HardPhongShader,
SoftPhongShader,
SoftSilhouetteShader,
SplatterPhongShader,
TexturedSoftPhongShader,
)
from pytorch3d.structures.meshes import (
@ -325,6 +326,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
shader_tests = [
ShaderTest(HardPhongShader, "phong", "hard_phong"),
ShaderTest(SoftPhongShader, "phong", "soft_phong"),
ShaderTest(SplatterPhongShader, "phong", "splatter_phong"),
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
ShaderTest(HardFlatShader, "flat", "hard_flat"),
]

View File

@ -0,0 +1,627 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.cameras import FoVPerspectiveCameras
from pytorch3d.renderer.splatter_blend import (
_compute_occlusion_layers,
_compute_splatted_colors_and_weights,
_compute_splatting_colors_and_weights,
_get_splat_kernel_normalization,
_normalize_and_compose_all_layers,
_offset_splats,
_precompute,
_prepare_pixels_and_colors,
)
offsets = torch.tensor(
[
[-1, -1],
[-1, 0],
[-1, 1],
[0, -1],
[0, 0],
[0, 1],
[1, -1],
[1, 0],
[1, 1],
],
device=torch.device("cpu"),
)
def compute_splatting_colors_and_weights_naive(pixel_coords_screen, colors, sigma):
normalizer = float(_get_splat_kernel_normalization(offsets))
N, H, W, K, _ = colors.shape
splat_weights_and_colors = torch.zeros((N, H, W, K, 9, 5))
for n in range(N):
for h in range(H):
for w in range(W):
for k in range(K):
q_xy = pixel_coords_screen[n, h, w, k]
q_to_px_center = torch.floor(q_xy) - q_xy + 0.5
color = colors[n, h, w, k]
alpha = colors[n, h, w, k, 3:4]
for d in range(9):
dist_p_q = torch.sum((q_to_px_center + offsets[d]) ** 2)
splat_weight = (
alpha * torch.exp(-dist_p_q / (2 * sigma**2)) * normalizer
)
splat_color = splat_weight * color
splat_weights_and_colors[n, h, w, k, d, :4] = splat_color
splat_weights_and_colors[n, h, w, k, d, 4:5] = splat_weight
return splat_weights_and_colors
class TestPrecompute(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.results_cpu = _precompute((2, 3, 4, 5), torch.device("cpu"))
self.results1_cpu = _precompute((1, 1, 1, 1), torch.device("cpu"))
def test_offsets(self):
self.assertClose(self.results_cpu[2].shape, offsets.shape, atol=0)
self.assertClose(self.results_cpu[2], offsets, atol=0)
# Offsets should be independent of input_size.
self.assertClose(self.results_cpu[2], self.results1_cpu[2], atol=0)
def test_crops_h(self):
target_crops_h1 = torch.tensor(
[
# chennels being offset:
# R G B A W(eight)
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
]
* 3, # 3 because we're aiming at (N, H, W+2, K, 9, 5) with W=1.
device=torch.device("cpu"),
).reshape(1, 1, 3, 1, 9, 5)
self.assertClose(self.results1_cpu[0], target_crops_h1, atol=0)
target_crops_h_base = target_crops_h1[0, 0, 0]
target_crops_h = torch.cat(
[target_crops_h_base, target_crops_h_base + 1, target_crops_h_base + 2],
dim=0,
)
# Check that we have the right shape, and (after broadcasting) it has the right
# values. These should be repeated (tiled) for each n and k.
self.assertClose(
self.results_cpu[0].shape, torch.tensor([2, 3, 6, 5, 9, 5]), atol=0
)
for n in range(2):
for w in range(6):
for k in range(5):
self.assertClose(
self.results_cpu[0][n, :, w, k],
target_crops_h,
)
def test_crops_w(self):
target_crops_w1 = torch.tensor(
[
# chennels being offset:
# R G B A W(eight)
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
],
device=torch.device("cpu"),
).reshape(1, 1, 1, 1, 9, 5)
self.assertClose(self.results1_cpu[1], target_crops_w1)
target_crops_w_base = target_crops_w1[0, 0, 0]
target_crops_w = torch.cat(
[
target_crops_w_base,
target_crops_w_base + 1,
target_crops_w_base + 2,
target_crops_w_base + 3,
],
dim=0,
) # Each w value needs an increment.
# Check that we have the right shape, and (after broadcasting) it has the right
# values. These should be repeated (tiled) for each n and k.
self.assertClose(self.results_cpu[1].shape, torch.tensor([2, 3, 4, 5, 9, 5]))
for n in range(2):
for h in range(3):
for k in range(5):
self.assertClose(
self.results_cpu[1][n, h, :, k],
target_crops_w,
atol=0,
)
class TestPreparPixelsAndColors(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.device = torch.device("cpu")
N, H, W, K = 2, 3, 4, 5
self.pixel_coords_cameras = torch.randn(
(N, H, W, K, 3), device=self.device, requires_grad=True
)
self.colors_before = torch.rand((N, H, W, K, 3), device=self.device)
self.cameras = FoVPerspectiveCameras(device=self.device)
self.background_mask = torch.rand((N, H, W, K), device=self.device) < 0.5
self.pixel_coords_screen, self.colors_after = _prepare_pixels_and_colors(
self.pixel_coords_cameras,
self.colors_before,
self.cameras,
self.background_mask,
)
def test_background_z(self):
self.assertTrue(
torch.all(self.pixel_coords_screen[..., 2][self.background_mask] == 1.0)
)
def test_background_alpha(self):
self.assertTrue(
torch.all(self.colors_after[..., 3][self.background_mask] == 0.0)
)
class TestGetSplatKernelNormalization(TestCaseMixin, unittest.TestCase):
def test_splat_kernel_normalization(self):
self.assertAlmostEqual(
float(_get_splat_kernel_normalization(offsets)), 0.6503, places=3
)
self.assertAlmostEqual(
float(_get_splat_kernel_normalization(offsets, 0.01)), 1.05, places=3
)
with self.assertRaisesRegex(ValueError, "Only positive standard deviations"):
_get_splat_kernel_normalization(offsets, 0)
class TestComputeOcclusionLayers(TestCaseMixin, unittest.TestCase):
def test_single_layer(self):
# If there's only one layer, all splats must be on the surface level.
N, H, W, K = 2, 3, 4, 1
q_depth = torch.rand(N, H, W, K)
occlusion_layers = _compute_occlusion_layers(q_depth)
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
def test_all_equal(self):
# If all q-vals are equal, then all splats must be on the surface level.
N, H, W, K = 2, 3, 4, 5
q_depth = torch.ones((N, H, W, K)) * 0.1234
occlusion_layers = _compute_occlusion_layers(q_depth)
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
def test_mid_to_top_level_splatting(self):
# Check that occlusion buffers get accumulated as expected when the splatting
# and splatted pixels are co-surface on different intersection layers.
# This test will make best sense with accompanying Fig. 4 from "Differentiable
# Surface Rendering via Non-differentiable Sampling" by Cole et al.
for direction, offset in enumerate(offsets):
if direction == 4:
continue # Skip self-splatting which is always co-surface.
depths = torch.zeros(1, 3, 3, 3)
# This is our q, the pixel splatted onto, in the center of the image.
depths[0, 1, 1] = torch.tensor([0.71, 0.8, 1.0])
# This is our p, the splatting pixel.
depths[0, offset[0] + 1, offset[1] + 1] = torch.tensor([0.5, 0.7, 0.9])
occlusion_layers = _compute_occlusion_layers(depths)
# Check that we computed that it is the middle layer of p that is co-
# surface with q. (1, 1) is the id of q in the depth array, and offset_id
# is the id of p's direction w.r.t. q.
psurfaceid_onto_q = occlusion_layers[0, 1, 1, direction]
self.assertEqual(int(psurfaceid_onto_q), 1)
# Conversely, if we swap p and q, we have a top-level splatting onto
# mid-level. offset + 1 is the id of p, and 8-offset_id is the id of
# q's direction w.r.t. p (e.g. if p is [-1, -1] w.r.t. q, then q is
# [1, 1] w.r.t. p; we use the ids of these two directions in the offsets
# array).
qsurfaceid_onto_p = occlusion_layers[
0, offset[0] + 1, offset[1] + 1, 8 - direction
]
self.assertEqual(int(qsurfaceid_onto_p), -1)
class TestComputeSplattingColorsAndWeights(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.N, self.H, self.W, self.K = 2, 3, 4, 5
self.pixel_coords_screen = (
torch.tile(
torch.stack(
torch.meshgrid(
torch.arange(self.H), torch.arange(self.W), indexing="ij"
),
dim=-1,
).reshape(1, self.H, self.W, 1, 2),
(self.N, 1, 1, self.K, 1),
).float()
+ 0.5
)
self.colors = torch.ones((self.N, self.H, self.W, self.K, 4))
def test_all_equal(self):
# If all colors are equal and on a regular grid, all weights and reweighted
# colors should be equal given a specific splatting direction.
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
self.pixel_coords_screen, self.colors * 0.2345, sigma=0.5, offsets=offsets
)
# Splatting directly to the top/bottom/left/right should have the same strenght.
non_diag_splats = splatting_colors_and_weights[
:, :, :, :, torch.tensor([1, 3, 5, 7])
]
# Same for diagonal splats.
diag_splats = splatting_colors_and_weights[
:, :, :, :, torch.tensor([0, 2, 6, 8])
]
# And for self-splats.
self_splats = splatting_colors_and_weights[:, :, :, :, torch.tensor([4])]
for splats in non_diag_splats, diag_splats, self_splats:
# Colors should be equal.
self.assertTrue(torch.all(splats[..., :4] == splats[0, 0, 0, 0, 0, 0]))
# Weights should be equal.
self.assertTrue(torch.all(splats[..., 4] == splats[0, 0, 0, 0, 0, 4]))
# Non-diagonal weights should be greater than diagonal weights.
self.assertGreater(
non_diag_splats[0, 0, 0, 0, 0, 0], diag_splats[0, 0, 0, 0, 0, 0]
)
# Self-splats should be strongest of all.
self.assertGreater(
self_splats[0, 0, 0, 0, 0, 0], non_diag_splats[0, 0, 0, 0, 0, 0]
)
# Splatting colors should be reweighted proportionally to their splat weights.
diag_self_color_ratio = (
diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
)
diag_self_weight_ratio = (
diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
)
self.assertEqual(diag_self_color_ratio, diag_self_weight_ratio)
non_diag_self_color_ratio = (
non_diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
)
non_diag_self_weight_ratio = (
non_diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
)
self.assertEqual(non_diag_self_color_ratio, non_diag_self_weight_ratio)
def test_zero_alpha_zero_weight(self):
# Pixels with zero alpha do no splatting, but should still be splatted on.
colors = self.colors.clone()
colors[0, 1, 1, 0, 3] = 0.0
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
self.pixel_coords_screen, colors, sigma=0.5, offsets=offsets
)
# The transparent pixel should do no splatting.
self.assertTrue(torch.all(splatting_colors_and_weights[0, 1, 1, 0] == 0.0))
# Splatting *onto* the transparent pixel should be unaffected.
reference_weights_colors = splatting_colors_and_weights[0, 1, 1, 1]
for direction, offset in enumerate(offsets):
if direction == 4:
continue # Ignore self-splats
# We invert the direction to get the right (h, w, d) coordinate of each
# pixel splatting *onto* the pixel with zero alpha.
self.assertClose(
splatting_colors_and_weights[
0, 1 + offset[0], 1 + offset[1], 0, 8 - direction
],
reference_weights_colors[8 - direction],
atol=0.001,
)
def test_random_inputs(self):
pixel_coords_screen = (
self.pixel_coords_screen
+ torch.randn((self.N, self.H, self.W, self.K, 2)) * 0.1
)
colors = torch.rand((self.N, self.H, self.W, self.K, 4))
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
pixel_coords_screen, colors, sigma=0.5, offsets=offsets
)
naive_colors_and_weights = compute_splatting_colors_and_weights_naive(
pixel_coords_screen, colors, sigma=0.5
)
self.assertClose(
splatting_colors_and_weights, naive_colors_and_weights, atol=0.01
)
class TestOffsetSplats(TestCaseMixin, unittest.TestCase):
def test_offset(self):
device = torch.device("cuda:0")
N, H, W, K = 2, 3, 4, 5
colors_and_weights = torch.rand((N, H, W, K, 9, 5), device=device)
crop_ids_h, crop_ids_w, _ = _precompute((N, H, W, K), device=device)
offset_colors_and_weights = _offset_splats(
colors_and_weights, crop_ids_h, crop_ids_w
)
# Check each splatting direction individually, for clarity.
# offset_x, offset_y = (-1, -1)
direction = 0
self.assertClose(
offset_colors_and_weights[:, 1:, 1:, :, direction],
colors_and_weights[:, :-1, :-1, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
)
# offset_x, offset_y = (-1, 0)
direction = 1
self.assertClose(
offset_colors_and_weights[:, :, 1:, :, direction],
colors_and_weights[:, :, :-1, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
)
# offset_x, offset_y = (-1, 1)
direction = 2
self.assertClose(
offset_colors_and_weights[:, :-1, 1:, :, direction],
colors_and_weights[:, 1:, :-1, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
)
# offset_x, offset_y = (0, -1)
direction = 3
self.assertClose(
offset_colors_and_weights[:, 1:, :, :, direction],
colors_and_weights[:, :-1, :, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
)
# self-splat
direction = 4
self.assertClose(
offset_colors_and_weights[..., direction, :],
colors_and_weights[..., direction, :],
atol=0.001,
)
# offset_x, offset_y = (0, 1)
direction = 5
self.assertClose(
offset_colors_and_weights[:, :-1, :, :, direction],
colors_and_weights[:, 1:, :, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
)
# offset_x, offset_y = (1, -1)
direction = 6
self.assertClose(
offset_colors_and_weights[:, 1:, :-1, :, direction],
colors_and_weights[:, :-1, 1:, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
)
# offset_x, offset_y = (1, 0)
direction = 7
self.assertClose(
offset_colors_and_weights[:, :, :-1, :, direction],
colors_and_weights[:, :, 1:, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
)
# offset_x, offset_y = (1, 1)
direction = 8
self.assertClose(
offset_colors_and_weights[:, :-1, :-1, :, direction],
colors_and_weights[:, 1:, 1:, :, direction],
atol=0.001,
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
)
self.assertTrue(
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
)
class TestComputeSplattedColorsAndWeights(TestCaseMixin, unittest.TestCase):
def test_accumulation_background(self):
# Set occlusion_layers to all -1, so all splats are background splats.
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
occlusion_layers = torch.zeros((1, 1, 1, 9)) - 1
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
occlusion_layers, splat_colors_and_weights
)
# Foreground splats (there are none).
self.assertClose(
splatted_colors[0, 0, 0, :, 0],
torch.zeros((4)),
atol=0.001,
)
# Surface splats (there are none).
self.assertClose(
splatted_colors[0, 0, 0, :, 1],
torch.zeros((4)),
atol=0.001,
)
# Background splats.
self.assertClose(
splatted_colors[0, 0, 0, :, 2],
splat_colors_and_weights[0, 0, 0, :, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
def test_accumulation_middle(self):
# Set occlusion_layers to all 0, so top splats are co-surface with splatted
# pixels. Thus, the top splatting layer should be accumulated to surface, and
# all other layers to background.
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
occlusion_layers = torch.zeros((1, 1, 1, 9))
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
occlusion_layers, splat_colors_and_weights
)
# Foreground splats (there are none).
self.assertClose(
splatted_colors[0, 0, 0, :, 0],
torch.zeros((4)),
atol=0.001,
)
# Surface splats
self.assertClose(
splatted_colors[0, 0, 0, :, 1],
splat_colors_and_weights[0, 0, 0, 0, :, :4].sum(dim=0),
atol=0.001,
)
# Background splats
self.assertClose(
splatted_colors[0, 0, 0, :, 2],
splat_colors_and_weights[0, 0, 0, 1:, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
def test_accumulation_foreground(self):
# Set occlusion_layers to all 1. Then the top splatter is a foreground
# splatter, mid splatter is surface, and bottom splatter is background.
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
occlusion_layers = torch.zeros((1, 1, 1, 9)) + 1
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
occlusion_layers, splat_colors_and_weights
)
# Foreground splats
self.assertClose(
splatted_colors[0, 0, 0, :, 0],
splat_colors_and_weights[0, 0, 0, 0:1, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
# Surface splats
self.assertClose(
splatted_colors[0, 0, 0, :, 1],
splat_colors_and_weights[0, 0, 0, 1:2, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
# Background splats
self.assertClose(
splatted_colors[0, 0, 0, :, 2],
splat_colors_and_weights[0, 0, 0, 2:3, :, :4].sum(dim=0).sum(dim=0),
atol=0.001,
)
class TestNormalizeAndComposeAllLayers(TestCaseMixin, unittest.TestCase):
def test_background_color(self):
# Background should always have alpha=0, and the chosen RGB.
N, H, W = 2, 3, 4
# Make a mask with background in the zeroth row of the first image.
bg_mask = torch.zeros([N, H, W, 1, 1])
bg_mask[0, :, 0] = 1
bg_color = torch.tensor([0.2, 0.3, 0.4])
color_layers = torch.rand((N, H, W, 4, 3)) * (1 - bg_mask)
color_weights = torch.rand((N, H, W, 1, 3)) * (1 - bg_mask)
colors = _normalize_and_compose_all_layers(
bg_color, color_layers, color_weights
)
# Background RGB should be .2, .3, .4, and alpha should be 0.
self.assertClose(
torch.masked_select(colors, bg_mask.bool()[..., 0]),
torch.tensor([0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0.0]),
atol=0.001,
)
def test_compositing_opaque(self):
# When all colors are opaque, only the foreground layer should be visible.
N, H, W = 2, 3, 4
color_layers = torch.rand((N, H, W, 4, 3))
color_layers[..., 3, :] = 1.0
color_weights = torch.ones((N, H, W, 1, 3))
out_colors = _normalize_and_compose_all_layers(
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
)
self.assertClose(out_colors, color_layers[..., 0], atol=0.001)
def test_compositing_transparencies(self):
# When foreground layer is transparent and surface and bg are semi-transparent,
# we should return a mix of the two latter.
N, H, W = 2, 3, 4
color_layers = torch.rand((N, H, W, 4, 3))
color_layers[..., 3, 0] = 0.1 # fg
color_layers[..., 3, 1] = 0.2 # surface
color_layers[..., 3, 2] = 0.3 # bg
color_weights = torch.ones((N, H, W, 1, 3))
out_colors = _normalize_and_compose_all_layers(
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
)
self.assertClose(
out_colors,
color_layers[..., 0]
+ 0.9 * (color_layers[..., 1] + 0.8 * color_layers[..., 2]),
)