Softmax blending small fix

Summary:
Small fix to the softmax blending function.

To avoid overflow in the exponential for the softmax, the exponent is shifted by the maximum value. In the final calculation of the color there is a weighted sum between the pixel color and the background color - in order for the sum to be correct, the background color also needs to be handled in the same way witt the shifted exponent.

Reviewed By: gkioxari

Differential Revision: D23148301

fbshipit-source-id: 86066586ee7d3ce7bd4a2076b12ce191fbd151a7
This commit is contained in:
Nikhila Ravi 2020-08-17 11:57:32 -07:00 committed by Facebook GitHub Bot
parent 8e9ff15faf
commit 5852b74d12
2 changed files with 18 additions and 18 deletions

View File

@ -3,7 +3,6 @@
from typing import NamedTuple, Sequence
import numpy as np
import torch
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
@ -162,9 +161,8 @@ def softmax_rgb_blend(
if not torch.is_tensor(background):
background = torch.tensor(background, dtype=torch.float32, device=device)
# Background color
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
delta = torch.tensor(delta, device=device)
# Weight for background color
eps = 1e-10
# Mask for padded pixels.
mask = fragments.pix_to_face >= 0
@ -189,15 +187,18 @@ def softmax_rgb_blend(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
# Also apply exp normalize trick for the background color weight.
# Clamp to ensure delta is never 0.
delta = torch.exp((eps - z_inv_max) / blend_params.gamma).clamp(min=eps)
# Normalize weights.
# weights_num shape: (N, H, W, K). Sum over K and divide through by the sum.
denom = weights_num.sum(dim=-1)[..., None] + delta
weights = weights_num / denom
# Sum: weights * textures + background color
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
weighted_background = (delta / denom) * background
pixel_colors[..., :3] = weighted_colors + weighted_background
weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
weighted_background = delta * background
pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
pixel_colors[..., 3] = 1.0 - alpha
return pixel_colors

View File

@ -2,7 +2,6 @@
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.blending import (
@ -97,21 +96,18 @@ def softmax_blend_naive(colors, fragments, blend_params):
# Near and far clipping planes
zfar = 100.0
znear = 1.0
eps = 1e-10
bk_color = blend_params.background_color
if not torch.is_tensor(bk_color):
bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device)
# Background color component
delta = np.exp(1e-10 / gamma) * 1e-10
delta = torch.tensor(delta).to(device=device)
for n in range(N):
for h in range(H):
for w in range(W):
alpha = 1.0
weights_k = torch.zeros(K, device=device)
zmax = 0.0
zmax = torch.tensor(0.0, device=device)
# Loop over K to find max z.
for k in range(K):
@ -129,11 +125,13 @@ def softmax_blend_naive(colors, fragments, blend_params):
alpha *= 1.0 - prob # cumulative product
weights_k[k] = prob * torch.exp((zinv - zmax) / gamma)
# Clamp to ensure delta is never 0
delta = torch.exp((eps - zmax) / blend_params.gamma).clamp(min=eps)
delta = delta.to(device)
denom = weights_k.sum() + delta
weights = weights_k / denom
cols = (weights[..., None] * colors[n, h, w, :, :]).sum(dim=0)
pixel_colors[n, h, w, :3] = cols
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
cols = (weights_k[..., None] * colors[n, h, w, :, :]).sum(dim=0)
pixel_colors[n, h, w, :3] = cols + delta * bk_color
pixel_colors[n, h, w, :3] /= denom
pixel_colors[n, h, w, 3] = 1.0 - alpha
return pixel_colors
@ -160,6 +158,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
(out2 * grad_out).sum().backward()
self.assertTrue(hasattr(grad_var2, "grad"))
self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
def test_hard_rgb_blend(self):