pytorch3d/tests/test_blending.py
Nikhila Ravi bce396df93 C++/CUDA implementation of sigmoid alpha blend
Summary:
C++/CUDA implementation of forward and backward passes for the sigmoid alpha blending function.

This is slightly faster than the vectorized implementation in Python, but more importantly uses less memory due to fewer tensors being created.

Reviewed By: gkioxari

Differential Revision: D19980671

fbshipit-source-id: 0779055d2c68b1f20fb0870e60046077ef4613ff
2020-07-16 10:17:28 -07:00

420 lines
15 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.blending import (
BlendParams,
hard_rgb_blend,
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
def sigmoid_blend_naive_loop(colors, fragments, blend_params):
"""
Naive for loop based implementation of distance based alpha calculation.
Only for test purposes.
"""
pix_to_face = fragments.pix_to_face
dists = fragments.dists
sigma = blend_params.sigma
N, H, W, K = pix_to_face.shape
device = pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
for n in range(N):
for h in range(H):
for w in range(W):
alpha = 1.0
# Loop over k faces and calculate 2D distance based probability
# map.
for k in range(K):
if pix_to_face[n, h, w, k] >= 0:
prob = torch.sigmoid(-dists[n, h, w, k] / sigma)
alpha *= 1.0 - prob # cumulative product
pixel_colors[n, h, w, :3] = colors[n, h, w, 0, :]
pixel_colors[n, h, w, 3] = 1.0 - alpha
return pixel_colors
def sigmoid_alpha_blend_vectorized(colors, fragments, blend_params) -> torch.Tensor:
N, H, W, K = fragments.pix_to_face.shape
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
mask = fragments.pix_to_face >= 0
prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
pixel_colors[..., :3] = colors[..., 0, :]
pixel_colors[..., 3] = 1.0 - torch.prod((1.0 - prob), dim=-1)
return pixel_colors
def sigmoid_blend_naive_loop_backward(grad_images, images, fragments, blend_params):
pix_to_face = fragments.pix_to_face
dists = fragments.dists
sigma = blend_params.sigma
N, H, W, K = pix_to_face.shape
device = pix_to_face.device
grad_distances = torch.zeros((N, H, W, K), dtype=dists.dtype, device=device)
for n in range(N):
for h in range(H):
for w in range(W):
alpha = 1.0 - images[n, h, w, 3]
grad_alpha = grad_images[n, h, w, 3]
# Loop over k faces and calculate 2D distance based probability
# map.
for k in range(K):
if pix_to_face[n, h, w, k] >= 0:
prob = torch.sigmoid(-dists[n, h, w, k] / sigma)
grad_distances[n, h, w, k] = (
grad_alpha * (-1.0 / sigma) * prob * alpha
)
return grad_distances
def softmax_blend_naive(colors, fragments, blend_params):
"""
Naive for loop based implementation of softmax blending.
Only for test purposes.
"""
pix_to_face = fragments.pix_to_face
dists = fragments.dists
zbuf = fragments.zbuf
sigma = blend_params.sigma
gamma = blend_params.gamma
N, H, W, K = pix_to_face.shape
device = pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
# Near and far clipping planes
zfar = 100.0
znear = 1.0
bk_color = blend_params.background_color
if not torch.is_tensor(bk_color):
bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device)
# Background color component
delta = np.exp(1e-10 / gamma) * 1e-10
delta = torch.tensor(delta).to(device=device)
for n in range(N):
for h in range(H):
for w in range(W):
alpha = 1.0
weights_k = torch.zeros(K, device=device)
zmax = 0.0
# Loop over K to find max z.
for k in range(K):
if pix_to_face[n, h, w, k] >= 0:
zinv = (zfar - zbuf[n, h, w, k]) / (zfar - znear)
if zinv > zmax:
zmax = zinv
# Loop over K faces to calculate 2D distance based probability
# map and zbuf based weights for colors.
for k in range(K):
if pix_to_face[n, h, w, k] >= 0:
zinv = (zfar - zbuf[n, h, w, k]) / (zfar - znear)
prob = torch.sigmoid(-dists[n, h, w, k] / sigma)
alpha *= 1.0 - prob # cumulative product
weights_k[k] = prob * torch.exp((zinv - zmax) / gamma)
denom = weights_k.sum() + delta
weights = weights_k / denom
cols = (weights[..., None] * colors[n, h, w, :, :]).sum(dim=0)
pixel_colors[n, h, w, :3] = cols
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
pixel_colors[n, h, w, 3] = 1.0 - alpha
return pixel_colors
class TestBlending(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
def _compare_impls(
self, fn1, fn2, args1, args2, grad_var1=None, grad_var2=None, compare_grads=True
):
out1 = fn1(*args1)
out2 = fn2(*args2)
self.assertClose(out1.cpu()[..., 3], out2.cpu()[..., 3], atol=1e-7)
# Check gradients
if not compare_grads:
return
grad_out = torch.randn_like(out1)
(out1 * grad_out).sum().backward()
self.assertTrue(hasattr(grad_var1, "grad"))
(out2 * grad_out).sum().backward()
self.assertTrue(hasattr(grad_var2, "grad"))
self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
def test_hard_rgb_blend(self):
N, H, W, K = 5, 10, 10, 20
pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K))
bary_coords = torch.ones((N, H, W, K, 3))
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=bary_coords,
zbuf=pix_to_face, # dummy
dists=pix_to_face, # dummy
)
colors = torch.randn((N, H, W, K, 3))
blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1))
images = hard_rgb_blend(colors, fragments, blend_params)
# Examine if the foreground colors are correct.
is_foreground = pix_to_face[..., 0] >= 0
self.assertClose(images[is_foreground][:, :3], colors[is_foreground][..., 0, :])
# Examine if the background colors are correct.
for i in range(3): # i.e. RGB
channel_color = blend_params.background_color[i]
self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all())
# Examine the alpha channel is correct
self.assertTrue(images[..., 3].eq(1).all())
def test_sigmoid_alpha_blend_manual_gradients(self):
# Create dummy outputs of rasterization
torch.manual_seed(231)
F = 32 # number of faces in the mesh
# The python loop version is really slow so only using small input sizes.
N, S, K = 2, 3, 2
device = torch.device("cuda")
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
colors = torch.randn((N, S, S, K, 3), device=device)
empty = torch.tensor([], device=device)
# # randomly flip the sign of the distance
# # (-) means inside triangle, (+) means outside triangle.
random_sign_flip = torch.rand((N, S, S, K))
random_sign_flip[random_sign_flip > 0.5] *= -1.0
dists = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=empty, # dummy
zbuf=empty, # dummy
dists=dists,
)
blend_params = BlendParams(sigma=1e-3)
pix_cols = sigmoid_blend_naive_loop(colors, fragments, blend_params)
grad_out = torch.randn_like(pix_cols)
# Backward pass
pix_cols.backward(grad_out)
grad_dists = sigmoid_blend_naive_loop_backward(
grad_out, pix_cols, fragments, blend_params
)
self.assertTrue(torch.allclose(dists.grad, grad_dists, atol=1e-7))
def test_sigmoid_alpha_blend_python(self):
"""
Test outputs of python tensorised function and python loop
"""
# Create dummy outputs of rasterization
torch.manual_seed(231)
F = 32 # number of faces in the mesh
# The python loop version is really slow so only using small input sizes.
N, S, K = 1, 4, 1
device = torch.device("cuda")
pix_to_face = torch.randint(low=-1, high=F, size=(N, S, S, K), device=device)
colors = torch.randn((N, S, S, K, 3), device=device)
empty = torch.tensor([], device=device)
dists1 = torch.randn(size=(N, S, S, K), device=device)
dists2 = dists1.clone()
dists1.requires_grad = True
dists2.requires_grad = True
fragments1 = Fragments(
pix_to_face=pix_to_face,
bary_coords=empty, # dummy
zbuf=empty, # dummy
dists=dists1,
)
fragments2 = Fragments(
pix_to_face=pix_to_face,
bary_coords=empty, # dummy
zbuf=empty, # dummy
dists=dists2,
)
blend_params = BlendParams(sigma=1e-2)
args1 = (colors, fragments1, blend_params)
args2 = (colors, fragments2, blend_params)
self._compare_impls(
sigmoid_alpha_blend,
sigmoid_alpha_blend_vectorized,
args1,
args2,
dists1,
dists2,
compare_grads=True,
)
def test_softmax_rgb_blend(self):
# Create dummy outputs of rasterization simulating a cube in the center
# of the image with surrounding padded values.
N, S, K = 1, 8, 2
device = torch.device("cuda")
pix_to_face = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.int64, device=device
)
h = int(S / 2)
pix_to_face_full = torch.randint(
size=(N, h, h, K), low=0, high=100, device=device
)
s = int(S / 4)
e = int(0.75 * S)
pix_to_face[:, s:e, s:e, :] = pix_to_face_full
empty = torch.tensor([], device=device)
random_sign_flip = torch.rand((N, S, S, K), device=device)
random_sign_flip[random_sign_flip > 0.5] *= -1.0
zbuf1 = torch.randn(size=(N, S, S, K), device=device)
# randomly flip the sign of the distance
# (-) means inside triangle, (+) means outside triangle.
dists1 = torch.randn(size=(N, S, S, K), device=device) * random_sign_flip
dists2 = dists1.clone()
zbuf2 = zbuf1.clone()
dists1.requires_grad = True
dists2.requires_grad = True
colors = torch.randn((N, S, S, K, 3), device=device)
fragments1 = Fragments(
pix_to_face=pix_to_face,
bary_coords=empty, # dummy
zbuf=zbuf1,
dists=dists1,
)
fragments2 = Fragments(
pix_to_face=pix_to_face,
bary_coords=empty, # dummy
zbuf=zbuf2,
dists=dists2,
)
blend_params = BlendParams(sigma=1e-3)
args1 = (colors, fragments1, blend_params)
args2 = (colors, fragments2, blend_params)
self._compare_impls(
softmax_rgb_blend,
softmax_blend_naive,
args1,
args2,
dists1,
dists2,
compare_grads=True,
)
@staticmethod
def bm_sigmoid_alpha_blending(
num_meshes: int = 16,
image_size: int = 128,
faces_per_pixel: int = 100,
device="cuda",
backend: str = "pytorch",
):
device = torch.device(device)
torch.manual_seed(231)
# Create dummy outputs of rasterization
N, S, K = num_meshes, image_size, faces_per_pixel
F = 32 # num faces in the mesh
pix_to_face = torch.randint(
low=-1, high=F + 1, size=(N, S, S, K), device=device
)
colors = torch.randn((N, S, S, K, 3), device=device)
empty = torch.tensor([], device=device)
dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=empty, # dummy
zbuf=empty, # dummy
dists=dists1,
)
blend_params = BlendParams(sigma=1e-3)
blend_fn = (
sigmoid_alpha_blend_vectorized
if backend == "pytorch"
else sigmoid_alpha_blend
)
torch.cuda.synchronize()
def fn():
# test forward and backward pass
images = blend_fn(colors, fragments, blend_params)
images.sum().backward()
torch.cuda.synchronize()
return fn
@staticmethod
def bm_softmax_blending(
num_meshes: int = 16,
image_size: int = 128,
faces_per_pixel: int = 100,
device: str = "cpu",
backend: str = "pytorch",
):
if torch.cuda.is_available() and "cuda:" in device:
# If a device other than the default is used, set the device explicity.
torch.cuda.set_device(device)
device = torch.device(device)
torch.manual_seed(231)
# Create dummy outputs of rasterization
N, S, K = num_meshes, image_size, faces_per_pixel
F = 32 # num faces in the mesh
pix_to_face = torch.randint(
low=-1, high=F + 1, size=(N, S, S, K), device=device
)
colors = torch.randn((N, S, S, K, 3), device=device)
empty = torch.tensor([], device=device)
dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
zbuf = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
fragments = Fragments(
pix_to_face=pix_to_face, bary_coords=empty, zbuf=zbuf, dists=dists1 # dummy
)
blend_params = BlendParams(sigma=1e-3)
torch.cuda.synchronize()
def fn():
# test forward and backward pass
images = softmax_rgb_blend(colors, fragments, blend_params)
images.sum().backward()
torch.cuda.synchronize()
return fn
def test_blend_params(self):
"""Test color parameter of BlendParams().
Assert passed value overrides default value.
"""
bp_default = BlendParams()
bp_new = BlendParams(background_color=(0.5, 0.5, 0.5))
self.assertEqual(bp_new.background_color, (0.5, 0.5, 0.5))
self.assertEqual(bp_default.background_color, (1.0, 1.0, 1.0))