mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
SplatterBlender
Summary: Splatting shader. See code comments for details. Same API as SoftPhongShader. Reviewed By: jcjohnson Differential Revision: D36354301 fbshipit-source-id: 71ee37f7ff6bb9ce028ba42a65741424a427a92d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1702c85bec
commit
c5a83f46ef
@@ -41,6 +41,7 @@ from pytorch3d.renderer.mesh.shader import (
|
||||
HardPhongShader,
|
||||
SoftPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
SplatterPhongShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from pytorch3d.structures.meshes import (
|
||||
@@ -325,6 +326,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
shader_tests = [
|
||||
ShaderTest(HardPhongShader, "phong", "hard_phong"),
|
||||
ShaderTest(SoftPhongShader, "phong", "soft_phong"),
|
||||
ShaderTest(SplatterPhongShader, "phong", "splatter_phong"),
|
||||
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
|
||||
ShaderTest(HardFlatShader, "flat", "hard_flat"),
|
||||
]
|
||||
|
||||
627
tests/test_splatter_blend.py
Normal file
627
tests/test_splatter_blend.py
Normal file
@@ -0,0 +1,627 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras
|
||||
from pytorch3d.renderer.splatter_blend import (
|
||||
_compute_occlusion_layers,
|
||||
_compute_splatted_colors_and_weights,
|
||||
_compute_splatting_colors_and_weights,
|
||||
_get_splat_kernel_normalization,
|
||||
_normalize_and_compose_all_layers,
|
||||
_offset_splats,
|
||||
_precompute,
|
||||
_prepare_pixels_and_colors,
|
||||
)
|
||||
|
||||
offsets = torch.tensor(
|
||||
[
|
||||
[-1, -1],
|
||||
[-1, 0],
|
||||
[-1, 1],
|
||||
[0, -1],
|
||||
[0, 0],
|
||||
[0, 1],
|
||||
[1, -1],
|
||||
[1, 0],
|
||||
[1, 1],
|
||||
],
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def compute_splatting_colors_and_weights_naive(pixel_coords_screen, colors, sigma):
|
||||
normalizer = float(_get_splat_kernel_normalization(offsets))
|
||||
N, H, W, K, _ = colors.shape
|
||||
splat_weights_and_colors = torch.zeros((N, H, W, K, 9, 5))
|
||||
for n in range(N):
|
||||
for h in range(H):
|
||||
for w in range(W):
|
||||
for k in range(K):
|
||||
q_xy = pixel_coords_screen[n, h, w, k]
|
||||
q_to_px_center = torch.floor(q_xy) - q_xy + 0.5
|
||||
color = colors[n, h, w, k]
|
||||
alpha = colors[n, h, w, k, 3:4]
|
||||
for d in range(9):
|
||||
dist_p_q = torch.sum((q_to_px_center + offsets[d]) ** 2)
|
||||
splat_weight = (
|
||||
alpha * torch.exp(-dist_p_q / (2 * sigma**2)) * normalizer
|
||||
)
|
||||
splat_color = splat_weight * color
|
||||
splat_weights_and_colors[n, h, w, k, d, :4] = splat_color
|
||||
splat_weights_and_colors[n, h, w, k, d, 4:5] = splat_weight
|
||||
return splat_weights_and_colors
|
||||
|
||||
|
||||
class TestPrecompute(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.results_cpu = _precompute((2, 3, 4, 5), torch.device("cpu"))
|
||||
self.results1_cpu = _precompute((1, 1, 1, 1), torch.device("cpu"))
|
||||
|
||||
def test_offsets(self):
|
||||
self.assertClose(self.results_cpu[2].shape, offsets.shape, atol=0)
|
||||
self.assertClose(self.results_cpu[2], offsets, atol=0)
|
||||
|
||||
# Offsets should be independent of input_size.
|
||||
self.assertClose(self.results_cpu[2], self.results1_cpu[2], atol=0)
|
||||
|
||||
def test_crops_h(self):
|
||||
target_crops_h1 = torch.tensor(
|
||||
[
|
||||
# chennels being offset:
|
||||
# R G B A W(eight)
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
]
|
||||
* 3, # 3 because we're aiming at (N, H, W+2, K, 9, 5) with W=1.
|
||||
device=torch.device("cpu"),
|
||||
).reshape(1, 1, 3, 1, 9, 5)
|
||||
self.assertClose(self.results1_cpu[0], target_crops_h1, atol=0)
|
||||
|
||||
target_crops_h_base = target_crops_h1[0, 0, 0]
|
||||
target_crops_h = torch.cat(
|
||||
[target_crops_h_base, target_crops_h_base + 1, target_crops_h_base + 2],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Check that we have the right shape, and (after broadcasting) it has the right
|
||||
# values. These should be repeated (tiled) for each n and k.
|
||||
self.assertClose(
|
||||
self.results_cpu[0].shape, torch.tensor([2, 3, 6, 5, 9, 5]), atol=0
|
||||
)
|
||||
for n in range(2):
|
||||
for w in range(6):
|
||||
for k in range(5):
|
||||
self.assertClose(
|
||||
self.results_cpu[0][n, :, w, k],
|
||||
target_crops_h,
|
||||
)
|
||||
|
||||
def test_crops_w(self):
|
||||
target_crops_w1 = torch.tensor(
|
||||
[
|
||||
# chennels being offset:
|
||||
# R G B A W(eight)
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[2, 2, 2, 2, 2],
|
||||
[2, 2, 2, 2, 2],
|
||||
],
|
||||
device=torch.device("cpu"),
|
||||
).reshape(1, 1, 1, 1, 9, 5)
|
||||
self.assertClose(self.results1_cpu[1], target_crops_w1)
|
||||
|
||||
target_crops_w_base = target_crops_w1[0, 0, 0]
|
||||
target_crops_w = torch.cat(
|
||||
[
|
||||
target_crops_w_base,
|
||||
target_crops_w_base + 1,
|
||||
target_crops_w_base + 2,
|
||||
target_crops_w_base + 3,
|
||||
],
|
||||
dim=0,
|
||||
) # Each w value needs an increment.
|
||||
|
||||
# Check that we have the right shape, and (after broadcasting) it has the right
|
||||
# values. These should be repeated (tiled) for each n and k.
|
||||
self.assertClose(self.results_cpu[1].shape, torch.tensor([2, 3, 4, 5, 9, 5]))
|
||||
for n in range(2):
|
||||
for h in range(3):
|
||||
for k in range(5):
|
||||
self.assertClose(
|
||||
self.results_cpu[1][n, h, :, k],
|
||||
target_crops_w,
|
||||
atol=0,
|
||||
)
|
||||
|
||||
|
||||
class TestPreparPixelsAndColors(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.device = torch.device("cpu")
|
||||
N, H, W, K = 2, 3, 4, 5
|
||||
self.pixel_coords_cameras = torch.randn(
|
||||
(N, H, W, K, 3), device=self.device, requires_grad=True
|
||||
)
|
||||
self.colors_before = torch.rand((N, H, W, K, 3), device=self.device)
|
||||
self.cameras = FoVPerspectiveCameras(device=self.device)
|
||||
self.background_mask = torch.rand((N, H, W, K), device=self.device) < 0.5
|
||||
self.pixel_coords_screen, self.colors_after = _prepare_pixels_and_colors(
|
||||
self.pixel_coords_cameras,
|
||||
self.colors_before,
|
||||
self.cameras,
|
||||
self.background_mask,
|
||||
)
|
||||
|
||||
def test_background_z(self):
|
||||
self.assertTrue(
|
||||
torch.all(self.pixel_coords_screen[..., 2][self.background_mask] == 1.0)
|
||||
)
|
||||
|
||||
def test_background_alpha(self):
|
||||
self.assertTrue(
|
||||
torch.all(self.colors_after[..., 3][self.background_mask] == 0.0)
|
||||
)
|
||||
|
||||
|
||||
class TestGetSplatKernelNormalization(TestCaseMixin, unittest.TestCase):
|
||||
def test_splat_kernel_normalization(self):
|
||||
self.assertAlmostEqual(
|
||||
float(_get_splat_kernel_normalization(offsets)), 0.6503, places=3
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
float(_get_splat_kernel_normalization(offsets, 0.01)), 1.05, places=3
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, "Only positive standard deviations"):
|
||||
_get_splat_kernel_normalization(offsets, 0)
|
||||
|
||||
|
||||
class TestComputeOcclusionLayers(TestCaseMixin, unittest.TestCase):
|
||||
def test_single_layer(self):
|
||||
# If there's only one layer, all splats must be on the surface level.
|
||||
N, H, W, K = 2, 3, 4, 1
|
||||
q_depth = torch.rand(N, H, W, K)
|
||||
occlusion_layers = _compute_occlusion_layers(q_depth)
|
||||
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
|
||||
|
||||
def test_all_equal(self):
|
||||
# If all q-vals are equal, then all splats must be on the surface level.
|
||||
N, H, W, K = 2, 3, 4, 5
|
||||
q_depth = torch.ones((N, H, W, K)) * 0.1234
|
||||
occlusion_layers = _compute_occlusion_layers(q_depth)
|
||||
self.assertClose(occlusion_layers, torch.zeros(N, H, W, 9).long(), atol=0.0)
|
||||
|
||||
def test_mid_to_top_level_splatting(self):
|
||||
# Check that occlusion buffers get accumulated as expected when the splatting
|
||||
# and splatted pixels are co-surface on different intersection layers.
|
||||
# This test will make best sense with accompanying Fig. 4 from "Differentiable
|
||||
# Surface Rendering via Non-differentiable Sampling" by Cole et al.
|
||||
for direction, offset in enumerate(offsets):
|
||||
if direction == 4:
|
||||
continue # Skip self-splatting which is always co-surface.
|
||||
|
||||
depths = torch.zeros(1, 3, 3, 3)
|
||||
|
||||
# This is our q, the pixel splatted onto, in the center of the image.
|
||||
depths[0, 1, 1] = torch.tensor([0.71, 0.8, 1.0])
|
||||
|
||||
# This is our p, the splatting pixel.
|
||||
depths[0, offset[0] + 1, offset[1] + 1] = torch.tensor([0.5, 0.7, 0.9])
|
||||
|
||||
occlusion_layers = _compute_occlusion_layers(depths)
|
||||
|
||||
# Check that we computed that it is the middle layer of p that is co-
|
||||
# surface with q. (1, 1) is the id of q in the depth array, and offset_id
|
||||
# is the id of p's direction w.r.t. q.
|
||||
psurfaceid_onto_q = occlusion_layers[0, 1, 1, direction]
|
||||
self.assertEqual(int(psurfaceid_onto_q), 1)
|
||||
|
||||
# Conversely, if we swap p and q, we have a top-level splatting onto
|
||||
# mid-level. offset + 1 is the id of p, and 8-offset_id is the id of
|
||||
# q's direction w.r.t. p (e.g. if p is [-1, -1] w.r.t. q, then q is
|
||||
# [1, 1] w.r.t. p; we use the ids of these two directions in the offsets
|
||||
# array).
|
||||
qsurfaceid_onto_p = occlusion_layers[
|
||||
0, offset[0] + 1, offset[1] + 1, 8 - direction
|
||||
]
|
||||
self.assertEqual(int(qsurfaceid_onto_p), -1)
|
||||
|
||||
|
||||
class TestComputeSplattingColorsAndWeights(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.N, self.H, self.W, self.K = 2, 3, 4, 5
|
||||
self.pixel_coords_screen = (
|
||||
torch.tile(
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(self.H), torch.arange(self.W), indexing="ij"
|
||||
),
|
||||
dim=-1,
|
||||
).reshape(1, self.H, self.W, 1, 2),
|
||||
(self.N, 1, 1, self.K, 1),
|
||||
).float()
|
||||
+ 0.5
|
||||
)
|
||||
self.colors = torch.ones((self.N, self.H, self.W, self.K, 4))
|
||||
|
||||
def test_all_equal(self):
|
||||
# If all colors are equal and on a regular grid, all weights and reweighted
|
||||
# colors should be equal given a specific splatting direction.
|
||||
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
||||
self.pixel_coords_screen, self.colors * 0.2345, sigma=0.5, offsets=offsets
|
||||
)
|
||||
|
||||
# Splatting directly to the top/bottom/left/right should have the same strenght.
|
||||
non_diag_splats = splatting_colors_and_weights[
|
||||
:, :, :, :, torch.tensor([1, 3, 5, 7])
|
||||
]
|
||||
|
||||
# Same for diagonal splats.
|
||||
diag_splats = splatting_colors_and_weights[
|
||||
:, :, :, :, torch.tensor([0, 2, 6, 8])
|
||||
]
|
||||
|
||||
# And for self-splats.
|
||||
self_splats = splatting_colors_and_weights[:, :, :, :, torch.tensor([4])]
|
||||
|
||||
for splats in non_diag_splats, diag_splats, self_splats:
|
||||
# Colors should be equal.
|
||||
self.assertTrue(torch.all(splats[..., :4] == splats[0, 0, 0, 0, 0, 0]))
|
||||
|
||||
# Weights should be equal.
|
||||
self.assertTrue(torch.all(splats[..., 4] == splats[0, 0, 0, 0, 0, 4]))
|
||||
|
||||
# Non-diagonal weights should be greater than diagonal weights.
|
||||
self.assertGreater(
|
||||
non_diag_splats[0, 0, 0, 0, 0, 0], diag_splats[0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
|
||||
# Self-splats should be strongest of all.
|
||||
self.assertGreater(
|
||||
self_splats[0, 0, 0, 0, 0, 0], non_diag_splats[0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
|
||||
# Splatting colors should be reweighted proportionally to their splat weights.
|
||||
diag_self_color_ratio = (
|
||||
diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
diag_self_weight_ratio = (
|
||||
diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
|
||||
)
|
||||
self.assertEqual(diag_self_color_ratio, diag_self_weight_ratio)
|
||||
|
||||
non_diag_self_color_ratio = (
|
||||
non_diag_splats[0, 0, 0, 0, 0, 0] / self_splats[0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
non_diag_self_weight_ratio = (
|
||||
non_diag_splats[0, 0, 0, 0, 0, 4] / self_splats[0, 0, 0, 0, 0, 4]
|
||||
)
|
||||
self.assertEqual(non_diag_self_color_ratio, non_diag_self_weight_ratio)
|
||||
|
||||
def test_zero_alpha_zero_weight(self):
|
||||
# Pixels with zero alpha do no splatting, but should still be splatted on.
|
||||
colors = self.colors.clone()
|
||||
colors[0, 1, 1, 0, 3] = 0.0
|
||||
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
||||
self.pixel_coords_screen, colors, sigma=0.5, offsets=offsets
|
||||
)
|
||||
|
||||
# The transparent pixel should do no splatting.
|
||||
self.assertTrue(torch.all(splatting_colors_and_weights[0, 1, 1, 0] == 0.0))
|
||||
|
||||
# Splatting *onto* the transparent pixel should be unaffected.
|
||||
reference_weights_colors = splatting_colors_and_weights[0, 1, 1, 1]
|
||||
for direction, offset in enumerate(offsets):
|
||||
if direction == 4:
|
||||
continue # Ignore self-splats
|
||||
# We invert the direction to get the right (h, w, d) coordinate of each
|
||||
# pixel splatting *onto* the pixel with zero alpha.
|
||||
self.assertClose(
|
||||
splatting_colors_and_weights[
|
||||
0, 1 + offset[0], 1 + offset[1], 0, 8 - direction
|
||||
],
|
||||
reference_weights_colors[8 - direction],
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_random_inputs(self):
|
||||
pixel_coords_screen = (
|
||||
self.pixel_coords_screen
|
||||
+ torch.randn((self.N, self.H, self.W, self.K, 2)) * 0.1
|
||||
)
|
||||
colors = torch.rand((self.N, self.H, self.W, self.K, 4))
|
||||
splatting_colors_and_weights = _compute_splatting_colors_and_weights(
|
||||
pixel_coords_screen, colors, sigma=0.5, offsets=offsets
|
||||
)
|
||||
naive_colors_and_weights = compute_splatting_colors_and_weights_naive(
|
||||
pixel_coords_screen, colors, sigma=0.5
|
||||
)
|
||||
|
||||
self.assertClose(
|
||||
splatting_colors_and_weights, naive_colors_and_weights, atol=0.01
|
||||
)
|
||||
|
||||
|
||||
class TestOffsetSplats(TestCaseMixin, unittest.TestCase):
|
||||
def test_offset(self):
|
||||
device = torch.device("cuda:0")
|
||||
N, H, W, K = 2, 3, 4, 5
|
||||
colors_and_weights = torch.rand((N, H, W, K, 9, 5), device=device)
|
||||
crop_ids_h, crop_ids_w, _ = _precompute((N, H, W, K), device=device)
|
||||
offset_colors_and_weights = _offset_splats(
|
||||
colors_and_weights, crop_ids_h, crop_ids_w
|
||||
)
|
||||
|
||||
# Check each splatting direction individually, for clarity.
|
||||
# offset_x, offset_y = (-1, -1)
|
||||
direction = 0
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, 1:, 1:, :, direction],
|
||||
colors_and_weights[:, :-1, :-1, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (-1, 0)
|
||||
direction = 1
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :, 1:, :, direction],
|
||||
colors_and_weights[:, :, :-1, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (-1, 1)
|
||||
direction = 2
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :-1, 1:, :, direction],
|
||||
colors_and_weights[:, 1:, :-1, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, 0, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (0, -1)
|
||||
direction = 3
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, 1:, :, :, direction],
|
||||
colors_and_weights[:, :-1, :, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# self-splat
|
||||
direction = 4
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[..., direction, :],
|
||||
colors_and_weights[..., direction, :],
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (0, 1)
|
||||
direction = 5
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :-1, :, :, direction],
|
||||
colors_and_weights[:, 1:, :, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (1, -1)
|
||||
direction = 6
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, 1:, :-1, :, direction],
|
||||
colors_and_weights[:, :-1, 1:, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, 0, :, :, direction] == 0.0)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (1, 0)
|
||||
direction = 7
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :, :-1, :, direction],
|
||||
colors_and_weights[:, :, 1:, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
# offset_x, offset_y = (1, 1)
|
||||
direction = 8
|
||||
self.assertClose(
|
||||
offset_colors_and_weights[:, :-1, :-1, :, direction],
|
||||
colors_and_weights[:, 1:, 1:, :, direction],
|
||||
atol=0.001,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, -1, :, :, direction] == 0.0)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(offset_colors_and_weights[:, :, -1, :, direction] == 0.0)
|
||||
)
|
||||
|
||||
|
||||
class TestComputeSplattedColorsAndWeights(TestCaseMixin, unittest.TestCase):
|
||||
def test_accumulation_background(self):
|
||||
# Set occlusion_layers to all -1, so all splats are background splats.
|
||||
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
||||
occlusion_layers = torch.zeros((1, 1, 1, 9)) - 1
|
||||
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
||||
occlusion_layers, splat_colors_and_weights
|
||||
)
|
||||
|
||||
# Foreground splats (there are none).
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 0],
|
||||
torch.zeros((4)),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Surface splats (there are none).
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 1],
|
||||
torch.zeros((4)),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Background splats.
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 2],
|
||||
splat_colors_and_weights[0, 0, 0, :, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_accumulation_middle(self):
|
||||
# Set occlusion_layers to all 0, so top splats are co-surface with splatted
|
||||
# pixels. Thus, the top splatting layer should be accumulated to surface, and
|
||||
# all other layers to background.
|
||||
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
||||
occlusion_layers = torch.zeros((1, 1, 1, 9))
|
||||
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
||||
occlusion_layers, splat_colors_and_weights
|
||||
)
|
||||
|
||||
# Foreground splats (there are none).
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 0],
|
||||
torch.zeros((4)),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Surface splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 1],
|
||||
splat_colors_and_weights[0, 0, 0, 0, :, :4].sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Background splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 2],
|
||||
splat_colors_and_weights[0, 0, 0, 1:, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_accumulation_foreground(self):
|
||||
# Set occlusion_layers to all 1. Then the top splatter is a foreground
|
||||
# splatter, mid splatter is surface, and bottom splatter is background.
|
||||
splat_colors_and_weights = torch.rand((1, 1, 1, 3, 9, 5))
|
||||
occlusion_layers = torch.zeros((1, 1, 1, 9)) + 1
|
||||
splatted_colors, splatted_weights = _compute_splatted_colors_and_weights(
|
||||
occlusion_layers, splat_colors_and_weights
|
||||
)
|
||||
|
||||
# Foreground splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 0],
|
||||
splat_colors_and_weights[0, 0, 0, 0:1, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Surface splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 1],
|
||||
splat_colors_and_weights[0, 0, 0, 1:2, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
# Background splats
|
||||
self.assertClose(
|
||||
splatted_colors[0, 0, 0, :, 2],
|
||||
splat_colors_and_weights[0, 0, 0, 2:3, :, :4].sum(dim=0).sum(dim=0),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeAndComposeAllLayers(TestCaseMixin, unittest.TestCase):
|
||||
def test_background_color(self):
|
||||
# Background should always have alpha=0, and the chosen RGB.
|
||||
N, H, W = 2, 3, 4
|
||||
# Make a mask with background in the zeroth row of the first image.
|
||||
bg_mask = torch.zeros([N, H, W, 1, 1])
|
||||
bg_mask[0, :, 0] = 1
|
||||
|
||||
bg_color = torch.tensor([0.2, 0.3, 0.4])
|
||||
|
||||
color_layers = torch.rand((N, H, W, 4, 3)) * (1 - bg_mask)
|
||||
color_weights = torch.rand((N, H, W, 1, 3)) * (1 - bg_mask)
|
||||
|
||||
colors = _normalize_and_compose_all_layers(
|
||||
bg_color, color_layers, color_weights
|
||||
)
|
||||
|
||||
# Background RGB should be .2, .3, .4, and alpha should be 0.
|
||||
self.assertClose(
|
||||
torch.masked_select(colors, bg_mask.bool()[..., 0]),
|
||||
torch.tensor([0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0, 0.2, 0.3, 0.4, 0.0]),
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_compositing_opaque(self):
|
||||
# When all colors are opaque, only the foreground layer should be visible.
|
||||
N, H, W = 2, 3, 4
|
||||
color_layers = torch.rand((N, H, W, 4, 3))
|
||||
color_layers[..., 3, :] = 1.0
|
||||
color_weights = torch.ones((N, H, W, 1, 3))
|
||||
|
||||
out_colors = _normalize_and_compose_all_layers(
|
||||
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
|
||||
)
|
||||
self.assertClose(out_colors, color_layers[..., 0], atol=0.001)
|
||||
|
||||
def test_compositing_transparencies(self):
|
||||
# When foreground layer is transparent and surface and bg are semi-transparent,
|
||||
# we should return a mix of the two latter.
|
||||
N, H, W = 2, 3, 4
|
||||
color_layers = torch.rand((N, H, W, 4, 3))
|
||||
color_layers[..., 3, 0] = 0.1 # fg
|
||||
color_layers[..., 3, 1] = 0.2 # surface
|
||||
color_layers[..., 3, 2] = 0.3 # bg
|
||||
color_weights = torch.ones((N, H, W, 1, 3))
|
||||
|
||||
out_colors = _normalize_and_compose_all_layers(
|
||||
torch.tensor([0.0, 0.0, 0.0]), color_layers, color_weights
|
||||
)
|
||||
self.assertClose(
|
||||
out_colors,
|
||||
color_layers[..., 0]
|
||||
+ 0.9 * (color_layers[..., 1] + 0.8 * color_layers[..., 2]),
|
||||
)
|
||||
Reference in New Issue
Block a user