From 5852b74d122f95fc7eeb041cd11593f647b9fc7a Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Mon, 17 Aug 2020 11:57:32 -0700 Subject: [PATCH] Softmax blending small fix Summary: Small fix to the softmax blending function. To avoid overflow in the exponential for the softmax, the exponent is shifted by the maximum value. In the final calculation of the color there is a weighted sum between the pixel color and the background color - in order for the sum to be correct, the background color also needs to be handled in the same way witt the shifted exponent. Reviewed By: gkioxari Differential Revision: D23148301 fbshipit-source-id: 86066586ee7d3ce7bd4a2076b12ce191fbd151a7 --- pytorch3d/renderer/blending.py | 17 +++++++++-------- tests/test_blending.py | 19 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index a23937f3..80ef33ef 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -3,7 +3,6 @@ from typing import NamedTuple, Sequence -import numpy as np import torch # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. @@ -162,9 +161,8 @@ def softmax_rgb_blend( if not torch.is_tensor(background): background = torch.tensor(background, dtype=torch.float32, device=device) - # Background color - delta = np.exp(1e-10 / blend_params.gamma) * 1e-10 - delta = torch.tensor(delta, device=device) + # Weight for background color + eps = 1e-10 # Mask for padded pixels. mask = fragments.pix_to_face >= 0 @@ -189,15 +187,18 @@ def softmax_rgb_blend( # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`. weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma) + # Also apply exp normalize trick for the background color weight. + # Clamp to ensure delta is never 0. + delta = torch.exp((eps - z_inv_max) / blend_params.gamma).clamp(min=eps) + # Normalize weights. # weights_num shape: (N, H, W, K). Sum over K and divide through by the sum. denom = weights_num.sum(dim=-1)[..., None] + delta - weights = weights_num / denom # Sum: weights * textures + background color - weighted_colors = (weights[..., None] * colors).sum(dim=-2) - weighted_background = (delta / denom) * background - pixel_colors[..., :3] = weighted_colors + weighted_background + weighted_colors = (weights_num[..., None] * colors).sum(dim=-2) + weighted_background = delta * background + pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom pixel_colors[..., 3] = 1.0 - alpha return pixel_colors diff --git a/tests/test_blending.py b/tests/test_blending.py index cb9ba41a..038ae961 100644 --- a/tests/test_blending.py +++ b/tests/test_blending.py @@ -2,7 +2,6 @@ import unittest -import numpy as np import torch from common_testing import TestCaseMixin from pytorch3d.renderer.blending import ( @@ -97,21 +96,18 @@ def softmax_blend_naive(colors, fragments, blend_params): # Near and far clipping planes zfar = 100.0 znear = 1.0 + eps = 1e-10 bk_color = blend_params.background_color if not torch.is_tensor(bk_color): bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device) - # Background color component - delta = np.exp(1e-10 / gamma) * 1e-10 - delta = torch.tensor(delta).to(device=device) - for n in range(N): for h in range(H): for w in range(W): alpha = 1.0 weights_k = torch.zeros(K, device=device) - zmax = 0.0 + zmax = torch.tensor(0.0, device=device) # Loop over K to find max z. for k in range(K): @@ -129,11 +125,13 @@ def softmax_blend_naive(colors, fragments, blend_params): alpha *= 1.0 - prob # cumulative product weights_k[k] = prob * torch.exp((zinv - zmax) / gamma) + # Clamp to ensure delta is never 0 + delta = torch.exp((eps - zmax) / blend_params.gamma).clamp(min=eps) + delta = delta.to(device) denom = weights_k.sum() + delta - weights = weights_k / denom - cols = (weights[..., None] * colors[n, h, w, :, :]).sum(dim=0) - pixel_colors[n, h, w, :3] = cols - pixel_colors[n, h, w, :3] += (delta / denom) * bk_color + cols = (weights_k[..., None] * colors[n, h, w, :, :]).sum(dim=0) + pixel_colors[n, h, w, :3] = cols + delta * bk_color + pixel_colors[n, h, w, :3] /= denom pixel_colors[n, h, w, 3] = 1.0 - alpha return pixel_colors @@ -160,6 +158,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase): (out2 * grad_out).sum().backward() self.assertTrue(hasattr(grad_var2, "grad")) + self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5) def test_hard_rgb_blend(self):