mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
SplatterBlender
Summary: Splatting shader. See code comments for details. Same API as SoftPhongShader. Reviewed By: jcjohnson Differential Revision: D36354301 fbshipit-source-id: 71ee37f7ff6bb9ce028ba42a65741424a427a92d
This commit is contained in:
parent
1702c85bec
commit
c5a83f46ef
@ -57,6 +57,7 @@ from .mesh import (
|
||||
SoftGouraudShader,
|
||||
SoftPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
SplatterPhongShader,
|
||||
Textures,
|
||||
TexturesAtlas,
|
||||
TexturesUV,
|
||||
@ -71,6 +72,7 @@ from .points import (
|
||||
PulsarPointsRenderer,
|
||||
rasterize_points,
|
||||
)
|
||||
from .splatter_blend import SplatterBlender
|
||||
from .utils import (
|
||||
convert_to_tensors_and_broadcast,
|
||||
ndc_grid_sample,
|
||||
|
@ -4,14 +4,12 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import NamedTuple, Sequence, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.common.datatypes import Device
|
||||
|
||||
|
||||
# Example functions for blending the top K colors per pixel using the outputs
|
||||
# from rasterization.
|
||||
# NOTE: All blending function should return an RGBA image per batch element
|
||||
@ -22,10 +20,12 @@ class BlendParams(NamedTuple):
|
||||
Data class to store blending params with defaults
|
||||
|
||||
Members:
|
||||
sigma (float): Controls the width of the sigmoid function used to
|
||||
calculate the 2D distance based probability. Determines the
|
||||
sharpness of the edges of the shape.
|
||||
Higher => faces have less defined edges.
|
||||
sigma (float): For SoftmaxPhong, controls the width of the sigmoid
|
||||
function used to calculate the 2D distance based probability. Determines
|
||||
the sharpness of the edges of the shape. Higher => faces have less defined
|
||||
edges. For SplatterPhong, this is the standard deviation of the Gaussian
|
||||
kernel. Higher => splats have a stronger effect and the rendered image is
|
||||
more blurry.
|
||||
gamma (float): Controls the scaling of the exponential function used
|
||||
to set the opacity of the color.
|
||||
Higher => faces are more transparent.
|
||||
@ -36,6 +36,7 @@ class BlendParams(NamedTuple):
|
||||
sigma: float = 1e-4
|
||||
gamma: float = 1e-4
|
||||
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
|
||||
background_alpha: float = 0.0
|
||||
|
||||
|
||||
def _get_background_color(
|
||||
|
@ -22,6 +22,7 @@ from .shader import ( # DEPRECATED
|
||||
SoftGouraudShader,
|
||||
SoftPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
SplatterPhongShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from .shading import gouraud_shading, phong_shading
|
||||
|
@ -20,9 +20,15 @@ from ..blending import (
|
||||
)
|
||||
from ..lighting import PointLights
|
||||
from ..materials import Materials
|
||||
from ..splatter_blend import SplatterBlender
|
||||
from ..utils import TensorProperties
|
||||
from .rasterizer import Fragments
|
||||
from .shading import flat_shading, gouraud_shading, phong_shading
|
||||
from .shading import (
|
||||
_phong_shading_with_pixels,
|
||||
flat_shading,
|
||||
gouraud_shading,
|
||||
phong_shading,
|
||||
)
|
||||
|
||||
|
||||
# A Shader should take as input fragments from the output of rasterization
|
||||
@ -308,3 +314,64 @@ class SoftSilhouetteShader(nn.Module):
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
images = sigmoid_alpha_blend(colors, fragments, blend_params)
|
||||
return images
|
||||
|
||||
|
||||
class SplatterPhongShader(ShaderBase):
|
||||
"""
|
||||
Per pixel lighting - the lighting model is applied using the interpolated
|
||||
coordinates and normals for each pixel. The blending function returns the
|
||||
color aggregated using splats from surrounding pixels (see [0]).
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = SplatterPhongShader(device=torch.device("cuda:0"))
|
||||
|
||||
Args:
|
||||
detach_rasterizer: If True, stop gradients from flowing through the rasterizer.
|
||||
This simulates the pipeline in [0] which uses a non-differentiable OpenGL
|
||||
rasterizer.
|
||||
|
||||
[0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
|
||||
Sampling".
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.splatter_blender = None
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of SplatterPhongShader."
|
||||
raise ValueError(msg)
|
||||
texels = meshes.sample_textures(fragments)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
|
||||
colors, pixel_coords_cameras = _phong_shading_with_pixels(
|
||||
meshes=meshes,
|
||||
fragments=fragments.detach(),
|
||||
texels=texels,
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
|
||||
if not self.splatter_blender:
|
||||
# Init only once, to avoid re-computing constants.
|
||||
N, H, W, K, _ = colors.shape
|
||||
self.splatter_blender = SplatterBlender((N, H, W, K), colors.device)
|
||||
|
||||
images = self.splatter_blender(
|
||||
colors,
|
||||
pixel_coords_cameras,
|
||||
cameras,
|
||||
fragments.pix_to_face < 0,
|
||||
self.blend_params,
|
||||
)
|
||||
|
||||
return images
|
||||
|
553
pytorch3d/renderer/splatter_blend.py
Normal file
553
pytorch3d/renderer/splatter_blend.py
Normal file
@ -0,0 +1,553 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.common.datatypes import Device
|
||||
from pytorch3d.renderer import BlendParams
|
||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras
|
||||
|
||||
from .blending import _get_background_color
|
||||
|
||||
|
||||
def _precompute(
|
||||
input_shape: Tuple[int, int, int, int], device: Device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Precompute padding and offset constants that won't change for a given NHWK shape.
|
||||
|
||||
Args:
|
||||
input_shape: Tuple indicating N (batch size), H, W (image size) and K (number of
|
||||
intersections) output by the rasterizer.
|
||||
device: Device to store the tensors on.
|
||||
|
||||
returns:
|
||||
crop_ids_h: An (N, H, W+2, K, 9, 5) tensor, used during splatting to offset the
|
||||
p-pixels (splatting pixels) in one of the 9 splatting directions within a
|
||||
call to torch.gather. See comments and offset_splats for details.
|
||||
crop_ids_w: An (N, H, W, K, 9, 5) tensor, used similarly to crop_ids_h.
|
||||
offsets: A (1, 1, 1, 1, 9, 2) tensor (shaped so for broadcasting) containing va-
|
||||
lues [-1, -1], [-1, 0], [-1, 1], [0, -1], ..., [1, 1] which correspond to
|
||||
the nine splatting directions.
|
||||
"""
|
||||
N, H, W, K = input_shape
|
||||
|
||||
# (N, H, W+2, K, 9, 5) tensor, used to reduce a tensor from (N, H+2, W+2...) to
|
||||
# (N, H, W+2, ...) in torch.gather. If only torch.gather broadcasted, we wouldn't
|
||||
# need the tiling. But it doesn't.
|
||||
crop_ids_h = (
|
||||
torch.arange(0, H, device=device).view(1, H, 1, 1, 1, 1)
|
||||
+ torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], device=device).view(
|
||||
1, 1, 1, 1, 9, 1
|
||||
)
|
||||
).expand(N, H, W + 2, K, 9, 5)
|
||||
|
||||
# (N, H, W, K, 9, 5) tensor, used to reduce a tensor from (N, H, W+2, ...) to
|
||||
# (N, H, W, ...) in torch.gather.
|
||||
crop_ids_w = (
|
||||
torch.arange(0, W, device=device).view(1, 1, W, 1, 1, 1)
|
||||
+ torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device=device).view(
|
||||
1, 1, 1, 1, 9, 1
|
||||
)
|
||||
).expand(N, H, W, K, 9, 5)
|
||||
|
||||
offsets = torch.tensor(
|
||||
list(itertools.product((-1, 0, 1), repeat=2)),
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return crop_ids_h, crop_ids_w, offsets
|
||||
|
||||
|
||||
def _prepare_pixels_and_colors(
|
||||
pixel_coords_cameras: torch.Tensor,
|
||||
colors: torch.Tensor,
|
||||
cameras: FoVPerspectiveCameras,
|
||||
background_mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Project pixel coords into the un-inverted screen frame of reference, and set
|
||||
background pixel z-values to 1.0 and alphas to 0.0.
|
||||
|
||||
Args:
|
||||
pixel_coords_cameras: (N, H, W, K, 3) float tensor.
|
||||
colors: (N, H, W, K, 3) float tensor.
|
||||
cameras: PyTorch3D cameras, for now we assume FoVPerspectiveCameras.
|
||||
background_mask: (N, H, W, K) boolean tensor.
|
||||
|
||||
Returns:
|
||||
pixel_coords_screen: (N, H, W, K, 3) float tensor. Background pixels have
|
||||
x=y=z=1.0.
|
||||
colors: (N, H, W, K, 4). Alpha is set to 1 for foreground pixels and 0 for back-
|
||||
ground pixels.
|
||||
"""
|
||||
|
||||
N, H, W, K, C = colors.shape
|
||||
# pixel_coords_screen will contain invalid values at background
|
||||
# intersections, and [H+0.5, W+0.5, z] at valid intersections. It is important
|
||||
# to not flip the xy axes, otherwise the gradients will be inverted when the
|
||||
# splatter works with a detached rasterizer.
|
||||
pixel_coords_screen = cameras.transform_points_screen(
|
||||
pixel_coords_cameras.view([N, -1, 3]), image_size=(H, W), with_xyflip=False
|
||||
).reshape(pixel_coords_cameras.shape)
|
||||
|
||||
# Set colors' alpha to 1 and background to 0.
|
||||
colors = torch.cat(
|
||||
[colors, torch.ones_like(colors[..., :1])], dim=-1
|
||||
) # (N, H, W, K, 4)
|
||||
|
||||
# The hw values of background don't matter because their alpha is set
|
||||
# to 0 in the next step (which means that no matter what their splatting kernel
|
||||
# value is, they will not splat as the kernel is multiplied by alpha). However,
|
||||
# their z-values need to be at max depth. Otherwise, we could incorrectly compute
|
||||
# occlusion layer linkage.
|
||||
pixel_coords_screen[background_mask] = 1.0
|
||||
|
||||
# Any background color value value with alpha=0 will do, as anything with
|
||||
# alpha=0 will have a zero-weight splatting power. Note that neighbors can still
|
||||
# splat on zero-alpha pixels: that's the way we get non-zero gradients at the
|
||||
# boundary with the background.
|
||||
colors[background_mask] = 0.0
|
||||
|
||||
return pixel_coords_screen, colors
|
||||
|
||||
|
||||
def _get_splat_kernel_normalization(
|
||||
offsets: torch.Tensor,
|
||||
sigma: float = 0.5,
|
||||
):
|
||||
if sigma <= 0.0:
|
||||
raise ValueError("Only positive standard deviations make sense.")
|
||||
|
||||
epsilon = 0.05
|
||||
normalization_constant = torch.exp(
|
||||
-(offsets**2).sum(dim=1) / (2 * sigma**2)
|
||||
).sum()
|
||||
|
||||
# We add an epsilon to the normalization constant to ensure the gradient will travel
|
||||
# through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia-
|
||||
# ble Surface Rendering via Non-Differentiable Sampling", Cole et al.
|
||||
return (1 + epsilon) / normalization_constant
|
||||
|
||||
|
||||
def _compute_occlusion_layers(
|
||||
q_depth: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
For each splatting pixel, decide whether it splats from a background, surface, or
|
||||
foreground depth relative to the splatted pixel. See unit tests in
|
||||
test_splatter_blend for some enlightening examples.
|
||||
|
||||
Args:
|
||||
q_depth: (N, H, W, K) tensor of z-values of the splatted pixels.
|
||||
|
||||
Returns:
|
||||
occlusion_layers: (N, H, W, 9) long tensor. Each of the 9 values corresponds to
|
||||
one of the nine splatting directions ([-1, -1], [-1, 0], ..., [1,
|
||||
1]). The value at nhwd (where d is the splatting direction) is 0 if
|
||||
the splat in direction d is on the same surface level as the pixel at
|
||||
hw. The value is negative if the splat is in the background (occluded
|
||||
by another splat above it that is at the same surface level as the
|
||||
pixel splatted on), and the value is positive if the splat is in the
|
||||
foreground.
|
||||
"""
|
||||
N, H, W, K = q_depth.shape
|
||||
|
||||
# q are the "center pixels" and p the pixels splatting onto them. Use `unfold` to
|
||||
# create `p_depth`, a tensor with 9 layers, each of which corresponds to the
|
||||
# depth of a neighbor of q in one of the 9 directions. For example, p_depth[nk0hw]
|
||||
# is the depth of the pixel splatting onto pixel nhwk from the [-1, -1] direction,
|
||||
# and p_depth[nk4hw] the depth of q (self-splatting onto itself).
|
||||
# More concretely, imagine the pixel depths in a 2x2 image's k-th layer are
|
||||
# .1 .2
|
||||
# .3 .4
|
||||
# Then (remembering that we pad with zeros when a pixel has fewer than 9 neighbors):
|
||||
#
|
||||
# p_depth[n, k, :, 0, 0] = [ 0 0 0 0 .1 .2 0 .3 .4] - neighbors of .1
|
||||
# p_depth[n, k, :, 0, 1] = [ 0 0 0 .1 .2 0 .3 .4 0] - neighbors of .2
|
||||
# p_depth[n, k, :, 1, 0] = [ 0 .1 .2 0 .3 .4 0 0 0] - neighbors of .3
|
||||
# p_depth[n, k, :, 0, 1] = [.1 .2 0 .3 .4 0 0 0 0] - neighbors of .4
|
||||
q_depth = q_depth.permute(0, 3, 1, 2) # (N, K, H, W)
|
||||
p_depth = F.unfold(q_depth, kernel_size=3, padding=1) # (N, 3^2 * K, H * W)
|
||||
q_depth = q_depth.view(N, K, 1, H, W)
|
||||
p_depth = p_depth.view(N, K, 9, H, W)
|
||||
|
||||
# Take the center pixel q's top rasterization layer. This is the "surface layer"
|
||||
# that we're splatting on. For each of the nine splatting directions p, find which
|
||||
# of the K splatting rasterization layers is closest in depth to the surface
|
||||
# splatted layer.
|
||||
qtop_to_p_zdist = torch.abs(p_depth - q_depth[:, 0:1]) # (N, K, 9, H, W)
|
||||
qtop_to_p_closest_zdist, qtop_to_p_closest_id = qtop_to_p_zdist.min(dim=1)
|
||||
|
||||
# For each of the nine splatting directions p, take the top of the K rasterization
|
||||
# layers. Check which of the K q-layers (that the given direction is splatting on)
|
||||
# is closest in depth to the top splatting layer.
|
||||
ptop_to_q_zdist = torch.abs(p_depth[:, 0:1] - q_depth) # (N, K, 9, H, W)
|
||||
ptop_to_q_closest_zdist, ptop_to_q_closest_id = ptop_to_q_zdist.min(dim=1)
|
||||
|
||||
# Decide whether each p is on the same level, below, or above the q it is splatting
|
||||
# on. See Fig. 4 in [0] for an illustration. Briefly: say we're interested in pixel
|
||||
# p_{h, w} = [10, 32] splatting onto its neighbor q_{h, w} = [11, 33]. The splat is
|
||||
# coming from direction [-1, -1], which has index 0 in our enumeration of splatting
|
||||
# directions. Hence, we are interested in
|
||||
#
|
||||
# P = p_depth[n, :, d=0, 11, 33] - a vector of K depth values, and
|
||||
# Q = q_depth.squeeze()[n, :, 11, 33] - a vector of K depth values.
|
||||
#
|
||||
# If Q[0] is closest, say, to P[2], then we assume the 0th surface layer of Q is
|
||||
# the same surface as P[2] that's splatting onto it, and P[:2] are foreground splats
|
||||
# and P[3:] are background splats.
|
||||
#
|
||||
# If instead say Q[2] is closest to P[0], then all the splats are background splats,
|
||||
# because the top splatting layer is the same surface as a non-top splatted layer.
|
||||
#
|
||||
# Finally, if Q[0] is closest to P[0], then the top-level P is splatting onto top-
|
||||
# level Q, and P[1:] are all background splats.
|
||||
occlusion_offsets = torch.where( # noqa
|
||||
ptop_to_q_closest_zdist < qtop_to_p_closest_zdist,
|
||||
-ptop_to_q_closest_id,
|
||||
qtop_to_p_closest_id,
|
||||
) # (N, 9, H, W)
|
||||
|
||||
occlusion_layers = occlusion_offsets.permute((0, 2, 3, 1)) # (N, H, W, 9)
|
||||
return occlusion_layers
|
||||
|
||||
|
||||
def _compute_splatting_colors_and_weights(
|
||||
pixel_coords_screen: torch.Tensor,
|
||||
colors: torch.Tensor,
|
||||
sigma: float,
|
||||
offsets: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
For each center pixel q, compute the splatting weights of its surrounding nine spla-
|
||||
tting pixels p, as well as their splatting colors (which are just their colors re-
|
||||
weighted by the splatting weights).
|
||||
|
||||
Args:
|
||||
pixel_coords_screen: (N, H, W, K, 2) tensor of pixel screen coords.
|
||||
colors: (N, H, W, K, 4) RGBA tensor of pixel colors.
|
||||
sigma: splatting kernel variance.
|
||||
offsets: (9, 2) tensor computed by _precompute, indicating the nine
|
||||
splatting directions ([-1, -1], ..., [1, 1]).
|
||||
|
||||
Returns:
|
||||
splat_colors_and_weights: (N, H, W, K, 9, 5) tensor.
|
||||
splat_colors_and_weights[..., :4] corresponds to the splatting colors, and
|
||||
splat_colors_and_weights[..., 4:5] to the splatting weights. The "9" di-
|
||||
mension corresponds to the nine splatting directions.
|
||||
"""
|
||||
N, H, W, K, C = colors.shape
|
||||
splat_kernel_normalization = _get_splat_kernel_normalization(offsets, sigma)
|
||||
|
||||
# Distance from each barycentric-interpolated triangle vertices' triplet from its
|
||||
# "ideal" pixel-center location. pixel_coords_screen are in screen coordinates, and
|
||||
# should be at the "ideal" locations on the forward pass -- e.g.
|
||||
# pixel_coords_screen[n, 24, 31, k] = [24.5, 31.5]. For this reason, q_to_px_center
|
||||
# should equal torch.zeros during the forward pass. On the backwards pass, these
|
||||
# coordinates will be adjusted and non-zero, allowing the gradients to flow back
|
||||
# to the mesh vertex coordinates.
|
||||
q_to_px_center = (
|
||||
torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5
|
||||
).view((N, H, W, K, 1, 2))
|
||||
|
||||
dist2_p_q = torch.sum((q_to_px_center + offsets) ** 2, dim=5) # (N, H, W, K, 9)
|
||||
splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2))
|
||||
alpha = colors[..., 3:4]
|
||||
splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze(
|
||||
5
|
||||
) # (N, H, W, K, 9, 1)
|
||||
|
||||
# splat_colors[n, h, w, direction, :] contains the splatting color (weighted by the
|
||||
# splatting weight) that pixel h, w will splat in one of the nine possible
|
||||
# directions (e.g. nhw0 corresponds to splatting in [-1, 1] direciton, nhw4 is
|
||||
# self-splatting).
|
||||
splat_colors = splat_weights * colors.unsqueeze(4) # (N, H, W, K, 9, 4)
|
||||
|
||||
return torch.cat([splat_colors, splat_weights], dim=5)
|
||||
|
||||
|
||||
def _offset_splats(
|
||||
splat_colors_and_weights: torch.Tensor,
|
||||
crop_ids_h: torch.Tensor,
|
||||
crop_ids_w: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad splatting colors and weights so that tensor locations/coordinates are aligned
|
||||
with the splatting directions. For example, say we have an example input Red channel
|
||||
splat_colors_and_weights[n, :, :, k, direction=0, channel=0] equal to
|
||||
.1 .2 .3
|
||||
.4 .5 .6
|
||||
.7 .8 .9
|
||||
the (h, w) entry indicates that pixel n, h, w, k splats the given color in direction
|
||||
equal to 0, which corresponds to offsets[0] = (-1, -1). Note that this is the x-y
|
||||
direction, not h-w. This function pads and crops this array to
|
||||
0 0 0
|
||||
.2 .3 0
|
||||
.5 .6 0
|
||||
which indicates, for example, that:
|
||||
* There is no pixel splatting in direction (-1, -1) whose splat lands on pixel
|
||||
h=w=0.
|
||||
* There is a pixel splatting in direction (-1, -1) whose splat lands on the pi-
|
||||
xel h=1, w=0, and that pixel's splatting color is .2.
|
||||
* There is a pixel splatting in direction (-1, -1) whose splat lands on the pi-
|
||||
xel h=2, w=1, and that pixel's splatting color is .6.
|
||||
|
||||
Args:
|
||||
*splat_colors_and_weights*: (N, H, W, K, 9, 5) tensor of colors and weights,
|
||||
where dim=-2 corresponds to the splatting directions/offsets.
|
||||
*crop_ids_h*: (N, H, W+2, K, 9, 5) precomputed tensor used for padding within
|
||||
torch.gather. See _precompute for more info.
|
||||
*crop_ids_w*: (N, H, W, K, 9, 5) precomputed tensor used for padding within
|
||||
torch.gather. See _precompute for more info.
|
||||
|
||||
|
||||
Returns:
|
||||
*splat_colors_and_weights*: (N, H, W, K, 9, 5) tensor.
|
||||
"""
|
||||
N, H, W, K, _, _ = splat_colors_and_weights.shape
|
||||
# Transform splat_colors such that each of the 9 layers (corresponding to
|
||||
# the 9 splat offsets) is padded with 1 and shifted in the appropriate
|
||||
# direction. E.g. splat_colors[n, :, :, 0] corresponds to the (-1, -1)
|
||||
# offset, so will be padded with one rows of 1 on the right and have a
|
||||
# single row clipped at the bottom, and splat_colors[n, :, :, 4] corrsponds
|
||||
# to offset (0, 0) and will remain unchanged.
|
||||
splat_colors_and_weights = F.pad(
|
||||
splat_colors_and_weights, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]
|
||||
) # N, H+2, W+2, 9, 5
|
||||
|
||||
# (N, H+2, W+2, K, 9, 5) -> (N, H, W+2, K, 9, 5)
|
||||
splat_colors_and_weights = torch.gather(
|
||||
splat_colors_and_weights, dim=1, index=crop_ids_h
|
||||
)
|
||||
|
||||
# (N, H, W+2, K, 9, 5) -> (N, H, W, K, 9, 5)
|
||||
splat_colors_and_weights = torch.gather(
|
||||
splat_colors_and_weights, dim=2, index=crop_ids_w
|
||||
)
|
||||
|
||||
return splat_colors_and_weights
|
||||
|
||||
|
||||
def _compute_splatted_colors_and_weights(
|
||||
occlusion_layers: torch.Tensor, # (N, H, W, 9)
|
||||
splat_colors_and_weights: torch.Tensor, # (N, H, W, K, 9, 5)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Accumulate splatted colors in background, surface and foreground occlusion buffers.
|
||||
|
||||
Args:
|
||||
occlusion_layers: (N, H, W, 9) tensor. See _compute_occlusion_layers.
|
||||
splat_colors_and_weights: (N, H, W, K, 9, 5) tensor. See _offset_splats.
|
||||
|
||||
Returns:
|
||||
splatted_colors: (N, H, W, 4, 3) tensor. Last dimension corresponds to back-
|
||||
ground, surface, and foreground splat colors.
|
||||
splatted_weights: (N, H, W, 1, 3) tensor. Last dimension corresponds to back-
|
||||
ground, surface, and foreground splat weights and is used for normalization.
|
||||
|
||||
"""
|
||||
N, H, W, K, _, _ = splat_colors_and_weights.shape
|
||||
|
||||
# Create an occlusion mask, with the last dimension of length 3, corresponding to
|
||||
# background/surface/foreground splatting. E.g. occlusion_layer_mask[n,h,w,k,d,0] is
|
||||
# 1 if the pixel at hw is splatted from direction d such that the splatting pixel p
|
||||
# is below the splatted pixel q (in the background); otherwise, the value is 0.
|
||||
# occlusion_layer_mask[n,h,w,k,d,1] is 1 if the splatting pixel is at the same
|
||||
# surface level as the splatted pixel q, and occlusion_layer_mask[n,h,w,k,d,2] is
|
||||
# 1 only if the splatting pixel is in the foreground.
|
||||
layer_ids = torch.arange(K, device=splat_colors_and_weights.device).view(
|
||||
1, 1, 1, K, 1
|
||||
)
|
||||
occlusion_layers = occlusion_layers.view(N, H, W, 1, 9)
|
||||
occlusion_layer_mask = torch.stack(
|
||||
[
|
||||
occlusion_layers > layer_ids, # (N, H, W, K, 9)
|
||||
occlusion_layers == layer_ids, # (N, H, W, K, 9)
|
||||
occlusion_layers < layer_ids, # (N, H, W, K, 9)
|
||||
],
|
||||
dim=5,
|
||||
).float() # (N, H, W, K, 9, 3)
|
||||
|
||||
# (N * H * W, 5, 9 * K) x (N * H * W, 9 * K, 3) -> (N * H * W, 5, 3)
|
||||
splatted_colors_and_weights = torch.bmm(
|
||||
splat_colors_and_weights.permute(0, 1, 2, 5, 3, 4).reshape(
|
||||
(N * H * W, 5, K * 9)
|
||||
),
|
||||
occlusion_layer_mask.reshape((N * H * W, K * 9, 3)),
|
||||
).reshape((N, H, W, 5, 3))
|
||||
|
||||
return (
|
||||
splatted_colors_and_weights[..., :4, :],
|
||||
splatted_colors_and_weights[..., 4:5, :],
|
||||
)
|
||||
|
||||
|
||||
def _normalize_and_compose_all_layers(
|
||||
background_color: torch.Tensor,
|
||||
splatted_colors_per_occlusion_layer: torch.Tensor,
|
||||
splatted_weights_per_occlusion_layer: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Normalize each bg/surface/fg buffer by its weight, and compose.
|
||||
|
||||
Args:
|
||||
background_color: (3) RGB tensor.
|
||||
splatter_colors_per_occlusion_layer: (N, H, W, 4, 3) RGBA tensor, last dimension
|
||||
corresponds to foreground, surface, and background splatting.
|
||||
splatted_weights_per_occlusion_layer: (N, H, W, 1, 3) weight tensor.
|
||||
|
||||
Returns:
|
||||
output_colors: (N, H, W, 4) RGBA tensor.
|
||||
"""
|
||||
device = splatted_colors_per_occlusion_layer.device
|
||||
|
||||
# Normalize each of bg/surface/fg splat layers separately.
|
||||
normalization_scales = 1.0 / (
|
||||
torch.maximum(
|
||||
splatted_weights_per_occlusion_layer,
|
||||
torch.tensor([1.0], device=device),
|
||||
)
|
||||
) # (N, H, W, 1, 3)
|
||||
|
||||
normalized_splatted_colors = (
|
||||
splatted_colors_per_occlusion_layer * normalization_scales
|
||||
) # (N, H, W, 4, 3)
|
||||
|
||||
# Use alpha-compositing to compose the splat layers.
|
||||
output_colors = torch.cat(
|
||||
[background_color, torch.tensor([0.0], device=device)]
|
||||
) # (4), will broadcast to (N, H, W, 4) below.
|
||||
|
||||
for occlusion_layer_id in (-1, -2, -3):
|
||||
# Over-compose the bg, surface, and fg occlusion layers. Note that we already
|
||||
# multiplied each pixel's RGBA by its own alpha as part of self-splatting in
|
||||
# _compute_splatting_colors_and_weights, so we don't re-multiply by alpha here.
|
||||
alpha = normalized_splatted_colors[..., 3:4, occlusion_layer_id] # (N, H, W, 1)
|
||||
output_colors = (
|
||||
normalized_splatted_colors[..., occlusion_layer_id]
|
||||
+ (1.0 - alpha) * output_colors
|
||||
)
|
||||
return output_colors
|
||||
|
||||
|
||||
class SplatterBlender(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_shape: Tuple[int, int, int, int],
|
||||
device,
|
||||
):
|
||||
"""
|
||||
A splatting blender. See `forward` docs for details of the splatting mechanism.
|
||||
|
||||
Args:
|
||||
input_shape: Tuple (N, H, W, K) indicating the batch size, image height,
|
||||
image width, and number of rasterized layers. Used to precompute
|
||||
constant tensors that do not change as long as this tuple is unchanged.
|
||||
"""
|
||||
super().__init__()
|
||||
self.crop_ids_h, self.crop_ids_w, self.offsets = _precompute(
|
||||
input_shape, device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
colors: torch.Tensor,
|
||||
pixel_coords_cameras: torch.Tensor,
|
||||
cameras: FoVPerspectiveCameras,
|
||||
background_mask: torch.Tensor,
|
||||
blend_params: BlendParams,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
RGB blending using splatting, as proposed in [0].
|
||||
|
||||
Args:
|
||||
colors: (N, H, W, K, 3) tensor of RGB colors at each h, w pixel location for
|
||||
K intersection layers.
|
||||
pixel_coords_cameras: (N, H, W, K, 3) tensor of pixel coordinates in the
|
||||
camera frame of reference. It is *crucial* that these are computed by
|
||||
interpolating triangle vertex positions using barycentric coordinates --
|
||||
this allows gradients to travel through pixel_coords_camera back to the
|
||||
vertex positions.
|
||||
cameras: Cameras object used to project pixel_coords_cameras screen coords.
|
||||
background_mask: (N, H, W, K, 3) boolean tensor, True for bg pixels. A pixel
|
||||
is considered "background" if no mesh triangle projects to it. This is
|
||||
typically computed by the rasterizer.
|
||||
blend_params: BlendParams, from which we use sigma (splatting kernel
|
||||
variance) and background_color.
|
||||
|
||||
Returns:
|
||||
output_colors: (N, H, W, 4) tensor of RGBA values. The alpha layer is set to
|
||||
fully transparent in the background.
|
||||
|
||||
[0] Cole, F. et al., "Differentiable Surface Rendering via Non-differentiable
|
||||
Sampling".
|
||||
"""
|
||||
|
||||
# Our implementation has 6 stages. In the description below, we will call each
|
||||
# pixel q and the 9 surrounding splatting pixels (including itself) p.
|
||||
# 1. Use barycentrics to compute the position of each pixel in screen
|
||||
# coordinates. These should exactly correspond to pixel centers during the
|
||||
# forward pass, but can be shifted on backwards. This step allows gradients to
|
||||
# travel to vertex coordinates, even if the rasterizer is non-differentiable.
|
||||
# 2a. For each center pixel q, take each splatting p and decide whether it
|
||||
# is on the same surface level as q, or in the background or foreground.
|
||||
# 2b. For each center pixel q, compute the splatting weight of surrounding
|
||||
# pixels p, and their splatting colors (which are just the original colors
|
||||
# weighted by the splatting weights).
|
||||
# 3. As a vectorization technicality, offset the tensors corresponding to
|
||||
# the splatting p values in the nine directions, by padding each of nine
|
||||
# splatting layers on the bottom/top, left/right.
|
||||
# 4. Do the actual splatting, by accumulating the splatting colors of the
|
||||
# surrounding p's for each pixel q. The weights get accumulated separately for
|
||||
# p's that got assigned to the background/surface/foreground in Step 2a.
|
||||
# 5. Normalize each the splatted bg/surface/fg colors for each q, and
|
||||
# compose the resulting color maps.
|
||||
#
|
||||
# Note that it is crucial that in Step 1 we compute the pixel coordinates by in-
|
||||
# terpolating triangle vertices using barycentric coords from the rasterizer. In
|
||||
# our case, these pixel_coords_camera are computed by the shader and passed to
|
||||
# this function to avoid re-computation.
|
||||
|
||||
pixel_coords_screen, colors = _prepare_pixels_and_colors(
|
||||
pixel_coords_cameras, colors, cameras, background_mask
|
||||
) # (N, H, W, K, 3) and (N, H, W, K, 4)
|
||||
|
||||
occlusion_layers = _compute_occlusion_layers(
|
||||
pixel_coords_screen[..., 2:3].squeeze(dim=-1)
|
||||
) # (N, H, W, 9)
|
||||
|
||||
splat_colors_and_weights = _compute_splatting_colors_and_weights(
|
||||
pixel_coords_screen,
|
||||
colors,
|
||||
blend_params.sigma,
|
||||
self.offsets,
|
||||
) # (N, H, W, K, 9, 5)
|
||||
|
||||
splat_colors_and_weights = _offset_splats(
|
||||
splat_colors_and_weights,
|
||||
self.crop_ids_h,
|
||||
self.crop_ids_w,
|
||||
) # (N, H, W, K, 9, 5)
|
||||
|
||||
(
|
||||
splatted_colors_per_occlusion_layer,
|
||||
splatted_weights_per_occlusion_layer,
|
||||
) = _compute_splatted_colors_and_weights(
|
||||
occlusion_layers, splat_colors_and_weights
|
||||
) # (N, H, W, 4, 3) and (N, H, W, 1, 3)
|
||||
|
||||
output_colors = _normalize_and_compose_all_layers(
|
||||
_get_background_color(blend_params, colors.device),
|
||||
splatted_colors_per_occlusion_layer,
|
||||
splatted_weights_per_occlusion_layer,
|
||||
) # (N, H, W, 4)
|
||||
|
||||
return output_colors
|
@ -41,6 +41,7 @@ from pytorch3d.renderer.mesh.shader import (
|
||||
HardPhongShader,
|
||||
SoftPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
SplatterPhongShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from pytorch3d.structures.meshes import (
|
||||
@ -325,6 +326,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
shader_tests = [
|
||||
ShaderTest(HardPhongShader, "phong", "hard_phong"),
|
||||
ShaderTest(SoftPhongShader, "phong", "soft_phong"),
|
||||
ShaderTest(SplatterPhongShader, "phong", "splatter_phong"),
|
||||
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
|
||||
ShaderTest(HardFlatShader, "flat", "hard_flat"),
|
||||
]
|
||||
|
627
tests/test_splatter_blend.py
Normal file
627
tests/test_splatter_blend.py
Normal file
@ -0,0 +1,627 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras
|
||||
from pytorch3d.renderer.splatter_blend import (
|
||||
_compute_occlusion_layers,
|
||||
_compute_splatted_colors_and_weights,
|
||||
_compute_splatting_colors_and_weights,
|
||||
_get_splat_kernel_normalization,
|
||||
_normalize_and_compose_all_layers,
|
||||
_offset_splats,
|
||||
_precompute,
|
||||
_prepare_pixels_and_colors,
|
||||
)
|
||||
|
||||
offsets = torch.tensor(
|
||||
[
|
||||
[-1, -1],
|
||||
[-1, 0],
|
||||
[-1, 1],
|
||||
[0, -1],
|
||||
[0, 0],
|
||||
[0, 1],
|
||||
[1, -1],
|
||||
[1, 0],
|
||||
[1, 1],
|
||||
],
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def compute_splatting_colors_and_weights_naive(pixel_coords_screen, colors, sigma):
|
||||
normalizer = float(_get_splat_kernel_normalization(offsets))
|
||||
N, H, W, K, _ = colors.shape
|
||||
splat_weights_and_colors = torch.zeros((N, H, W, K, 9, 5))
|
||||
for n in range(N):
|
||||
for h in range(H):
|
||||
for w in range(W):
|
||||
for k in range(K):
|
||||
q_xy = pixel_coords_screen[n, h, w, k]
|
||||
q_to_px_center = torch.floor(q_xy) - q_xy + 0.5
|
||||
color = colors[n, h, w, k]
|
||||
alpha = colors[n, h, w, k, 3:4]
|
||||
for d in range(9):
|
||||
dist_p_q = torch.sum((q_to_px_center + offsets[d]) ** 2)
|
||||
splat_weight = (
|
||||
alpha * torch.exp(-dist_p_q / (2 * sigma**2)) * normalizer
|
||||
)
|
||||
splat_color = splat_weight * color
|
||||
splat_weights_and_colors[n, h, w, k, d, :4] = splat_color
|
||||
splat_weights_and_colors[n, h, w, k, d, 4:5] = splat_weight
|
||||
return splat_weights_and_colors
|
||||
|
||||
|
||||
class TestPrecompute(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.results_cpu = _precompute((2, 3, 4, 5), torch.device("cpu"))
|
||||
self.results1_cpu = _precompute((1, 1, 1, 1), torch.device("cpu"))
|
||||
|
||||
def test_offsets(self):
|
||||
self.assertClose(self.results_cpu[2].shape, offsets.shape, atol=0)
|
||||
self.assertClose(self.results_cpu[2], offsets, atol=0)
|
||||
|
||||
# Offsets should be independent of input_size.
|
||||
self.assertClose(self.results_cpu[2], self.results1_cpu[2], atol=0)
|
||||
|
||||
def test_crops_h(self):
|
||||
target_crops_h1 = torch.tensor(
|
||||
[
|
||||
# chennels being offset:
|
||||
# R G B A W(eight)
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
]
|
||||
* 3, # 3 because we're aiming at (N, H, W+2, K, 9, 5) with W=1.
|
||||
device=torch.device("cpu"),
|
||||
).reshape(1, 1, 3, 1, 9, 5)
|
||||
self.assertClose(self.results1_cpu[0], target_crops_h1, atol=0)
|
||||
|
||||
target_crops_h_base = target_crops_h1[0, 0, 0]
|
||||
target_crops_h = torch.cat(
|
||||
[target_crops_h_base, target_crops_h_base + 1, target_crops_h_base + 2],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Check that we have the right shape, and (after broadcasting) it has the right
|
||||
# values. These should be repeated (tiled) for each n and k.
|
||||
self.assertClose(
|
||||
self.results_cpu[0].shape, torch.tensor([2, 3, 6, 5, 9, 5]), atol=0
|
||||
)
|
||||
for n in range(2):
|
||||
for w in range(6):
|
||||
for k in range(5):
|
||||
self.assertClose(
|
||||
self.results_cpu[0][n, :, w, k],
|
||||
target_crops_h,
|
||||
)
|
||||
|
||||
def test_crops_w(self):
|
||||
target_crops_w1 = torch.tensor(
|
||||
[
|
||||
# chennels being offset:
|
||||
# R G B A W(eight)
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[2, 2, 2, 2, 2],
|
||||
[2, 2, 2, 2, 2],
|
||||
],
|
||||
device=torch.device("cpu"),
|
||||
).reshape(1, 1, 1, 1, 9, 5)
|
||||
self.assertClose(self.results1_cpu[1], target_crops_w1)
|
||||
|
||||
target_crops_w_base = target_crops_w1[0, 0, 0]
|
||||
target_crops_w = torch.cat(
|
||||
[
|
||||
target_crops_w_base,
|
||||
target_crops_w_base + 1,
|
||||
target_crops_w_base + 2,
|
||||
target_crops_w_base + 3,
|
||||
],
|
||||
dim=0,
|
||||
) # Each w value needs an increment.
|
||||
|
||||
# Check that we have the right shape, and (after broadcasting) it has the right
|
||||
# values. These should be repeated (tiled) for each n and k.
|
||||
self.assertClose(self.results_cpu[1].shape, torch.tensor([2, 3, 4, 5, 9, 5]))
|
||||
for n in range(2):
|
||||
for h in range(3):
|
||||
for k in range(5):
|
||||
self.assertClose(
|
||||
self.results_cpu[1][n, h, :, k],
|
||||
target_crops_w,
|
||||
atol=0,
|
||||
)
|
||||
|
||||
|
||||
class TestPreparPixelsAndColors(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.device = torch.device("cpu")
|
||||
N, H, W, K = 2, 3, 4, 5
|
||||
self.pixel_coords_cameras = torch.randn(
|
||||
(N, H, W, K, 3), device=self.device, requires_grad=True
|
||||
)
|
||||
self.colors_before = torch.rand((N, H, W, K, 3), device=self.device)
|
||||
self.cameras = FoVPerspectiveCameras(device=self.device)
|
||||
self.background_mask = torch.rand((N, H, W, K), device=self.device) < 0.5
|
||||
self.pixel_coords_screen, self.colors_after = _prepare_pixels_and_colors(
|
||||
self.pixel_coords_cameras,
|
||||
self.colors_before,
|
||||
self.cameras,
|
||||
self.background_mask,
|
||||
)
|
||||
|
||||
def test_background_z(self):
|
||||
self.assertTrue(
|
||||
torch.all(self.pixel_coords_screen[..., 2][self.background_mask] == 1.0)
|
||||
)
|
||||
|
||||
def test_background_alpha(self):
|
||||
self.assertTrue(
|
||||
torch.all(self.colors_after[..., 3][self.background_mask] == 0.0)
|
||||
)
|
||||
|
||||
|
||||
class TestGetSplatKernelNormalization(TestCaseMixin, unittest.TestCase):
|
||||
def test_splat_kernel_normalization(self):
|
||||
self.assertAlmostEqual(
|
||||
float(_get_splat_kernel_normalization(offsets)), 0.6503, places=3
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
float(_get_splat_kernel_normalization(offsets, 0.01)), 1.05, places=3
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, "Only positive standard deviations"):
|
||||
_get_splat_kernel_normalization(offsets, 0)
|
||||
|
||||
|
||||
class TestComputeOcclusionLayers(TestCaseMixin, unittest.TestCase):
|
||||
def test_single_layer(self):
|
||||
# If there's only one layer, all splats must be on the surface level.
|
||||
N, H, W, K = 2, 3, 4, 1
|
||||
q_depth = torch.rand(N, H, W, K)
|
||||
occlusion_layers = _compute_occlusion_layers(q_depth)
|
||||
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
|
||||
|
||||
def test_all_equal(self):
|
||||
# If all q-vals are equal, then all splats must be on the surface level.
|
||||
N, H, W, K = 2, 3, 4, 5
|
||||
q_depth = torch.ones((N, H, W, K)) * 0.1234
|
||||
occlusion_layers = _compute_occlusion_layers(q_depth)
|
||||
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
|
||||
|
||||
def test_mid_to_top_level_splatting(self):
|
||||
# Check that occlusion buffers get accumulated as expected when the splatting
|
||||
# and splatted pixels are co-surface on different intersection layers.
|
||||
# This test will make best sense with accompanying Fig. 4 from "Differentiable
|
||||
# Surface Rendering via Non-differentiable Sampling" by Cole et al.
|
||||
for direction, offset in enumerate(offsets):
|
||||
if direction == 4:
|
||||
continue # Skip self-splatting which is always co-surface.
|
||||
|
||||
depths = torch.zeros(1, 3, 3, 3)
|
||||
|
||||
# This is our q, the pixel splatted onto, in the center of the image.
|
||||
depths[0, 1, 1] = torch.tensor([0.71, 0.8, 1.0])
|
||||
|
||||
# This is our p, the splatting pixel.
|
||||
depths[0, offset[0] + 1, offset[1] + 1] = torch.tensor([0.5, 0.7, 0.9])
|
||||
|
||||
occlusion_layers = _compute_occlusion_layers(depths)
|
||||
|
||||
# Check that we computed that it is the middle layer of p that is co-
|
||||
# surface with q. (1, 1) is the id of q in the depth array, and offset_id
|
||||
# is the id of p's direction w.r.t. q.
|
||||
psurfaceid_onto_q = occlusion_layers[0, 1, 1, direction]
|
||||
self.assertEqual(int(psurfaceid_onto_q), 1)
|
||||
|
||||
# Conversely, if we swap p and q, we have a top-level splatting onto
|
||||
# mid-level. offset + 1 is the id of p, and 8-offset_id is the id of
|
||||
# q's direction w.r.t. p (e.g. if p is [-1, -1] w.r.t. q, then q is
|
||||
# [1, 1] w.r.t. p; we use the ids of these two directions in the offsets
|
||||
# array).
|
||||
qsurfaceid_onto_p = occlusion_layers[
|
||||
0, offset[0] + 1, offset[1] + 1, 8 - direction
|
||||
]
|
||||
self.assertEqual(int(qsurfaceid_onto_p), -1)
|
||||
|
||||
|
||||
class TestComputeSplattingColorsAndWeights(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.N, self.H, self.W, self.K = 2, 3, 4, 5
|
||||
self.pixel_coords_screen = (
|
||||
torch.tile(
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(self.H), torch.arange(self.W), indexing="ij"
|
||||
),
|
||||
dim=-1,
|
||||
).reshape(1, self.H, self.W, 1, 2),
|
||||
(self.N, 1, 1, self.K, 1),
|
||||
).float()
|
||||
+ 0.5
|
||||
)
|
||||
self.colors = torch.ones((self.N, self.H, self.W, self.K, 4))
|
||||
|
||||
def test_all_equal(self):
|
||||
# If all colors are equal and on a regular grid, all weights and reweighted
|
||||
# colors should be equal given a specific splatting direction.
|
||||
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
||||
self.pixel_coords_screen, self.colors * 0.2345, sigma=0.5, offsets=offsets
|
||||
)
|
||||
|
||||
# Splatting directly to the top/bottom/left/right should have the same strenght.
|
||||
non_diag_splats = splatting_colors_and_weights[
|
||||
:, :, :, :, torch.tensor([1, 3, 5, 7])
|
||||
]
|
||||
|
||||
# Same for diagonal splats.
|
||||
diag_splats = splatting_colors_and_weights[
|
||||
:, :, :, :, torch.tensor([0, 2, 6, 8])
|
||||
]
|
||||
|
||||
# And for self-splats.
|
||||
self_splats = splatting_colors_and_weights[:, :, :, :, torch.tensor([4])]
|
||||
|
||||
for splats in non_diag_splats, diag_splats, self_splats:
|
||||
# Colors should be equal.
|
||||
self.assertTrue(torch.all(splats[..., :4] == splats[0, 0, 0, 0, 0, 0]))
|
||||
|
||||
# Weights should be equal.
|
||||
self.assertTrue(torch.all(splats[..., 4] == splats[0, 0, 0, 0, 0, 4]))
|
||||
|
||||
# Non-diagonal weights should be greater than diagonal weights.
|
||||
self.assertGreater(
|
||||
non_diag_splats[0, 0, 0, 0, 0, 0], diag_splats[0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
|
||||
# Self-splats should be strongest of all.
|
||||
self.assertGreater(
|
||||
self_splats[0, 0, 0, 0, 0, 0], non_diag_splats[0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
|
||||
# Splatting colors should be reweighted proportionally to their splat weights.
|
||||
diag_self_color_ratio = (
|
||||
diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
diag_self_weight_ratio = (
|
||||
diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
|
||||
)
|
||||
self.assertEqual(diag_self_color_ratio, diag_self_weight_ratio)
|
||||
|
||||
non_diag_self_color_ratio = (
|
||||
non_diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
non_diag_self_weight_ratio = (
|
||||
non_diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
|
||||
)
|
||||
self.assertEqual(non_diag_self_color_ratio, non_diag_self_weight_ratio)
|
||||
|
||||
def test_zero_alpha_zero_weight(self):
|
||||
# Pixels with zero alpha do no splatting, but should still be splatted on.
|
||||
colors = self.colors.clone()
|
||||
colors[0, 1, 1, 0, 3] = 0.0
|
||||
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
||||
self.pixel_coords_screen, colors, sigma=0.5, offsets=offsets
|
||||
)
|
||||
|
||||
# The transparent pixel should do no splatting.
|
||||
self.assertTrue(torch.all(splatting_colors_and_weights[0, 1, 1, 0] == 0.0))
|
||||
|
||||
# Splatting *onto* the transparent pixel should be unaffected.
|
||||
reference_weights_colors = splatting_colors_and_weights[0, 1, 1, 1]
|
||||
for direction, offset in enumerate(offsets):
|
||||
if direction == 4:
|
||||
continue # Ignore self-splats
|
||||
# We invert the direction to get the right (h, w, d) coordinate of each
|
||||
# pixel splatting *onto* the pixel with zero alpha.
|
||||
self.assertClose(
|
||||
splatting_colors_and_weights[
|
||||
0, 1 + offset[0], 1 + offset[1], 0, 8 - direction
|
||||
],
|
||||
reference_weights_colors[8 - direction],
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_random_inputs(self):
|
||||
pixel_coords_screen = (
|
||||
self.pixel_coords_screen
|
||||
+ torch.randn((self.N, self.H, self.W, self.K, 2)) * 0.1
|
||||
)
|
||||
colors = torch.rand((self.N, self.H, self.W, self.K, 4))
|
||||
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
||||
pixel_coords_screen, colors, sigma=0.5, offsets=offsets
|
||||
)
|
||||
naive_colors_and_weights = compute_splatting_colors_and_weights_naive(
|
||||
pixel_coords_screen, colors, sigma=0.5
|
||||
)
|
||||
|
||||
self.assertClose(
|
||||
splatting_colors_and_weights, naive_colors_and_weights, atol=0.01
|
||||
)
|
||||
|
||||
|
||||
class TestOffsetSplats(TestCaseMixin, unittest.TestCase):
|
||||
def test_offset(self):
|
||||
device = torch.device("cuda:0")
|
||||
N, H, W, K = 2, 3, 4, 5
|
||||
colors_and_weights = torch.rand((N, H, W, K, 9, 5), device=device)
|
||||
crop_ids_h, crop_ids_w, _ = _precompute((N, H, W, K), device=device)
|
||||
offset_colors_and_weights = _offset_splats(
|
||||
colors_and_weights, crop_ids_h, crop_ids_w
|
||||
)
|
||||
|
||||
# Check each splatting direction individually, for clarity.
|
||||
# offset_x, offset_y = (-1, -1)
|
||||
direction = 0
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, 1:, 1:, :, direction],
|
||||
colors_and_weights[:, :-1, :-1, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (-1, 0)
|
||||
direction = 1
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :, 1:, :, direction],
|
||||
colors_and_weights[:, :, :-1, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (-1, 1)
|
||||
direction = 2
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :-1, 1:, :, direction],
|
||||
colors_and_weights[:, 1:, :-1, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (0, -1)
|
||||
direction = 3
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, 1:, :, :, direction],
|
||||
colors_and_weights[:, :-1, :, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# self-splat
|
||||
direction = 4
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[..., direction, :],
|
||||
colors_and_weights[..., direction, :],
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (0, 1)
|
||||
direction = 5
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :-1, :, :, direction],
|
||||
colors_and_weights[:, 1:, :, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (1, -1)
|
||||
direction = 6
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, 1:, :-1, :, direction],
|
||||
colors_and_weights[:, :-1, 1:, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (1, 0)
|
||||
direction = 7
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :, :-1, :, direction],
|
||||
colors_and_weights[:, :, 1:, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (1, 1)
|
||||
direction = 8
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :-1, :-1, :, direction],
|
||||
colors_and_weights[:, 1:, 1:, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
|
||||
class TestComputeSplattedColorsAndWeights(TestCaseMixin, unittest.TestCase):
|
||||
def test_accumulation_background(self):
|
||||
# Set occlusion_layers to all -1, so all splats are background splats.
|
||||
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
||||
occlusion_layers = torch.zeros((1, 1, 1, 9)) - 1
|
||||
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
||||
occlusion_layers, splat_colors_and_weights
|
||||
)
|
||||
|
||||
# Foreground splats (there are none).
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 0],
|
||||
torch.zeros((4)),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Surface splats (there are none).
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 1],
|
||||
torch.zeros((4)),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Background splats.
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 2],
|
||||
splat_colors_and_weights[0, 0, 0, :, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_accumulation_middle(self):
|
||||
# Set occlusion_layers to all 0, so top splats are co-surface with splatted
|
||||
# pixels. Thus, the top splatting layer should be accumulated to surface, and
|
||||
# all other layers to background.
|
||||
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
||||
occlusion_layers = torch.zeros((1, 1, 1, 9))
|
||||
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
||||
occlusion_layers, splat_colors_and_weights
|
||||
)
|
||||
|
||||
# Foreground splats (there are none).
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 0],
|
||||
torch.zeros((4)),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Surface splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 1],
|
||||
splat_colors_and_weights[0, 0, 0, 0, :, :4].sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Background splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 2],
|
||||
splat_colors_and_weights[0, 0, 0, 1:, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_accumulation_foreground(self):
|
||||
# Set occlusion_layers to all 1. Then the top splatter is a foreground
|
||||
# splatter, mid splatter is surface, and bottom splatter is background.
|
||||
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
||||
occlusion_layers = torch.zeros((1, 1, 1, 9)) + 1
|
||||
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
||||
occlusion_layers, splat_colors_and_weights
|
||||
)
|
||||
|
||||
# Foreground splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 0],
|
||||
splat_colors_and_weights[0, 0, 0, 0:1, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Surface splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 1],
|
||||
splat_colors_and_weights[0, 0, 0, 1:2, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Background splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 2],
|
||||
splat_colors_and_weights[0, 0, 0, 2:3, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeAndComposeAllLayers(TestCaseMixin, unittest.TestCase):
|
||||
def test_background_color(self):
|
||||
# Background should always have alpha=0, and the chosen RGB.
|
||||
N, H, W = 2, 3, 4
|
||||
# Make a mask with background in the zeroth row of the first image.
|
||||
bg_mask = torch.zeros([N, H, W, 1, 1])
|
||||
bg_mask[0, :, 0] = 1
|
||||
|
||||
bg_color = torch.tensor([0.2, 0.3, 0.4])
|
||||
|
||||
color_layers = torch.rand((N, H, W, 4, 3)) * (1 - bg_mask)
|
||||
color_weights = torch.rand((N, H, W, 1, 3)) * (1 - bg_mask)
|
||||
|
||||
colors = _normalize_and_compose_all_layers(
|
||||
bg_color, color_layers, color_weights
|
||||
)
|
||||
|
||||
# Background RGB should be .2, .3, .4, and alpha should be 0.
|
||||
self.assertClose(
|
||||
torch.masked_select(colors, bg_mask.bool()[..., 0]),
|
||||
torch.tensor([0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0.0]),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_compositing_opaque(self):
|
||||
# When all colors are opaque, only the foreground layer should be visible.
|
||||
N, H, W = 2, 3, 4
|
||||
color_layers = torch.rand((N, H, W, 4, 3))
|
||||
color_layers[..., 3, :] = 1.0
|
||||
color_weights = torch.ones((N, H, W, 1, 3))
|
||||
|
||||
out_colors = _normalize_and_compose_all_layers(
|
||||
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
|
||||
)
|
||||
self.assertClose(out_colors, color_layers[..., 0], atol=0.001)
|
||||
|
||||
def test_compositing_transparencies(self):
|
||||
# When foreground layer is transparent and surface and bg are semi-transparent,
|
||||
# we should return a mix of the two latter.
|
||||
N, H, W = 2, 3, 4
|
||||
color_layers = torch.rand((N, H, W, 4, 3))
|
||||
color_layers[..., 3, 0] = 0.1 # fg
|
||||
color_layers[..., 3, 1] = 0.2 # surface
|
||||
color_layers[..., 3, 2] = 0.3 # bg
|
||||
color_weights = torch.ones((N, H, W, 1, 3))
|
||||
|
||||
out_colors = _normalize_and_compose_all_layers(
|
||||
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
|
||||
)
|
||||
self.assertClose(
|
||||
out_colors,
|
||||
color_layers[..., 0]
|
||||
+ 0.9 * (color_layers[..., 1] + 0.8 * color_layers[..., 2]),
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user