From c5a83f46efd53aaf5fbb77a0383cb6f3cdc98545 Mon Sep 17 00:00:00 2001 From: Krzysztof Chalupka Date: Tue, 24 May 2022 21:04:11 -0700 Subject: [PATCH] SplatterBlender Summary: Splatting shader. See code comments for details. Same API as SoftPhongShader. Reviewed By: jcjohnson Differential Revision: D36354301 fbshipit-source-id: 71ee37f7ff6bb9ce028ba42a65741424a427a92d --- pytorch3d/renderer/__init__.py | 2 + pytorch3d/renderer/blending.py | 13 +- pytorch3d/renderer/mesh/__init__.py | 1 + pytorch3d/renderer/mesh/shader.py | 69 ++- pytorch3d/renderer/splatter_blend.py | 553 +++++++++++++++++++++++ tests/test_render_meshes.py | 2 + tests/test_splatter_blend.py | 627 +++++++++++++++++++++++++++ 7 files changed, 1260 insertions(+), 7 deletions(-) create mode 100644 pytorch3d/renderer/splatter_blend.py create mode 100644 tests/test_splatter_blend.py diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index ef3daf0e..4e566f5e 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -57,6 +57,7 @@ from .mesh import ( SoftGouraudShader, SoftPhongShader, SoftSilhouetteShader, + SplatterPhongShader, Textures, TexturesAtlas, TexturesUV, @@ -71,6 +72,7 @@ from .points import ( PulsarPointsRenderer, rasterize_points, ) +from .splatter_blend import SplatterBlender from .utils import ( convert_to_tensors_and_broadcast, ndc_grid_sample, diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index bfdae6c9..69318ca1 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -4,14 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from typing import NamedTuple, Sequence, Union import torch from pytorch3d import _C from pytorch3d.common.datatypes import Device - # Example functions for blending the top K colors per pixel using the outputs # from rasterization. # NOTE: All blending function should return an RGBA image per batch element @@ -22,10 +20,12 @@ class BlendParams(NamedTuple): Data class to store blending params with defaults Members: - sigma (float): Controls the width of the sigmoid function used to - calculate the 2D distance based probability. Determines the - sharpness of the edges of the shape. - Higher => faces have less defined edges. + sigma (float): For SoftmaxPhong, controls the width of the sigmoid + function used to calculate the 2D distance based probability. Determines + the sharpness of the edges of the shape. Higher => faces have less defined + edges. For SplatterPhong, this is the standard deviation of the Gaussian + kernel. Higher => splats have a stronger effect and the rendered image is + more blurry. gamma (float): Controls the scaling of the exponential function used to set the opacity of the color. Higher => faces are more transparent. @@ -36,6 +36,7 @@ class BlendParams(NamedTuple): sigma: float = 1e-4 gamma: float = 1e-4 background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0) + background_alpha: float = 0.0 def _get_background_color( diff --git a/pytorch3d/renderer/mesh/__init__.py b/pytorch3d/renderer/mesh/__init__.py index e9e12cdc..46cd791a 100644 --- a/pytorch3d/renderer/mesh/__init__.py +++ b/pytorch3d/renderer/mesh/__init__.py @@ -22,6 +22,7 @@ from .shader import ( # DEPRECATED SoftGouraudShader, SoftPhongShader, SoftSilhouetteShader, + SplatterPhongShader, TexturedSoftPhongShader, ) from .shading import gouraud_shading, phong_shading diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 797efbcc..815dd85e 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -20,9 +20,15 @@ from ..blending import ( ) from ..lighting import PointLights from ..materials import Materials +from ..splatter_blend import SplatterBlender from ..utils import TensorProperties from .rasterizer import Fragments -from .shading import flat_shading, gouraud_shading, phong_shading +from .shading import ( + _phong_shading_with_pixels, + flat_shading, + gouraud_shading, + phong_shading, +) # A Shader should take as input fragments from the output of rasterization @@ -308,3 +314,64 @@ class SoftSilhouetteShader(nn.Module): blend_params = kwargs.get("blend_params", self.blend_params) images = sigmoid_alpha_blend(colors, fragments, blend_params) return images + + +class SplatterPhongShader(ShaderBase): + """ + Per pixel lighting - the lighting model is applied using the interpolated + coordinates and normals for each pixel. The blending function returns the + color aggregated using splats from surrounding pixels (see [0]). + + To use the default values, simply initialize the shader with the desired + device e.g. + + .. code-block:: + + shader = SplatterPhongShader(device=torch.device("cuda:0")) + + Args: + detach_rasterizer: If True, stop gradients from flowing through the rasterizer. + This simulates the pipeline in [0] which uses a non-differentiable OpenGL + rasterizer. + + [0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable + Sampling". + """ + + def __init__(self, **kwargs): + self.splatter_blender = None + super().__init__(**kwargs) + + def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: + cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SplatterPhongShader." + raise ValueError(msg) + texels = meshes.sample_textures(fragments) + lights = kwargs.get("lights", self.lights) + materials = kwargs.get("materials", self.materials) + + colors, pixel_coords_cameras = _phong_shading_with_pixels( + meshes=meshes, + fragments=fragments.detach(), + texels=texels, + lights=lights, + cameras=cameras, + materials=materials, + ) + + if not self.splatter_blender: + # Init only once, to avoid re-computing constants. + N, H, W, K, _ = colors.shape + self.splatter_blender = SplatterBlender((N, H, W, K), colors.device) + + images = self.splatter_blender( + colors, + pixel_coords_cameras, + cameras, + fragments.pix_to_face < 0, + self.blend_params, + ) + + return images diff --git a/pytorch3d/renderer/splatter_blend.py b/pytorch3d/renderer/splatter_blend.py new file mode 100644 index 00000000..4bdb0a73 --- /dev/null +++ b/pytorch3d/renderer/splatter_blend.py @@ -0,0 +1,553 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +from typing import Tuple + +import torch +import torch.nn.functional as F +from pytorch3d.common.datatypes import Device +from pytorch3d.renderer import BlendParams +from pytorch3d.renderer.cameras import FoVPerspectiveCameras + +from .blending import _get_background_color + + +def _precompute( + input_shape: Tuple[int, int, int, int], device: Device +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Precompute padding and offset constants that won't change for a given NHWK shape. + + Args: + input_shape: Tuple indicating N (batch size), H, W (image size) and K (number of + intersections) output by the rasterizer. + device: Device to store the tensors on. + + returns: + crop_ids_h: An (N, H, W+2, K, 9, 5) tensor, used during splatting to offset the + p-pixels (splatting pixels) in one of the 9 splatting directions within a + call to torch.gather. See comments and offset_splats for details. + crop_ids_w: An (N, H, W, K, 9, 5) tensor, used similarly to crop_ids_h. + offsets: A (1, 1, 1, 1, 9, 2) tensor (shaped so for broadcasting) containing va- + lues [-1, -1], [-1, 0], [-1, 1], [0, -1], ..., [1, 1] which correspond to + the nine splatting directions. + """ + N, H, W, K = input_shape + + # (N, H, W+2, K, 9, 5) tensor, used to reduce a tensor from (N, H+2, W+2...) to + # (N, H, W+2, ...) in torch.gather. If only torch.gather broadcasted, we wouldn't + # need the tiling. But it doesn't. + crop_ids_h = ( + torch.arange(0, H, device=device).view(1, H, 1, 1, 1, 1) + + torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], device=device).view( + 1, 1, 1, 1, 9, 1 + ) + ).expand(N, H, W + 2, K, 9, 5) + + # (N, H, W, K, 9, 5) tensor, used to reduce a tensor from (N, H, W+2, ...) to + # (N, H, W, ...) in torch.gather. + crop_ids_w = ( + torch.arange(0, W, device=device).view(1, 1, W, 1, 1, 1) + + torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device=device).view( + 1, 1, 1, 1, 9, 1 + ) + ).expand(N, H, W, K, 9, 5) + + offsets = torch.tensor( + list(itertools.product((-1, 0, 1), repeat=2)), + dtype=torch.long, + device=device, + ) + + return crop_ids_h, crop_ids_w, offsets + + +def _prepare_pixels_and_colors( + pixel_coords_cameras: torch.Tensor, + colors: torch.Tensor, + cameras: FoVPerspectiveCameras, + background_mask: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Project pixel coords into the un-inverted screen frame of reference, and set + background pixel z-values to 1.0 and alphas to 0.0. + + Args: + pixel_coords_cameras: (N, H, W, K, 3) float tensor. + colors: (N, H, W, K, 3) float tensor. + cameras: PyTorch3D cameras, for now we assume FoVPerspectiveCameras. + background_mask: (N, H, W, K) boolean tensor. + + Returns: + pixel_coords_screen: (N, H, W, K, 3) float tensor. Background pixels have + x=y=z=1.0. + colors: (N, H, W, K, 4). Alpha is set to 1 for foreground pixels and 0 for back- + ground pixels. + """ + + N, H, W, K, C = colors.shape + # pixel_coords_screen will contain invalid values at background + # intersections, and [H+0.5, W+0.5, z] at valid intersections. It is important + # to not flip the xy axes, otherwise the gradients will be inverted when the + # splatter works with a detached rasterizer. + pixel_coords_screen = cameras.transform_points_screen( + pixel_coords_cameras.view([N, -1, 3]), image_size=(H, W), with_xyflip=False + ).reshape(pixel_coords_cameras.shape) + + # Set colors' alpha to 1 and background to 0. + colors = torch.cat( + [colors, torch.ones_like(colors[..., :1])], dim=-1 + ) # (N, H, W, K, 4) + + # The hw values of background don't matter because their alpha is set + # to 0 in the next step (which means that no matter what their splatting kernel + # value is, they will not splat as the kernel is multiplied by alpha). However, + # their z-values need to be at max depth. Otherwise, we could incorrectly compute + # occlusion layer linkage. + pixel_coords_screen[background_mask] = 1.0 + + # Any background color value value with alpha=0 will do, as anything with + # alpha=0 will have a zero-weight splatting power. Note that neighbors can still + # splat on zero-alpha pixels: that's the way we get non-zero gradients at the + # boundary with the background. + colors[background_mask] = 0.0 + + return pixel_coords_screen, colors + + +def _get_splat_kernel_normalization( + offsets: torch.Tensor, + sigma: float = 0.5, +): + if sigma <= 0.0: + raise ValueError("Only positive standard deviations make sense.") + + epsilon = 0.05 + normalization_constant = torch.exp( + -(offsets**2).sum(dim=1) / (2 * sigma**2) + ).sum() + + # We add an epsilon to the normalization constant to ensure the gradient will travel + # through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia- + # ble Surface Rendering via Non-Differentiable Sampling", Cole et al. + return (1 + epsilon) / normalization_constant + + +def _compute_occlusion_layers( + q_depth: torch.Tensor, +) -> torch.Tensor: + """ + For each splatting pixel, decide whether it splats from a background, surface, or + foreground depth relative to the splatted pixel. See unit tests in + test_splatter_blend for some enlightening examples. + + Args: + q_depth: (N, H, W, K) tensor of z-values of the splatted pixels. + + Returns: + occlusion_layers: (N, H, W, 9) long tensor. Each of the 9 values corresponds to + one of the nine splatting directions ([-1, -1], [-1, 0], ..., [1, + 1]). The value at nhwd (where d is the splatting direction) is 0 if + the splat in direction d is on the same surface level as the pixel at + hw. The value is negative if the splat is in the background (occluded + by another splat above it that is at the same surface level as the + pixel splatted on), and the value is positive if the splat is in the + foreground. + """ + N, H, W, K = q_depth.shape + + # q are the "center pixels" and p the pixels splatting onto them. Use `unfold` to + # create `p_depth`, a tensor with 9 layers, each of which corresponds to the + # depth of a neighbor of q in one of the 9 directions. For example, p_depth[nk0hw] + # is the depth of the pixel splatting onto pixel nhwk from the [-1, -1] direction, + # and p_depth[nk4hw] the depth of q (self-splatting onto itself). + # More concretely, imagine the pixel depths in a 2x2 image's k-th layer are + # .1 .2 + # .3 .4 + # Then (remembering that we pad with zeros when a pixel has fewer than 9 neighbors): + # + # p_depth[n, k, :, 0, 0] = [ 0 0 0 0 .1 .2 0 .3 .4] - neighbors of .1 + # p_depth[n, k, :, 0, 1] = [ 0 0 0 .1 .2 0 .3 .4 0] - neighbors of .2 + # p_depth[n, k, :, 1, 0] = [ 0 .1 .2 0 .3 .4 0 0 0] - neighbors of .3 + # p_depth[n, k, :, 0, 1] = [.1 .2 0 .3 .4 0 0 0 0] - neighbors of .4 + q_depth = q_depth.permute(0, 3, 1, 2) # (N, K, H, W) + p_depth = F.unfold(q_depth, kernel_size=3, padding=1) # (N, 3^2 * K, H * W) + q_depth = q_depth.view(N, K, 1, H, W) + p_depth = p_depth.view(N, K, 9, H, W) + + # Take the center pixel q's top rasterization layer. This is the "surface layer" + # that we're splatting on. For each of the nine splatting directions p, find which + # of the K splatting rasterization layers is closest in depth to the surface + # splatted layer. + qtop_to_p_zdist = torch.abs(p_depth - q_depth[:, 0:1]) # (N, K, 9, H, W) + qtop_to_p_closest_zdist, qtop_to_p_closest_id = qtop_to_p_zdist.min(dim=1) + + # For each of the nine splatting directions p, take the top of the K rasterization + # layers. Check which of the K q-layers (that the given direction is splatting on) + # is closest in depth to the top splatting layer. + ptop_to_q_zdist = torch.abs(p_depth[:, 0:1] - q_depth) # (N, K, 9, H, W) + ptop_to_q_closest_zdist, ptop_to_q_closest_id = ptop_to_q_zdist.min(dim=1) + + # Decide whether each p is on the same level, below, or above the q it is splatting + # on. See Fig. 4 in [0] for an illustration. Briefly: say we're interested in pixel + # p_{h, w} = [10, 32] splatting onto its neighbor q_{h, w} = [11, 33]. The splat is + # coming from direction [-1, -1], which has index 0 in our enumeration of splatting + # directions. Hence, we are interested in + # + # P = p_depth[n, :, d=0, 11, 33] - a vector of K depth values, and + # Q = q_depth.squeeze()[n, :, 11, 33] - a vector of K depth values. + # + # If Q[0] is closest, say, to P[2], then we assume the 0th surface layer of Q is + # the same surface as P[2] that's splatting onto it, and P[:2] are foreground splats + # and P[3:] are background splats. + # + # If instead say Q[2] is closest to P[0], then all the splats are background splats, + # because the top splatting layer is the same surface as a non-top splatted layer. + # + # Finally, if Q[0] is closest to P[0], then the top-level P is splatting onto top- + # level Q, and P[1:] are all background splats. + occlusion_offsets = torch.where( # noqa + ptop_to_q_closest_zdist < qtop_to_p_closest_zdist, + -ptop_to_q_closest_id, + qtop_to_p_closest_id, + ) # (N, 9, H, W) + + occlusion_layers = occlusion_offsets.permute((0, 2, 3, 1)) # (N, H, W, 9) + return occlusion_layers + + +def _compute_splatting_colors_and_weights( + pixel_coords_screen: torch.Tensor, + colors: torch.Tensor, + sigma: float, + offsets: torch.Tensor, +) -> torch.Tensor: + """ + For each center pixel q, compute the splatting weights of its surrounding nine spla- + tting pixels p, as well as their splatting colors (which are just their colors re- + weighted by the splatting weights). + + Args: + pixel_coords_screen: (N, H, W, K, 2) tensor of pixel screen coords. + colors: (N, H, W, K, 4) RGBA tensor of pixel colors. + sigma: splatting kernel variance. + offsets: (9, 2) tensor computed by _precompute, indicating the nine + splatting directions ([-1, -1], ..., [1, 1]). + + Returns: + splat_colors_and_weights: (N, H, W, K, 9, 5) tensor. + splat_colors_and_weights[..., :4] corresponds to the splatting colors, and + splat_colors_and_weights[..., 4:5] to the splatting weights. The "9" di- + mension corresponds to the nine splatting directions. + """ + N, H, W, K, C = colors.shape + splat_kernel_normalization = _get_splat_kernel_normalization(offsets, sigma) + + # Distance from each barycentric-interpolated triangle vertices' triplet from its + # "ideal" pixel-center location. pixel_coords_screen are in screen coordinates, and + # should be at the "ideal" locations on the forward pass -- e.g. + # pixel_coords_screen[n, 24, 31, k] = [24.5, 31.5]. For this reason, q_to_px_center + # should equal torch.zeros during the forward pass. On the backwards pass, these + # coordinates will be adjusted and non-zero, allowing the gradients to flow back + # to the mesh vertex coordinates. + q_to_px_center = ( + torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5 + ).view((N, H, W, K, 1, 2)) + + dist2_p_q = torch.sum((q_to_px_center + offsets) ** 2, dim=5) # (N, H, W, K, 9) + splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2)) + alpha = colors[..., 3:4] + splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze( + 5 + ) # (N, H, W, K, 9, 1) + + # splat_colors[n, h, w, direction, :] contains the splatting color (weighted by the + # splatting weight) that pixel h, w will splat in one of the nine possible + # directions (e.g. nhw0 corresponds to splatting in [-1, 1] direciton, nhw4 is + # self-splatting). + splat_colors = splat_weights * colors.unsqueeze(4) # (N, H, W, K, 9, 4) + + return torch.cat([splat_colors, splat_weights], dim=5) + + +def _offset_splats( + splat_colors_and_weights: torch.Tensor, + crop_ids_h: torch.Tensor, + crop_ids_w: torch.Tensor, +) -> torch.Tensor: + """ + Pad splatting colors and weights so that tensor locations/coordinates are aligned + with the splatting directions. For example, say we have an example input Red channel + splat_colors_and_weights[n, :, :, k, direction=0, channel=0] equal to + .1 .2 .3 + .4 .5 .6 + .7 .8 .9 + the (h, w) entry indicates that pixel n, h, w, k splats the given color in direction + equal to 0, which corresponds to offsets[0] = (-1, -1). Note that this is the x-y + direction, not h-w. This function pads and crops this array to + 0 0 0 + .2 .3 0 + .5 .6 0 + which indicates, for example, that: + * There is no pixel splatting in direction (-1, -1) whose splat lands on pixel + h=w=0. + * There is a pixel splatting in direction (-1, -1) whose splat lands on the pi- + xel h=1, w=0, and that pixel's splatting color is .2. + * There is a pixel splatting in direction (-1, -1) whose splat lands on the pi- + xel h=2, w=1, and that pixel's splatting color is .6. + + Args: + *splat_colors_and_weights*: (N, H, W, K, 9, 5) tensor of colors and weights, + where dim=-2 corresponds to the splatting directions/offsets. + *crop_ids_h*: (N, H, W+2, K, 9, 5) precomputed tensor used for padding within + torch.gather. See _precompute for more info. + *crop_ids_w*: (N, H, W, K, 9, 5) precomputed tensor used for padding within + torch.gather. See _precompute for more info. + + + Returns: + *splat_colors_and_weights*: (N, H, W, K, 9, 5) tensor. + """ + N, H, W, K, _, _ = splat_colors_and_weights.shape + # Transform splat_colors such that each of the 9 layers (corresponding to + # the 9 splat offsets) is padded with 1 and shifted in the appropriate + # direction. E.g. splat_colors[n, :, :, 0] corresponds to the (-1, -1) + # offset, so will be padded with one rows of 1 on the right and have a + # single row clipped at the bottom, and splat_colors[n, :, :, 4] corrsponds + # to offset (0, 0) and will remain unchanged. + splat_colors_and_weights = F.pad( + splat_colors_and_weights, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0] + ) # N, H+2, W+2, 9, 5 + + # (N, H+2, W+2, K, 9, 5) -> (N, H, W+2, K, 9, 5) + splat_colors_and_weights = torch.gather( + splat_colors_and_weights, dim=1, index=crop_ids_h + ) + + # (N, H, W+2, K, 9, 5) -> (N, H, W, K, 9, 5) + splat_colors_and_weights = torch.gather( + splat_colors_and_weights, dim=2, index=crop_ids_w + ) + + return splat_colors_and_weights + + +def _compute_splatted_colors_and_weights( + occlusion_layers: torch.Tensor, # (N, H, W, 9) + splat_colors_and_weights: torch.Tensor, # (N, H, W, K, 9, 5) +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Accumulate splatted colors in background, surface and foreground occlusion buffers. + + Args: + occlusion_layers: (N, H, W, 9) tensor. See _compute_occlusion_layers. + splat_colors_and_weights: (N, H, W, K, 9, 5) tensor. See _offset_splats. + + Returns: + splatted_colors: (N, H, W, 4, 3) tensor. Last dimension corresponds to back- + ground, surface, and foreground splat colors. + splatted_weights: (N, H, W, 1, 3) tensor. Last dimension corresponds to back- + ground, surface, and foreground splat weights and is used for normalization. + + """ + N, H, W, K, _, _ = splat_colors_and_weights.shape + + # Create an occlusion mask, with the last dimension of length 3, corresponding to + # background/surface/foreground splatting. E.g. occlusion_layer_mask[n,h,w,k,d,0] is + # 1 if the pixel at hw is splatted from direction d such that the splatting pixel p + # is below the splatted pixel q (in the background); otherwise, the value is 0. + # occlusion_layer_mask[n,h,w,k,d,1] is 1 if the splatting pixel is at the same + # surface level as the splatted pixel q, and occlusion_layer_mask[n,h,w,k,d,2] is + # 1 only if the splatting pixel is in the foreground. + layer_ids = torch.arange(K, device=splat_colors_and_weights.device).view( + 1, 1, 1, K, 1 + ) + occlusion_layers = occlusion_layers.view(N, H, W, 1, 9) + occlusion_layer_mask = torch.stack( + [ + occlusion_layers > layer_ids, # (N, H, W, K, 9) + occlusion_layers == layer_ids, # (N, H, W, K, 9) + occlusion_layers < layer_ids, # (N, H, W, K, 9) + ], + dim=5, + ).float() # (N, H, W, K, 9, 3) + + # (N * H * W, 5, 9 * K) x (N * H * W, 9 * K, 3) -> (N * H * W, 5, 3) + splatted_colors_and_weights = torch.bmm( + splat_colors_and_weights.permute(0, 1, 2, 5, 3, 4).reshape( + (N * H * W, 5, K * 9) + ), + occlusion_layer_mask.reshape((N * H * W, K * 9, 3)), + ).reshape((N, H, W, 5, 3)) + + return ( + splatted_colors_and_weights[..., :4, :], + splatted_colors_and_weights[..., 4:5, :], + ) + + +def _normalize_and_compose_all_layers( + background_color: torch.Tensor, + splatted_colors_per_occlusion_layer: torch.Tensor, + splatted_weights_per_occlusion_layer: torch.Tensor, +) -> torch.Tensor: + """ + Normalize each bg/surface/fg buffer by its weight, and compose. + + Args: + background_color: (3) RGB tensor. + splatter_colors_per_occlusion_layer: (N, H, W, 4, 3) RGBA tensor, last dimension + corresponds to foreground, surface, and background splatting. + splatted_weights_per_occlusion_layer: (N, H, W, 1, 3) weight tensor. + + Returns: + output_colors: (N, H, W, 4) RGBA tensor. + """ + device = splatted_colors_per_occlusion_layer.device + + # Normalize each of bg/surface/fg splat layers separately. + normalization_scales = 1.0 / ( + torch.maximum( + splatted_weights_per_occlusion_layer, + torch.tensor([1.0], device=device), + ) + ) # (N, H, W, 1, 3) + + normalized_splatted_colors = ( + splatted_colors_per_occlusion_layer * normalization_scales + ) # (N, H, W, 4, 3) + + # Use alpha-compositing to compose the splat layers. + output_colors = torch.cat( + [background_color, torch.tensor([0.0], device=device)] + ) # (4), will broadcast to (N, H, W, 4) below. + + for occlusion_layer_id in (-1, -2, -3): + # Over-compose the bg, surface, and fg occlusion layers. Note that we already + # multiplied each pixel's RGBA by its own alpha as part of self-splatting in + # _compute_splatting_colors_and_weights, so we don't re-multiply by alpha here. + alpha = normalized_splatted_colors[..., 3:4, occlusion_layer_id] # (N, H, W, 1) + output_colors = ( + normalized_splatted_colors[..., occlusion_layer_id] + + (1.0 - alpha) * output_colors + ) + return output_colors + + +class SplatterBlender(torch.nn.Module): + def __init__( + self, + input_shape: Tuple[int, int, int, int], + device, + ): + """ + A splatting blender. See `forward` docs for details of the splatting mechanism. + + Args: + input_shape: Tuple (N, H, W, K) indicating the batch size, image height, + image width, and number of rasterized layers. Used to precompute + constant tensors that do not change as long as this tuple is unchanged. + """ + super().__init__() + self.crop_ids_h, self.crop_ids_w, self.offsets = _precompute( + input_shape, device + ) + + def forward( + self, + colors: torch.Tensor, + pixel_coords_cameras: torch.Tensor, + cameras: FoVPerspectiveCameras, + background_mask: torch.Tensor, + blend_params: BlendParams, + ) -> torch.Tensor: + """ + RGB blending using splatting, as proposed in [0]. + + Args: + colors: (N, H, W, K, 3) tensor of RGB colors at each h, w pixel location for + K intersection layers. + pixel_coords_cameras: (N, H, W, K, 3) tensor of pixel coordinates in the + camera frame of reference. It is *crucial* that these are computed by + interpolating triangle vertex positions using barycentric coordinates -- + this allows gradients to travel through pixel_coords_camera back to the + vertex positions. + cameras: Cameras object used to project pixel_coords_cameras screen coords. + background_mask: (N, H, W, K, 3) boolean tensor, True for bg pixels. A pixel + is considered "background" if no mesh triangle projects to it. This is + typically computed by the rasterizer. + blend_params: BlendParams, from which we use sigma (splatting kernel + variance) and background_color. + + Returns: + output_colors: (N, H, W, 4) tensor of RGBA values. The alpha layer is set to + fully transparent in the background. + + [0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable + Sampling". + """ + + # Our implementation has 6 stages. In the description below, we will call each + # pixel q and the 9 surrounding splatting pixels (including itself) p. + # 1. Use barycentrics to compute the position of each pixel in screen + # coordinates. These should exactly correspond to pixel centers during the + # forward pass, but can be shifted on backwards. This step allows gradients to + # travel to vertex coordinates, even if the rasterizer is non-differentiable. + # 2a. For each center pixel q, take each splatting p and decide whether it + # is on the same surface level as q, or in the background or foreground. + # 2b. For each center pixel q, compute the splatting weight of surrounding + # pixels p, and their splatting colors (which are just the original colors + # weighted by the splatting weights). + # 3. As a vectorization technicality, offset the tensors corresponding to + # the splatting p values in the nine directions, by padding each of nine + # splatting layers on the bottom/top, left/right. + # 4. Do the actual splatting, by accumulating the splatting colors of the + # surrounding p's for each pixel q. The weights get accumulated separately for + # p's that got assigned to the background/surface/foreground in Step 2a. + # 5. Normalize each the splatted bg/surface/fg colors for each q, and + # compose the resulting color maps. + # + # Note that it is crucial that in Step 1 we compute the pixel coordinates by in- + # terpolating triangle vertices using barycentric coords from the rasterizer. In + # our case, these pixel_coords_camera are computed by the shader and passed to + # this function to avoid re-computation. + + pixel_coords_screen, colors = _prepare_pixels_and_colors( + pixel_coords_cameras, colors, cameras, background_mask + ) # (N, H, W, K, 3) and (N, H, W, K, 4) + + occlusion_layers = _compute_occlusion_layers( + pixel_coords_screen[..., 2:3].squeeze(dim=-1) + ) # (N, H, W, 9) + + splat_colors_and_weights = _compute_splatting_colors_and_weights( + pixel_coords_screen, + colors, + blend_params.sigma, + self.offsets, + ) # (N, H, W, K, 9, 5) + + splat_colors_and_weights = _offset_splats( + splat_colors_and_weights, + self.crop_ids_h, + self.crop_ids_w, + ) # (N, H, W, K, 9, 5) + + ( + splatted_colors_per_occlusion_layer, + splatted_weights_per_occlusion_layer, + ) = _compute_splatted_colors_and_weights( + occlusion_layers, splat_colors_and_weights + ) # (N, H, W, 4, 3) and (N, H, W, 1, 3) + + output_colors = _normalize_and_compose_all_layers( + _get_background_color(blend_params, colors.device), + splatted_colors_per_occlusion_layer, + splatted_weights_per_occlusion_layer, + ) # (N, H, W, 4) + + return output_colors diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 5b033831..fdca0fdb 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -41,6 +41,7 @@ from pytorch3d.renderer.mesh.shader import ( HardPhongShader, SoftPhongShader, SoftSilhouetteShader, + SplatterPhongShader, TexturedSoftPhongShader, ) from pytorch3d.structures.meshes import ( @@ -325,6 +326,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): shader_tests = [ ShaderTest(HardPhongShader, "phong", "hard_phong"), ShaderTest(SoftPhongShader, "phong", "soft_phong"), + ShaderTest(SplatterPhongShader, "phong", "splatter_phong"), ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"), ShaderTest(HardFlatShader, "flat", "hard_flat"), ] diff --git a/tests/test_splatter_blend.py b/tests/test_splatter_blend.py new file mode 100644 index 00000000..caaa5d5f --- /dev/null +++ b/tests/test_splatter_blend.py @@ -0,0 +1,627 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.renderer.cameras import FoVPerspectiveCameras +from pytorch3d.renderer.splatter_blend import ( + _compute_occlusion_layers, + _compute_splatted_colors_and_weights, + _compute_splatting_colors_and_weights, + _get_splat_kernel_normalization, + _normalize_and_compose_all_layers, + _offset_splats, + _precompute, + _prepare_pixels_and_colors, +) + +offsets = torch.tensor( + [ + [-1, -1], + [-1, 0], + [-1, 1], + [0, -1], + [0, 0], + [0, 1], + [1, -1], + [1, 0], + [1, 1], + ], + device=torch.device("cpu"), +) + + +def compute_splatting_colors_and_weights_naive(pixel_coords_screen, colors, sigma): + normalizer = float(_get_splat_kernel_normalization(offsets)) + N, H, W, K, _ = colors.shape + splat_weights_and_colors = torch.zeros((N, H, W, K, 9, 5)) + for n in range(N): + for h in range(H): + for w in range(W): + for k in range(K): + q_xy = pixel_coords_screen[n, h, w, k] + q_to_px_center = torch.floor(q_xy) - q_xy + 0.5 + color = colors[n, h, w, k] + alpha = colors[n, h, w, k, 3:4] + for d in range(9): + dist_p_q = torch.sum((q_to_px_center + offsets[d]) ** 2) + splat_weight = ( + alpha * torch.exp(-dist_p_q / (2 * sigma**2)) * normalizer + ) + splat_color = splat_weight * color + splat_weights_and_colors[n, h, w, k, d, :4] = splat_color + splat_weights_and_colors[n, h, w, k, d, 4:5] = splat_weight + return splat_weights_and_colors + + +class TestPrecompute(TestCaseMixin, unittest.TestCase): + def setUp(self): + self.results_cpu = _precompute((2, 3, 4, 5), torch.device("cpu")) + self.results1_cpu = _precompute((1, 1, 1, 1), torch.device("cpu")) + + def test_offsets(self): + self.assertClose(self.results_cpu[2].shape, offsets.shape, atol=0) + self.assertClose(self.results_cpu[2], offsets, atol=0) + + # Offsets should be independent of input_size. + self.assertClose(self.results_cpu[2], self.results1_cpu[2], atol=0) + + def test_crops_h(self): + target_crops_h1 = torch.tensor( + [ + # chennels being offset: + # R G B A W(eight) + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [2, 2, 2, 2, 2], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [2, 2, 2, 2, 2], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [2, 2, 2, 2, 2], + ] + * 3, # 3 because we're aiming at (N, H, W+2, K, 9, 5) with W=1. + device=torch.device("cpu"), + ).reshape(1, 1, 3, 1, 9, 5) + self.assertClose(self.results1_cpu[0], target_crops_h1, atol=0) + + target_crops_h_base = target_crops_h1[0, 0, 0] + target_crops_h = torch.cat( + [target_crops_h_base, target_crops_h_base + 1, target_crops_h_base + 2], + dim=0, + ) + + # Check that we have the right shape, and (after broadcasting) it has the right + # values. These should be repeated (tiled) for each n and k. + self.assertClose( + self.results_cpu[0].shape, torch.tensor([2, 3, 6, 5, 9, 5]), atol=0 + ) + for n in range(2): + for w in range(6): + for k in range(5): + self.assertClose( + self.results_cpu[0][n, :, w, k], + target_crops_h, + ) + + def test_crops_w(self): + target_crops_w1 = torch.tensor( + [ + # chennels being offset: + # R G B A W(eight) + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [2, 2, 2, 2, 2], + [2, 2, 2, 2, 2], + [2, 2, 2, 2, 2], + ], + device=torch.device("cpu"), + ).reshape(1, 1, 1, 1, 9, 5) + self.assertClose(self.results1_cpu[1], target_crops_w1) + + target_crops_w_base = target_crops_w1[0, 0, 0] + target_crops_w = torch.cat( + [ + target_crops_w_base, + target_crops_w_base + 1, + target_crops_w_base + 2, + target_crops_w_base + 3, + ], + dim=0, + ) # Each w value needs an increment. + + # Check that we have the right shape, and (after broadcasting) it has the right + # values. These should be repeated (tiled) for each n and k. + self.assertClose(self.results_cpu[1].shape, torch.tensor([2, 3, 4, 5, 9, 5])) + for n in range(2): + for h in range(3): + for k in range(5): + self.assertClose( + self.results_cpu[1][n, h, :, k], + target_crops_w, + atol=0, + ) + + +class TestPreparPixelsAndColors(TestCaseMixin, unittest.TestCase): + def setUp(self): + self.device = torch.device("cpu") + N, H, W, K = 2, 3, 4, 5 + self.pixel_coords_cameras = torch.randn( + (N, H, W, K, 3), device=self.device, requires_grad=True + ) + self.colors_before = torch.rand((N, H, W, K, 3), device=self.device) + self.cameras = FoVPerspectiveCameras(device=self.device) + self.background_mask = torch.rand((N, H, W, K), device=self.device) < 0.5 + self.pixel_coords_screen, self.colors_after = _prepare_pixels_and_colors( + self.pixel_coords_cameras, + self.colors_before, + self.cameras, + self.background_mask, + ) + + def test_background_z(self): + self.assertTrue( + torch.all(self.pixel_coords_screen[..., 2][self.background_mask] == 1.0) + ) + + def test_background_alpha(self): + self.assertTrue( + torch.all(self.colors_after[..., 3][self.background_mask] == 0.0) + ) + + +class TestGetSplatKernelNormalization(TestCaseMixin, unittest.TestCase): + def test_splat_kernel_normalization(self): + self.assertAlmostEqual( + float(_get_splat_kernel_normalization(offsets)), 0.6503, places=3 + ) + self.assertAlmostEqual( + float(_get_splat_kernel_normalization(offsets, 0.01)), 1.05, places=3 + ) + with self.assertRaisesRegex(ValueError, "Only positive standard deviations"): + _get_splat_kernel_normalization(offsets, 0) + + +class TestComputeOcclusionLayers(TestCaseMixin, unittest.TestCase): + def test_single_layer(self): + # If there's only one layer, all splats must be on the surface level. + N, H, W, K = 2, 3, 4, 1 + q_depth = torch.rand(N, H, W, K) + occlusion_layers = _compute_occlusion_layers(q_depth) + self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0) + + def test_all_equal(self): + # If all q-vals are equal, then all splats must be on the surface level. + N, H, W, K = 2, 3, 4, 5 + q_depth = torch.ones((N, H, W, K)) * 0.1234 + occlusion_layers = _compute_occlusion_layers(q_depth) + self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0) + + def test_mid_to_top_level_splatting(self): + # Check that occlusion buffers get accumulated as expected when the splatting + # and splatted pixels are co-surface on different intersection layers. + # This test will make best sense with accompanying Fig. 4 from "Differentiable + # Surface Rendering via Non-differentiable Sampling" by Cole et al. + for direction, offset in enumerate(offsets): + if direction == 4: + continue # Skip self-splatting which is always co-surface. + + depths = torch.zeros(1, 3, 3, 3) + + # This is our q, the pixel splatted onto, in the center of the image. + depths[0, 1, 1] = torch.tensor([0.71, 0.8, 1.0]) + + # This is our p, the splatting pixel. + depths[0, offset[0] + 1, offset[1] + 1] = torch.tensor([0.5, 0.7, 0.9]) + + occlusion_layers = _compute_occlusion_layers(depths) + + # Check that we computed that it is the middle layer of p that is co- + # surface with q. (1, 1) is the id of q in the depth array, and offset_id + # is the id of p's direction w.r.t. q. + psurfaceid_onto_q = occlusion_layers[0, 1, 1, direction] + self.assertEqual(int(psurfaceid_onto_q), 1) + + # Conversely, if we swap p and q, we have a top-level splatting onto + # mid-level. offset + 1 is the id of p, and 8-offset_id is the id of + # q's direction w.r.t. p (e.g. if p is [-1, -1] w.r.t. q, then q is + # [1, 1] w.r.t. p; we use the ids of these two directions in the offsets + # array). + qsurfaceid_onto_p = occlusion_layers[ + 0, offset[0] + 1, offset[1] + 1, 8 - direction + ] + self.assertEqual(int(qsurfaceid_onto_p), -1) + + +class TestComputeSplattingColorsAndWeights(TestCaseMixin, unittest.TestCase): + def setUp(self): + self.N, self.H, self.W, self.K = 2, 3, 4, 5 + self.pixel_coords_screen = ( + torch.tile( + torch.stack( + torch.meshgrid( + torch.arange(self.H), torch.arange(self.W), indexing="ij" + ), + dim=-1, + ).reshape(1, self.H, self.W, 1, 2), + (self.N, 1, 1, self.K, 1), + ).float() + + 0.5 + ) + self.colors = torch.ones((self.N, self.H, self.W, self.K, 4)) + + def test_all_equal(self): + # If all colors are equal and on a regular grid, all weights and reweighted + # colors should be equal given a specific splatting direction. + splatting_colors_and_weights = _compute_splatting_colors_and_weights( + self.pixel_coords_screen, self.colors * 0.2345, sigma=0.5, offsets=offsets + ) + + # Splatting directly to the top/bottom/left/right should have the same strenght. + non_diag_splats = splatting_colors_and_weights[ + :, :, :, :, torch.tensor([1, 3, 5, 7]) + ] + + # Same for diagonal splats. + diag_splats = splatting_colors_and_weights[ + :, :, :, :, torch.tensor([0, 2, 6, 8]) + ] + + # And for self-splats. + self_splats = splatting_colors_and_weights[:, :, :, :, torch.tensor([4])] + + for splats in non_diag_splats, diag_splats, self_splats: + # Colors should be equal. + self.assertTrue(torch.all(splats[..., :4] == splats[0, 0, 0, 0, 0, 0])) + + # Weights should be equal. + self.assertTrue(torch.all(splats[..., 4] == splats[0, 0, 0, 0, 0, 4])) + + # Non-diagonal weights should be greater than diagonal weights. + self.assertGreater( + non_diag_splats[0, 0, 0, 0, 0, 0], diag_splats[0, 0, 0, 0, 0, 0] + ) + + # Self-splats should be strongest of all. + self.assertGreater( + self_splats[0, 0, 0, 0, 0, 0], non_diag_splats[0, 0, 0, 0, 0, 0] + ) + + # Splatting colors should be reweighted proportionally to their splat weights. + diag_self_color_ratio = ( + diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0] + ) + diag_self_weight_ratio = ( + diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4] + ) + self.assertEqual(diag_self_color_ratio, diag_self_weight_ratio) + + non_diag_self_color_ratio = ( + non_diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0] + ) + non_diag_self_weight_ratio = ( + non_diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4] + ) + self.assertEqual(non_diag_self_color_ratio, non_diag_self_weight_ratio) + + def test_zero_alpha_zero_weight(self): + # Pixels with zero alpha do no splatting, but should still be splatted on. + colors = self.colors.clone() + colors[0, 1, 1, 0, 3] = 0.0 + splatting_colors_and_weights = _compute_splatting_colors_and_weights( + self.pixel_coords_screen, colors, sigma=0.5, offsets=offsets + ) + + # The transparent pixel should do no splatting. + self.assertTrue(torch.all(splatting_colors_and_weights[0, 1, 1, 0] == 0.0)) + + # Splatting *onto* the transparent pixel should be unaffected. + reference_weights_colors = splatting_colors_and_weights[0, 1, 1, 1] + for direction, offset in enumerate(offsets): + if direction == 4: + continue # Ignore self-splats + # We invert the direction to get the right (h, w, d) coordinate of each + # pixel splatting *onto* the pixel with zero alpha. + self.assertClose( + splatting_colors_and_weights[ + 0, 1 + offset[0], 1 + offset[1], 0, 8 - direction + ], + reference_weights_colors[8 - direction], + atol=0.001, + ) + + def test_random_inputs(self): + pixel_coords_screen = ( + self.pixel_coords_screen + + torch.randn((self.N, self.H, self.W, self.K, 2)) * 0.1 + ) + colors = torch.rand((self.N, self.H, self.W, self.K, 4)) + splatting_colors_and_weights = _compute_splatting_colors_and_weights( + pixel_coords_screen, colors, sigma=0.5, offsets=offsets + ) + naive_colors_and_weights = compute_splatting_colors_and_weights_naive( + pixel_coords_screen, colors, sigma=0.5 + ) + + self.assertClose( + splatting_colors_and_weights, naive_colors_and_weights, atol=0.01 + ) + + +class TestOffsetSplats(TestCaseMixin, unittest.TestCase): + def test_offset(self): + device = torch.device("cuda:0") + N, H, W, K = 2, 3, 4, 5 + colors_and_weights = torch.rand((N, H, W, K, 9, 5), device=device) + crop_ids_h, crop_ids_w, _ = _precompute((N, H, W, K), device=device) + offset_colors_and_weights = _offset_splats( + colors_and_weights, crop_ids_h, crop_ids_w + ) + + # Check each splatting direction individually, for clarity. + # offset_x, offset_y = (-1, -1) + direction = 0 + self.assertClose( + offset_colors_and_weights[:, 1:, 1:, :, direction], + colors_and_weights[:, :-1, :-1, :, direction], + atol=0.001, + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0) + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0) + ) + + # offset_x, offset_y = (-1, 0) + direction = 1 + self.assertClose( + offset_colors_and_weights[:, :, 1:, :, direction], + colors_and_weights[:, :, :-1, :, direction], + atol=0.001, + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0) + ) + + # offset_x, offset_y = (-1, 1) + direction = 2 + self.assertClose( + offset_colors_and_weights[:, :-1, 1:, :, direction], + colors_and_weights[:, 1:, :-1, :, direction], + atol=0.001, + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0) + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0) + ) + + # offset_x, offset_y = (0, -1) + direction = 3 + self.assertClose( + offset_colors_and_weights[:, 1:, :, :, direction], + colors_and_weights[:, :-1, :, :, direction], + atol=0.001, + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0) + ) + + # self-splat + direction = 4 + self.assertClose( + offset_colors_and_weights[..., direction, :], + colors_and_weights[..., direction, :], + atol=0.001, + ) + + # offset_x, offset_y = (0, 1) + direction = 5 + self.assertClose( + offset_colors_and_weights[:, :-1, :, :, direction], + colors_and_weights[:, 1:, :, :, direction], + atol=0.001, + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0) + ) + + # offset_x, offset_y = (1, -1) + direction = 6 + self.assertClose( + offset_colors_and_weights[:, 1:, :-1, :, direction], + colors_and_weights[:, :-1, 1:, :, direction], + atol=0.001, + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0) + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0) + ) + + # offset_x, offset_y = (1, 0) + direction = 7 + self.assertClose( + offset_colors_and_weights[:, :, :-1, :, direction], + colors_and_weights[:, :, 1:, :, direction], + atol=0.001, + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0) + ) + + # offset_x, offset_y = (1, 1) + direction = 8 + self.assertClose( + offset_colors_and_weights[:, :-1, :-1, :, direction], + colors_and_weights[:, 1:, 1:, :, direction], + atol=0.001, + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0) + ) + self.assertTrue( + torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0) + ) + + +class TestComputeSplattedColorsAndWeights(TestCaseMixin, unittest.TestCase): + def test_accumulation_background(self): + # Set occlusion_layers to all -1, so all splats are background splats. + splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5)) + occlusion_layers = torch.zeros((1, 1, 1, 9)) - 1 + splatted_colors, splatted_weights = _compute_splatted_colors_and_weights( + occlusion_layers, splat_colors_and_weights + ) + + # Foreground splats (there are none). + self.assertClose( + splatted_colors[0, 0, 0, :, 0], + torch.zeros((4)), + atol=0.001, + ) + + # Surface splats (there are none). + self.assertClose( + splatted_colors[0, 0, 0, :, 1], + torch.zeros((4)), + atol=0.001, + ) + + # Background splats. + self.assertClose( + splatted_colors[0, 0, 0, :, 2], + splat_colors_and_weights[0, 0, 0, :, :, :4].sum(dim=0).sum(dim=0), + atol=0.001, + ) + + def test_accumulation_middle(self): + # Set occlusion_layers to all 0, so top splats are co-surface with splatted + # pixels. Thus, the top splatting layer should be accumulated to surface, and + # all other layers to background. + splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5)) + occlusion_layers = torch.zeros((1, 1, 1, 9)) + splatted_colors, splatted_weights = _compute_splatted_colors_and_weights( + occlusion_layers, splat_colors_and_weights + ) + + # Foreground splats (there are none). + self.assertClose( + splatted_colors[0, 0, 0, :, 0], + torch.zeros((4)), + atol=0.001, + ) + + # Surface splats + self.assertClose( + splatted_colors[0, 0, 0, :, 1], + splat_colors_and_weights[0, 0, 0, 0, :, :4].sum(dim=0), + atol=0.001, + ) + + # Background splats + self.assertClose( + splatted_colors[0, 0, 0, :, 2], + splat_colors_and_weights[0, 0, 0, 1:, :, :4].sum(dim=0).sum(dim=0), + atol=0.001, + ) + + def test_accumulation_foreground(self): + # Set occlusion_layers to all 1. Then the top splatter is a foreground + # splatter, mid splatter is surface, and bottom splatter is background. + splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5)) + occlusion_layers = torch.zeros((1, 1, 1, 9)) + 1 + splatted_colors, splatted_weights = _compute_splatted_colors_and_weights( + occlusion_layers, splat_colors_and_weights + ) + + # Foreground splats + self.assertClose( + splatted_colors[0, 0, 0, :, 0], + splat_colors_and_weights[0, 0, 0, 0:1, :, :4].sum(dim=0).sum(dim=0), + atol=0.001, + ) + + # Surface splats + self.assertClose( + splatted_colors[0, 0, 0, :, 1], + splat_colors_and_weights[0, 0, 0, 1:2, :, :4].sum(dim=0).sum(dim=0), + atol=0.001, + ) + + # Background splats + self.assertClose( + splatted_colors[0, 0, 0, :, 2], + splat_colors_and_weights[0, 0, 0, 2:3, :, :4].sum(dim=0).sum(dim=0), + atol=0.001, + ) + + +class TestNormalizeAndComposeAllLayers(TestCaseMixin, unittest.TestCase): + def test_background_color(self): + # Background should always have alpha=0, and the chosen RGB. + N, H, W = 2, 3, 4 + # Make a mask with background in the zeroth row of the first image. + bg_mask = torch.zeros([N, H, W, 1, 1]) + bg_mask[0, :, 0] = 1 + + bg_color = torch.tensor([0.2, 0.3, 0.4]) + + color_layers = torch.rand((N, H, W, 4, 3)) * (1 - bg_mask) + color_weights = torch.rand((N, H, W, 1, 3)) * (1 - bg_mask) + + colors = _normalize_and_compose_all_layers( + bg_color, color_layers, color_weights + ) + + # Background RGB should be .2, .3, .4, and alpha should be 0. + self.assertClose( + torch.masked_select(colors, bg_mask.bool()[..., 0]), + torch.tensor([0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0.0]), + atol=0.001, + ) + + def test_compositing_opaque(self): + # When all colors are opaque, only the foreground layer should be visible. + N, H, W = 2, 3, 4 + color_layers = torch.rand((N, H, W, 4, 3)) + color_layers[..., 3, :] = 1.0 + color_weights = torch.ones((N, H, W, 1, 3)) + + out_colors = _normalize_and_compose_all_layers( + torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights + ) + self.assertClose(out_colors, color_layers[..., 0], atol=0.001) + + def test_compositing_transparencies(self): + # When foreground layer is transparent and surface and bg are semi-transparent, + # we should return a mix of the two latter. + N, H, W = 2, 3, 4 + color_layers = torch.rand((N, H, W, 4, 3)) + color_layers[..., 3, 0] = 0.1 # fg + color_layers[..., 3, 1] = 0.2 # surface + color_layers[..., 3, 2] = 0.3 # bg + color_weights = torch.ones((N, H, W, 1, 3)) + + out_colors = _normalize_and_compose_all_layers( + torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights + ) + self.assertClose( + out_colors, + color_layers[..., 0] + + 0.9 * (color_layers[..., 1] + 0.8 * color_layers[..., 2]), + )