diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index a23937f3..80ef33ef 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -3,7 +3,6 @@ from typing import NamedTuple, Sequence -import numpy as np import torch # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. @@ -162,9 +161,8 @@ def softmax_rgb_blend( if not torch.is_tensor(background): background = torch.tensor(background, dtype=torch.float32, device=device) - # Background color - delta = np.exp(1e-10 / blend_params.gamma) * 1e-10 - delta = torch.tensor(delta, device=device) + # Weight for background color + eps = 1e-10 # Mask for padded pixels. mask = fragments.pix_to_face >= 0 @@ -189,15 +187,18 @@ def softmax_rgb_blend( # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`. weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma) + # Also apply exp normalize trick for the background color weight. + # Clamp to ensure delta is never 0. + delta = torch.exp((eps - z_inv_max) / blend_params.gamma).clamp(min=eps) + # Normalize weights. # weights_num shape: (N, H, W, K). Sum over K and divide through by the sum. denom = weights_num.sum(dim=-1)[..., None] + delta - weights = weights_num / denom # Sum: weights * textures + background color - weighted_colors = (weights[..., None] * colors).sum(dim=-2) - weighted_background = (delta / denom) * background - pixel_colors[..., :3] = weighted_colors + weighted_background + weighted_colors = (weights_num[..., None] * colors).sum(dim=-2) + weighted_background = delta * background + pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom pixel_colors[..., 3] = 1.0 - alpha return pixel_colors diff --git a/tests/test_blending.py b/tests/test_blending.py index cb9ba41a..038ae961 100644 --- a/tests/test_blending.py +++ b/tests/test_blending.py @@ -2,7 +2,6 @@ import unittest -import numpy as np import torch from common_testing import TestCaseMixin from pytorch3d.renderer.blending import ( @@ -97,21 +96,18 @@ def softmax_blend_naive(colors, fragments, blend_params): # Near and far clipping planes zfar = 100.0 znear = 1.0 + eps = 1e-10 bk_color = blend_params.background_color if not torch.is_tensor(bk_color): bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device) - # Background color component - delta = np.exp(1e-10 / gamma) * 1e-10 - delta = torch.tensor(delta).to(device=device) - for n in range(N): for h in range(H): for w in range(W): alpha = 1.0 weights_k = torch.zeros(K, device=device) - zmax = 0.0 + zmax = torch.tensor(0.0, device=device) # Loop over K to find max z. for k in range(K): @@ -129,11 +125,13 @@ def softmax_blend_naive(colors, fragments, blend_params): alpha *= 1.0 - prob # cumulative product weights_k[k] = prob * torch.exp((zinv - zmax) / gamma) + # Clamp to ensure delta is never 0 + delta = torch.exp((eps - zmax) / blend_params.gamma).clamp(min=eps) + delta = delta.to(device) denom = weights_k.sum() + delta - weights = weights_k / denom - cols = (weights[..., None] * colors[n, h, w, :, :]).sum(dim=0) - pixel_colors[n, h, w, :3] = cols - pixel_colors[n, h, w, :3] += (delta / denom) * bk_color + cols = (weights_k[..., None] * colors[n, h, w, :, :]).sum(dim=0) + pixel_colors[n, h, w, :3] = cols + delta * bk_color + pixel_colors[n, h, w, :3] /= denom pixel_colors[n, h, w, 3] = 1.0 - alpha return pixel_colors @@ -160,6 +158,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase): (out2 * grad_out).sum().backward() self.assertTrue(hasattr(grad_var2, "grad")) + self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5) def test_hard_rgb_blend(self):