Softmax blending small fix

Summary:
Small fix to the softmax blending function.

To avoid overflow in the exponential for the softmax, the exponent is shifted by the maximum value. In the final calculation of the color there is a weighted sum between the pixel color and the background color - in order for the sum to be correct, the background color also needs to be handled in the same way witt the shifted exponent.

Reviewed By: gkioxari

Differential Revision: D23148301

fbshipit-source-id: 86066586ee7d3ce7bd4a2076b12ce191fbd151a7
This commit is contained in:
Nikhila Ravi 2020-08-17 11:57:32 -07:00 committed by Facebook GitHub Bot
parent 8e9ff15faf
commit 5852b74d12
2 changed files with 18 additions and 18 deletions

View File

@ -3,7 +3,6 @@
from typing import NamedTuple, Sequence from typing import NamedTuple, Sequence
import numpy as np
import torch import torch
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
@ -162,9 +161,8 @@ def softmax_rgb_blend(
if not torch.is_tensor(background): if not torch.is_tensor(background):
background = torch.tensor(background, dtype=torch.float32, device=device) background = torch.tensor(background, dtype=torch.float32, device=device)
# Background color # Weight for background color
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10 eps = 1e-10
delta = torch.tensor(delta, device=device)
# Mask for padded pixels. # Mask for padded pixels.
mask = fragments.pix_to_face >= 0 mask = fragments.pix_to_face >= 0
@ -189,15 +187,18 @@ def softmax_rgb_blend(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`. # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma) weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
# Also apply exp normalize trick for the background color weight.
# Clamp to ensure delta is never 0.
delta = torch.exp((eps - z_inv_max) / blend_params.gamma).clamp(min=eps)
# Normalize weights. # Normalize weights.
# weights_num shape: (N, H, W, K). Sum over K and divide through by the sum. # weights_num shape: (N, H, W, K). Sum over K and divide through by the sum.
denom = weights_num.sum(dim=-1)[..., None] + delta denom = weights_num.sum(dim=-1)[..., None] + delta
weights = weights_num / denom
# Sum: weights * textures + background color # Sum: weights * textures + background color
weighted_colors = (weights[..., None] * colors).sum(dim=-2) weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
weighted_background = (delta / denom) * background weighted_background = delta * background
pixel_colors[..., :3] = weighted_colors + weighted_background pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
pixel_colors[..., 3] = 1.0 - alpha pixel_colors[..., 3] = 1.0 - alpha
return pixel_colors return pixel_colors

View File

@ -2,7 +2,6 @@
import unittest import unittest
import numpy as np
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.renderer.blending import ( from pytorch3d.renderer.blending import (
@ -97,21 +96,18 @@ def softmax_blend_naive(colors, fragments, blend_params):
# Near and far clipping planes # Near and far clipping planes
zfar = 100.0 zfar = 100.0
znear = 1.0 znear = 1.0
eps = 1e-10
bk_color = blend_params.background_color bk_color = blend_params.background_color
if not torch.is_tensor(bk_color): if not torch.is_tensor(bk_color):
bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device) bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device)
# Background color component
delta = np.exp(1e-10 / gamma) * 1e-10
delta = torch.tensor(delta).to(device=device)
for n in range(N): for n in range(N):
for h in range(H): for h in range(H):
for w in range(W): for w in range(W):
alpha = 1.0 alpha = 1.0
weights_k = torch.zeros(K, device=device) weights_k = torch.zeros(K, device=device)
zmax = 0.0 zmax = torch.tensor(0.0, device=device)
# Loop over K to find max z. # Loop over K to find max z.
for k in range(K): for k in range(K):
@ -129,11 +125,13 @@ def softmax_blend_naive(colors, fragments, blend_params):
alpha *= 1.0 - prob # cumulative product alpha *= 1.0 - prob # cumulative product
weights_k[k] = prob * torch.exp((zinv - zmax) / gamma) weights_k[k] = prob * torch.exp((zinv - zmax) / gamma)
# Clamp to ensure delta is never 0
delta = torch.exp((eps - zmax) / blend_params.gamma).clamp(min=eps)
delta = delta.to(device)
denom = weights_k.sum() + delta denom = weights_k.sum() + delta
weights = weights_k / denom cols = (weights_k[..., None] * colors[n, h, w, :, :]).sum(dim=0)
cols = (weights[..., None] * colors[n, h, w, :, :]).sum(dim=0) pixel_colors[n, h, w, :3] = cols + delta * bk_color
pixel_colors[n, h, w, :3] = cols pixel_colors[n, h, w, :3] /= denom
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
pixel_colors[n, h, w, 3] = 1.0 - alpha pixel_colors[n, h, w, 3] = 1.0 - alpha
return pixel_colors return pixel_colors
@ -160,6 +158,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
(out2 * grad_out).sum().backward() (out2 * grad_out).sum().backward()
self.assertTrue(hasattr(grad_var2, "grad")) self.assertTrue(hasattr(grad_var2, "grad"))
self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5) self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
def test_hard_rgb_blend(self): def test_hard_rgb_blend(self):