diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 8d50c52f..4676aa50 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -45,7 +45,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: """ Silhouette blending to return an RGBA image - **RGB** - choose color of the closest point. - - **A** - blend based on the 2D distance based probability map [0]. + - **A** - blend based on the 2D distance based probability map [1]. Args: colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel. @@ -60,7 +60,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: Returns: RGBA pixel_colors: (N, H, W, 4) - [0] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based + [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based 3D Reasoning', ICCV 2019 """ N, H, W, K = fragments.pix_to_face.shape @@ -73,20 +73,13 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: # the face. Therefore use -1.0 * fragments.dists to get the correct sign. prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask - # The cumulative product ensures that alpha will be 1 if at least 1 face - # fully covers the pixel as for that face prob will be 1.0 - # TODO: investigate why torch.cumprod backwards is very slow for large - # values of K. - # Temporarily replace this with exp(sum(log))) using the fact that - # a*b = exp(log(a*b)) = exp(log(a) + log(b)) - # alpha = 1.0 - torch.cumprod((1.0 - prob), dim=-1)[..., -1] - - alpha = 1.0 - torch.exp(torch.log((1.0 - prob)).sum(dim=-1)) - + # The cumulative product ensures that alpha will be 0.0 if at least 1 + # face fully covers the pixel as for that face, prob will be 1.0. + # This results in a multiplication by 0.0 because of the (1.0 - prob) + # term. Therefore 1.0 - alpha will be 1.0. + alpha = torch.prod((1.0 - prob), dim=-1) pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB - pixel_colors[..., 3] = alpha - - pixel_colors = torch.clamp(pixel_colors, min=0, max=1.0) + pixel_colors[..., 3] = 1.0 - alpha return torch.flip(pixel_colors, [1]) @@ -95,7 +88,7 @@ def softmax_rgb_blend( ) -> torch.Tensor: """ RGB and alpha channel blending to return an RGBA image based on the method - proposed in [0] + proposed in [1] - **RGB** - blend the colors based on the 2D distance based probability map and relative z distances. - **A** - blend based on the 2D distance based probability map. @@ -151,15 +144,11 @@ def softmax_rgb_blend( # Sigmoid probability map based on the distance of the pixel to the face. prob_map = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask - # The cumulative product ensures that alpha will be 1 if at least 1 face - # fully covers the pixel as for that face prob will be 1.0 - # TODO: investigate why torch.cumprod backwards is very slow for large - # values of K. - # Temporarily replace this with exp(sum(log))) using the fact that - # a*b = exp(log(a*b)) = exp(log(a) + log(b)) - # alpha = 1.0 - torch.cumprod((1.0 - prob), dim=-1)[..., -1] - - alpha = 1.0 - torch.exp(torch.log((1.0 - prob_map)).sum(dim=-1)) + # The cumulative product ensures that alpha will be 0.0 if at least 1 + # face fully covers the pixel as for that face, prob will be 1.0. + # This results in a multiplication by 0.0 because of the (1.0 - prob) + # term. Therefore 1.0 - alpha will be 1.0. + alpha = torch.prod((1.0 - prob_map), dim=-1) # Weights for each face. Adjust the exponential by the max z to prevent # overflow. zbuf shape (N, H, W, K), find max over K. @@ -178,8 +167,6 @@ def softmax_rgb_blend( weighted_colors = (weights[..., None] * colors).sum(dim=-2) weighted_background = (delta / denom) * background pix_colors[..., :3] = weighted_colors + weighted_background - pix_colors[..., 3] = alpha + pix_colors[..., 3] = 1.0 - alpha - # Clamp colors to the range 0-1 and flip y axis. - pix_colors = torch.clamp(pix_colors, min=0, max=1.0) return torch.flip(pix_colors, [1]) diff --git a/tests/bm_blending.py b/tests/bm_blending.py new file mode 100644 index 00000000..e5ebfb30 --- /dev/null +++ b/tests/bm_blending.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + + +from itertools import product +from fvcore.common.benchmark import benchmark + +from test_blending import TestBlending + + +def bm_blending() -> None: + devices = ["cpu", "cuda"] + kwargs_list = [] + num_meshes = [16] + image_size = [128, 256] + faces_per_pixel = [50, 100] + test_cases = product(num_meshes, image_size, faces_per_pixel, devices) + + for case in test_cases: + n, s, k, d = case + kwargs_list.append( + { + "num_meshes": n, + "image_size": s, + "faces_per_pixel": k, + "device": d, + } + ) + + benchmark( + TestBlending.bm_sigmoid_alpha_blending, + "SIGMOID_ALPHA_BLENDING_PYTORCH", + kwargs_list, + warmup_iters=1, + ) + + benchmark( + TestBlending.bm_softmax_blending, + "SOFTMAX_BLENDING_PYTORCH", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/test_blending.py b/tests/test_blending.py index 87efad60..47a901fd 100644 --- a/tests/test_blending.py +++ b/tests/test_blending.py @@ -14,7 +14,7 @@ from pytorch3d.renderer.blending import ( from pytorch3d.renderer.mesh.rasterizer import Fragments -def sigmoid_blend_naive(colors, fragments, blend_params): +def sigmoid_blend_naive_loop(colors, fragments, blend_params): """ Naive for loop based implementation of distance based alpha calculation. Only for test purposes. @@ -41,10 +41,38 @@ def sigmoid_blend_naive(colors, fragments, blend_params): pixel_colors[n, h, w, :3] = colors[n, h, w, 0, :] pixel_colors[n, h, w, 3] = 1.0 - alpha - pixel_colors = torch.clamp(pixel_colors, min=0, max=1.0) return torch.flip(pixel_colors, [1]) +def sigmoid_blend_naive_loop_backward( + grad_images, images, fragments, blend_params +): + pix_to_face = fragments.pix_to_face + dists = fragments.dists + sigma = blend_params.sigma + + N, H, W, K = pix_to_face.shape + device = pix_to_face.device + grad_distances = torch.zeros((N, H, W, K), dtype=dists.dtype, device=device) + images = torch.flip(images, [1]) + grad_images = torch.flip(grad_images, [1]) + + for n in range(N): + for h in range(H): + for w in range(W): + alpha = 1.0 - images[n, h, w, 3] + grad_alpha = grad_images[n, h, w, 3] + # Loop over k faces and calculate 2D distance based probability + # map. + for k in range(K): + if pix_to_face[n, h, w, k] >= 0: + prob = torch.sigmoid(-dists[n, h, w, k] / sigma) + grad_distances[n, h, w, k] = ( + grad_alpha * (-1.0 / sigma) * prob * alpha + ) + return grad_distances + + def softmax_blend_naive(colors, fragments, blend_params): """ Naive for loop based implementation of softmax blending. @@ -76,7 +104,7 @@ def softmax_blend_naive(colors, fragments, blend_params): for h in range(H): for w in range(W): alpha = 1.0 - weights_k = torch.zeros(K) + weights_k = torch.zeros(K, device=device) zmax = 0.0 # Loop over K to find max z. @@ -102,7 +130,6 @@ def softmax_blend_naive(colors, fragments, blend_params): pixel_colors[n, h, w, :3] += (delta / denom) * bk_color pixel_colors[n, h, w, 3] = 1.0 - alpha - pixel_colors = torch.clamp(pixel_colors, min=0, max=1.0) return torch.flip(pixel_colors, [1]) @@ -110,6 +137,37 @@ class TestBlending(unittest.TestCase): def setUp(self) -> None: torch.manual_seed(42) + def _compare_impls( + self, + fn1, + fn2, + args1, + args2, + grad_var1=None, + grad_var2=None, + compare_grads=True, + ): + + out1 = fn1(*args1) + out2 = fn2(*args2) + self.assertTrue(torch.allclose(out1.cpu(), out2.cpu(), atol=1e-7)) + + # Check gradients + if not compare_grads: + return + + grad_out = torch.randn_like(out1) + (out1 * grad_out).sum().backward() + self.assertTrue(hasattr(grad_var1, "grad")) + + (out2 * grad_out).sum().backward() + self.assertTrue(hasattr(grad_var2, "grad")) + self.assertTrue( + torch.allclose( + grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5 + ) + ) + def test_hard_rgb_blend(self): N, H, W, K = 5, 10, 10, 20 pix_to_face = torch.ones((N, H, W, K)) @@ -129,116 +187,246 @@ class TestBlending(unittest.TestCase): expected_vals[..., :3] = pix_cols self.assertTrue(torch.allclose(images, expected_vals)) - def test_sigmoid_alpha_blend(self): - """ - Test outputs of sigmoid alpha blend tensorised function match those of - the naive iterative version. Also check gradients match. - """ + def test_sigmoid_alpha_blend_manual_gradients(self): + # Create dummy outputs of rasterization + torch.manual_seed(231) + F = 32 # number of faces in the mesh + # The python loop version is really slow so only using small input sizes. + N, S, K = 2, 3, 2 + device = torch.device("cuda") + pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 + colors = torch.randn((N, S, S, K, 3), device=device) + empty = torch.tensor([], device=device) - # Create dummy outputs of rasterization simulating a cube in the centre - # of the image with surrounding padded values. - N, S, K = 1, 8, 2 - pix_to_face = -torch.ones((N, S, S, K), dtype=torch.int64) - h = int(S / 2) - pix_to_face_full = torch.randint(size=(N, h, h, K), low=0, high=100) - s = int(S / 4) - e = int(0.75 * S) - pix_to_face[:, s:e, s:e, :] = pix_to_face_full - bary_coords = torch.ones((N, S, S, K, 3)) - - # randomly flip the sign of the distance - # (-) means inside triangle, (+) means outside triangle. + # # randomly flip the sign of the distance + # # (-) means inside triangle, (+) means outside triangle. random_sign_flip = torch.rand((N, S, S, K)) random_sign_flip[random_sign_flip > 0.5] *= -1.0 - dists = torch.randn(size=(N, S, S, K)) - dists1 = dists * random_sign_flip - dists2 = dists1.clone() - dists1.requires_grad = True + dists = torch.randn( + size=(N, S, S, K), requires_grad=True, device=device + ) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=empty, # dummy + zbuf=empty, # dummy + dists=dists, + ) + blend_params = BlendParams(sigma=1e-3) + pix_cols = sigmoid_blend_naive_loop(colors, fragments, blend_params) + grad_out = torch.randn_like(pix_cols) + + # Backward pass + pix_cols.backward(grad_out) + grad_dists = sigmoid_blend_naive_loop_backward( + grad_out, pix_cols, fragments, blend_params + ) + self.assertTrue(torch.allclose(dists.grad, grad_dists, atol=1e-7)) + + def test_sigmoid_alpha_blend_python(self): + """ + Test outputs of python tensorised function and python loop + """ + + # Create dummy outputs of rasterization + torch.manual_seed(231) + F = 32 # number of faces in the mesh + # The python loop version is really slow so only using small input sizes. + N, S, K = 2, 10, 5 + device = torch.device("cuda") + pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 + colors = torch.randn((N, S, S, K, 3), device=device) + empty = torch.tensor([], device=device) + + # # randomly flip the sign of the distance + # # (-) means inside triangle, (+) means outside triangle. + random_sign_flip = torch.rand((N, S, S, K)) + random_sign_flip[random_sign_flip > 0.5] *= -1.0 + dists1 = torch.randn( + size=(N, S, S, K), requires_grad=True, device=device + ) + dists2 = dists1.detach().clone() dists2.requires_grad = True - colors = torch.randn_like(bary_coords) + fragments1 = Fragments( pix_to_face=pix_to_face, - bary_coords=bary_coords, # dummy - zbuf=pix_to_face, # dummy + bary_coords=empty, # dummy + zbuf=empty, # dummy dists=dists1, ) fragments2 = Fragments( pix_to_face=pix_to_face, - bary_coords=bary_coords, # dummy - zbuf=pix_to_face, # dummy + bary_coords=empty, # dummy + zbuf=empty, # dummy dists=dists2, ) - blend_params = BlendParams(sigma=2e-1) - images = sigmoid_alpha_blend(colors, fragments1, blend_params) - images_naive = sigmoid_blend_naive(colors, fragments2, blend_params) - self.assertTrue(torch.allclose(images, images_naive)) - torch.manual_seed(231) - images.sum().backward() - self.assertTrue(hasattr(dists1, "grad")) - images_naive.sum().backward() - self.assertTrue(hasattr(dists2, "grad")) + blend_params = BlendParams(sigma=1e-2) + args1 = (colors, fragments1, blend_params) + args2 = (colors, fragments2, blend_params) - self.assertTrue(torch.allclose(dists1.grad, dists2.grad, rtol=1e-5)) + self._compare_impls( + sigmoid_alpha_blend, + sigmoid_blend_naive_loop, + args1, + args2, + dists1, + dists2, + compare_grads=True, + ) def test_softmax_rgb_blend(self): # Create dummy outputs of rasterization simulating a cube in the centre # of the image with surrounding padded values. N, S, K = 1, 8, 2 - pix_to_face = -torch.ones((N, S, S, K), dtype=torch.int64) + device = torch.device("cuda") + pix_to_face = -torch.ones( + (N, S, S, K), dtype=torch.int64, device=device + ) h = int(S / 2) - pix_to_face_full = torch.randint(size=(N, h, h, K), low=0, high=100) + pix_to_face_full = torch.randint( + size=(N, h, h, K), low=0, high=100, device=device + ) s = int(S / 4) e = int(0.75 * S) pix_to_face[:, s:e, s:e, :] = pix_to_face_full - bary_coords = torch.ones((N, S, S, K, 3)) + empty = torch.tensor([], device=device) - random_sign_flip = torch.rand((N, S, S, K)) + random_sign_flip = torch.rand((N, S, S, K), device=device) random_sign_flip[random_sign_flip > 0.5] *= -1.0 - zbuf1 = torch.randn(size=(N, S, S, K)) + zbuf1 = torch.randn(size=(N, S, S, K), device=device) # randomly flip the sign of the distance # (-) means inside triangle, (+) means outside triangle. - dists1 = torch.randn(size=(N, S, S, K)) * random_sign_flip + dists1 = ( + torch.randn(size=(N, S, S, K), device=device) * random_sign_flip + ) dists2 = dists1.clone() zbuf2 = zbuf1.clone() dists1.requires_grad = True dists2.requires_grad = True - zbuf1.requires_grad = True - zbuf2.requires_grad = True - colors = torch.randn_like(bary_coords) + colors = torch.randn((N, S, S, K, 3), device=device) fragments1 = Fragments( pix_to_face=pix_to_face, - bary_coords=bary_coords, # dummy + bary_coords=empty, # dummy zbuf=zbuf1, dists=dists1, ) fragments2 = Fragments( pix_to_face=pix_to_face, - bary_coords=bary_coords, # dummy + bary_coords=empty, # dummy zbuf=zbuf2, dists=dists2, ) - blend_params = BlendParams(sigma=1e-1) - images = softmax_rgb_blend(colors, fragments1, blend_params) - images_naive = softmax_blend_naive(colors, fragments2, blend_params) - self.assertTrue(torch.allclose(images, images_naive)) - # Check gradients. - images.sum().backward() - self.assertTrue(hasattr(dists1, "grad")) - self.assertTrue(hasattr(zbuf1, "grad")) - images_naive.sum().backward() - self.assertTrue(hasattr(dists2, "grad")) - self.assertTrue(hasattr(zbuf2, "grad")) + blend_params = BlendParams(sigma=1e-3) + args1 = (colors, fragments1, blend_params) + args2 = (colors, fragments2, blend_params) + self._compare_impls( + softmax_rgb_blend, + softmax_blend_naive, + args1, + args2, + dists1, + dists2, + compare_grads=True, + ) - self.assertTrue(torch.allclose(dists1.grad, dists2.grad, atol=2e-5)) - self.assertTrue(torch.allclose(zbuf1.grad, zbuf2.grad, atol=2e-5)) + @staticmethod + def bm_sigmoid_alpha_blending( + num_meshes: int = 16, + image_size: int = 128, + faces_per_pixel: int = 100, + device: str = "cpu", + ): + if torch.cuda.is_available() and "cuda:" in device: + # If a device other than the default is used, set the device explicity. + torch.cuda.set_device(device) + + device = torch.device(device) + torch.manual_seed(231) + + # Create dummy outputs of rasterization + N, S, K = num_meshes, image_size, faces_per_pixel + F = 32 # num faces in the mesh + pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 + colors = torch.randn((N, S, S, K, 3), device=device) + empty = torch.tensor([], device=device) + + # # randomly flip the sign of the distance + # # (-) means inside triangle, (+) means outside triangle. + random_sign_flip = torch.rand((N, S, S, K), device=device) + random_sign_flip[random_sign_flip > 0.5] *= -1.0 + dists1 = torch.randn( + size=(N, S, S, K), requires_grad=True, device=device + ) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=empty, # dummy + zbuf=empty, # dummy + dists=dists1, + ) + blend_params = BlendParams(sigma=1e-3) + torch.cuda.synchronize() + + def fn(): + # test forward and backward pass + images = sigmoid_alpha_blend(colors, fragments, blend_params) + images.sum().backward() + torch.cuda.synchronize() + + return fn + + @staticmethod + def bm_softmax_blending( + num_meshes: int = 16, + image_size: int = 128, + faces_per_pixel: int = 100, + device: str = "cpu", + ): + if torch.cuda.is_available() and "cuda:" in device: + # If a device other than the default is used, set the device explicity. + torch.cuda.set_device(device) + + device = torch.device(device) + torch.manual_seed(231) + + # Create dummy outputs of rasterization + N, S, K = num_meshes, image_size, faces_per_pixel + F = 32 # num faces in the mesh + pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 + colors = torch.randn((N, S, S, K, 3), device=device) + empty = torch.tensor([], device=device) + + # # randomly flip the sign of the distance + # # (-) means inside triangle, (+) means outside triangle. + random_sign_flip = torch.rand((N, S, S, K), device=device) + random_sign_flip[random_sign_flip > 0.5] *= -1.0 + dists1 = torch.randn( + size=(N, S, S, K), requires_grad=True, device=device + ) + zbuf = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=empty, # dummy + zbuf=zbuf, + dists=dists1, + ) + blend_params = BlendParams(sigma=1e-3) + + torch.cuda.synchronize() + + def fn(): + # test forward and backward pass + images = softmax_rgb_blend(colors, fragments, blend_params) + images.sum().backward() + torch.cuda.synchronize() + + return fn def test_blend_params(self): """Test colour parameter of BlendParams(). - Assert passed value overrides default value. - """ + Assert passed value overrides default value. + """ bp_default = BlendParams() bp_new = BlendParams(background_color=(0.5, 0.5, 0.5)) self.assertEqual(bp_new.background_color, (0.5, 0.5, 0.5))