mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Softmax blending small fix
Summary: Small fix to the softmax blending function. To avoid overflow in the exponential for the softmax, the exponent is shifted by the maximum value. In the final calculation of the color there is a weighted sum between the pixel color and the background color - in order for the sum to be correct, the background color also needs to be handled in the same way witt the shifted exponent. Reviewed By: gkioxari Differential Revision: D23148301 fbshipit-source-id: 86066586ee7d3ce7bd4a2076b12ce191fbd151a7
This commit is contained in:
parent
8e9ff15faf
commit
5852b74d12
@ -3,7 +3,6 @@
|
||||
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
|
||||
@ -162,9 +161,8 @@ def softmax_rgb_blend(
|
||||
if not torch.is_tensor(background):
|
||||
background = torch.tensor(background, dtype=torch.float32, device=device)
|
||||
|
||||
# Background color
|
||||
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
|
||||
delta = torch.tensor(delta, device=device)
|
||||
# Weight for background color
|
||||
eps = 1e-10
|
||||
|
||||
# Mask for padded pixels.
|
||||
mask = fragments.pix_to_face >= 0
|
||||
@ -189,15 +187,18 @@ def softmax_rgb_blend(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
|
||||
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
|
||||
|
||||
# Also apply exp normalize trick for the background color weight.
|
||||
# Clamp to ensure delta is never 0.
|
||||
delta = torch.exp((eps - z_inv_max) / blend_params.gamma).clamp(min=eps)
|
||||
|
||||
# Normalize weights.
|
||||
# weights_num shape: (N, H, W, K). Sum over K and divide through by the sum.
|
||||
denom = weights_num.sum(dim=-1)[..., None] + delta
|
||||
weights = weights_num / denom
|
||||
|
||||
# Sum: weights * textures + background color
|
||||
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
|
||||
weighted_background = (delta / denom) * background
|
||||
pixel_colors[..., :3] = weighted_colors + weighted_background
|
||||
weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
|
||||
weighted_background = delta * background
|
||||
pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
|
||||
pixel_colors[..., 3] = 1.0 - alpha
|
||||
|
||||
return pixel_colors
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.blending import (
|
||||
@ -97,21 +96,18 @@ def softmax_blend_naive(colors, fragments, blend_params):
|
||||
# Near and far clipping planes
|
||||
zfar = 100.0
|
||||
znear = 1.0
|
||||
eps = 1e-10
|
||||
|
||||
bk_color = blend_params.background_color
|
||||
if not torch.is_tensor(bk_color):
|
||||
bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device)
|
||||
|
||||
# Background color component
|
||||
delta = np.exp(1e-10 / gamma) * 1e-10
|
||||
delta = torch.tensor(delta).to(device=device)
|
||||
|
||||
for n in range(N):
|
||||
for h in range(H):
|
||||
for w in range(W):
|
||||
alpha = 1.0
|
||||
weights_k = torch.zeros(K, device=device)
|
||||
zmax = 0.0
|
||||
zmax = torch.tensor(0.0, device=device)
|
||||
|
||||
# Loop over K to find max z.
|
||||
for k in range(K):
|
||||
@ -129,11 +125,13 @@ def softmax_blend_naive(colors, fragments, blend_params):
|
||||
alpha *= 1.0 - prob # cumulative product
|
||||
weights_k[k] = prob * torch.exp((zinv - zmax) / gamma)
|
||||
|
||||
# Clamp to ensure delta is never 0
|
||||
delta = torch.exp((eps - zmax) / blend_params.gamma).clamp(min=eps)
|
||||
delta = delta.to(device)
|
||||
denom = weights_k.sum() + delta
|
||||
weights = weights_k / denom
|
||||
cols = (weights[..., None] * colors[n, h, w, :, :]).sum(dim=0)
|
||||
pixel_colors[n, h, w, :3] = cols
|
||||
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
|
||||
cols = (weights_k[..., None] * colors[n, h, w, :, :]).sum(dim=0)
|
||||
pixel_colors[n, h, w, :3] = cols + delta * bk_color
|
||||
pixel_colors[n, h, w, :3] /= denom
|
||||
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
||||
|
||||
return pixel_colors
|
||||
@ -160,6 +158,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
(out2 * grad_out).sum().backward()
|
||||
self.assertTrue(hasattr(grad_var2, "grad"))
|
||||
|
||||
self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
|
||||
|
||||
def test_hard_rgb_blend(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user