mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Blending fixes and test updates
Summary: Changed `torch.cumprod` to `torch.prod` in blending functions and added more tests and benchmark tests. This should fix the issue raised on GitHub. Reviewed By: gkioxari Differential Revision: D20163073 fbshipit-source-id: 4569fd37be11aa4435a3ce8736b55622c00ec718
This commit is contained in:
parent
ff19c642cb
commit
ba11c0b59c
@ -45,7 +45,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
|||||||
"""
|
"""
|
||||||
Silhouette blending to return an RGBA image
|
Silhouette blending to return an RGBA image
|
||||||
- **RGB** - choose color of the closest point.
|
- **RGB** - choose color of the closest point.
|
||||||
- **A** - blend based on the 2D distance based probability map [0].
|
- **A** - blend based on the 2D distance based probability map [1].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
|
colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
|
||||||
@ -60,7 +60,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
RGBA pixel_colors: (N, H, W, 4)
|
RGBA pixel_colors: (N, H, W, 4)
|
||||||
|
|
||||||
[0] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
|
[1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
|
||||||
3D Reasoning', ICCV 2019
|
3D Reasoning', ICCV 2019
|
||||||
"""
|
"""
|
||||||
N, H, W, K = fragments.pix_to_face.shape
|
N, H, W, K = fragments.pix_to_face.shape
|
||||||
@ -73,20 +73,13 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
|||||||
# the face. Therefore use -1.0 * fragments.dists to get the correct sign.
|
# the face. Therefore use -1.0 * fragments.dists to get the correct sign.
|
||||||
prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
|
prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
|
||||||
|
|
||||||
# The cumulative product ensures that alpha will be 1 if at least 1 face
|
# The cumulative product ensures that alpha will be 0.0 if at least 1
|
||||||
# fully covers the pixel as for that face prob will be 1.0
|
# face fully covers the pixel as for that face, prob will be 1.0.
|
||||||
# TODO: investigate why torch.cumprod backwards is very slow for large
|
# This results in a multiplication by 0.0 because of the (1.0 - prob)
|
||||||
# values of K.
|
# term. Therefore 1.0 - alpha will be 1.0.
|
||||||
# Temporarily replace this with exp(sum(log))) using the fact that
|
alpha = torch.prod((1.0 - prob), dim=-1)
|
||||||
# a*b = exp(log(a*b)) = exp(log(a) + log(b))
|
|
||||||
# alpha = 1.0 - torch.cumprod((1.0 - prob), dim=-1)[..., -1]
|
|
||||||
|
|
||||||
alpha = 1.0 - torch.exp(torch.log((1.0 - prob)).sum(dim=-1))
|
|
||||||
|
|
||||||
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
|
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
|
||||||
pixel_colors[..., 3] = alpha
|
pixel_colors[..., 3] = 1.0 - alpha
|
||||||
|
|
||||||
pixel_colors = torch.clamp(pixel_colors, min=0, max=1.0)
|
|
||||||
return torch.flip(pixel_colors, [1])
|
return torch.flip(pixel_colors, [1])
|
||||||
|
|
||||||
|
|
||||||
@ -95,7 +88,7 @@ def softmax_rgb_blend(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
RGB and alpha channel blending to return an RGBA image based on the method
|
RGB and alpha channel blending to return an RGBA image based on the method
|
||||||
proposed in [0]
|
proposed in [1]
|
||||||
- **RGB** - blend the colors based on the 2D distance based probability map and
|
- **RGB** - blend the colors based on the 2D distance based probability map and
|
||||||
relative z distances.
|
relative z distances.
|
||||||
- **A** - blend based on the 2D distance based probability map.
|
- **A** - blend based on the 2D distance based probability map.
|
||||||
@ -151,15 +144,11 @@ def softmax_rgb_blend(
|
|||||||
# Sigmoid probability map based on the distance of the pixel to the face.
|
# Sigmoid probability map based on the distance of the pixel to the face.
|
||||||
prob_map = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
|
prob_map = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
|
||||||
|
|
||||||
# The cumulative product ensures that alpha will be 1 if at least 1 face
|
# The cumulative product ensures that alpha will be 0.0 if at least 1
|
||||||
# fully covers the pixel as for that face prob will be 1.0
|
# face fully covers the pixel as for that face, prob will be 1.0.
|
||||||
# TODO: investigate why torch.cumprod backwards is very slow for large
|
# This results in a multiplication by 0.0 because of the (1.0 - prob)
|
||||||
# values of K.
|
# term. Therefore 1.0 - alpha will be 1.0.
|
||||||
# Temporarily replace this with exp(sum(log))) using the fact that
|
alpha = torch.prod((1.0 - prob_map), dim=-1)
|
||||||
# a*b = exp(log(a*b)) = exp(log(a) + log(b))
|
|
||||||
# alpha = 1.0 - torch.cumprod((1.0 - prob), dim=-1)[..., -1]
|
|
||||||
|
|
||||||
alpha = 1.0 - torch.exp(torch.log((1.0 - prob_map)).sum(dim=-1))
|
|
||||||
|
|
||||||
# Weights for each face. Adjust the exponential by the max z to prevent
|
# Weights for each face. Adjust the exponential by the max z to prevent
|
||||||
# overflow. zbuf shape (N, H, W, K), find max over K.
|
# overflow. zbuf shape (N, H, W, K), find max over K.
|
||||||
@ -178,8 +167,6 @@ def softmax_rgb_blend(
|
|||||||
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
|
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
|
||||||
weighted_background = (delta / denom) * background
|
weighted_background = (delta / denom) * background
|
||||||
pix_colors[..., :3] = weighted_colors + weighted_background
|
pix_colors[..., :3] = weighted_colors + weighted_background
|
||||||
pix_colors[..., 3] = alpha
|
pix_colors[..., 3] = 1.0 - alpha
|
||||||
|
|
||||||
# Clamp colors to the range 0-1 and flip y axis.
|
|
||||||
pix_colors = torch.clamp(pix_colors, min=0, max=1.0)
|
|
||||||
return torch.flip(pix_colors, [1])
|
return torch.flip(pix_colors, [1])
|
||||||
|
42
tests/bm_blending.py
Normal file
42
tests/bm_blending.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
|
|
||||||
|
from test_blending import TestBlending
|
||||||
|
|
||||||
|
|
||||||
|
def bm_blending() -> None:
|
||||||
|
devices = ["cpu", "cuda"]
|
||||||
|
kwargs_list = []
|
||||||
|
num_meshes = [16]
|
||||||
|
image_size = [128, 256]
|
||||||
|
faces_per_pixel = [50, 100]
|
||||||
|
test_cases = product(num_meshes, image_size, faces_per_pixel, devices)
|
||||||
|
|
||||||
|
for case in test_cases:
|
||||||
|
n, s, k, d = case
|
||||||
|
kwargs_list.append(
|
||||||
|
{
|
||||||
|
"num_meshes": n,
|
||||||
|
"image_size": s,
|
||||||
|
"faces_per_pixel": k,
|
||||||
|
"device": d,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmark(
|
||||||
|
TestBlending.bm_sigmoid_alpha_blending,
|
||||||
|
"SIGMOID_ALPHA_BLENDING_PYTORCH",
|
||||||
|
kwargs_list,
|
||||||
|
warmup_iters=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmark(
|
||||||
|
TestBlending.bm_softmax_blending,
|
||||||
|
"SOFTMAX_BLENDING_PYTORCH",
|
||||||
|
kwargs_list,
|
||||||
|
warmup_iters=1,
|
||||||
|
)
|
@ -14,7 +14,7 @@ from pytorch3d.renderer.blending import (
|
|||||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_blend_naive(colors, fragments, blend_params):
|
def sigmoid_blend_naive_loop(colors, fragments, blend_params):
|
||||||
"""
|
"""
|
||||||
Naive for loop based implementation of distance based alpha calculation.
|
Naive for loop based implementation of distance based alpha calculation.
|
||||||
Only for test purposes.
|
Only for test purposes.
|
||||||
@ -41,10 +41,38 @@ def sigmoid_blend_naive(colors, fragments, blend_params):
|
|||||||
pixel_colors[n, h, w, :3] = colors[n, h, w, 0, :]
|
pixel_colors[n, h, w, :3] = colors[n, h, w, 0, :]
|
||||||
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
||||||
|
|
||||||
pixel_colors = torch.clamp(pixel_colors, min=0, max=1.0)
|
|
||||||
return torch.flip(pixel_colors, [1])
|
return torch.flip(pixel_colors, [1])
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid_blend_naive_loop_backward(
|
||||||
|
grad_images, images, fragments, blend_params
|
||||||
|
):
|
||||||
|
pix_to_face = fragments.pix_to_face
|
||||||
|
dists = fragments.dists
|
||||||
|
sigma = blend_params.sigma
|
||||||
|
|
||||||
|
N, H, W, K = pix_to_face.shape
|
||||||
|
device = pix_to_face.device
|
||||||
|
grad_distances = torch.zeros((N, H, W, K), dtype=dists.dtype, device=device)
|
||||||
|
images = torch.flip(images, [1])
|
||||||
|
grad_images = torch.flip(grad_images, [1])
|
||||||
|
|
||||||
|
for n in range(N):
|
||||||
|
for h in range(H):
|
||||||
|
for w in range(W):
|
||||||
|
alpha = 1.0 - images[n, h, w, 3]
|
||||||
|
grad_alpha = grad_images[n, h, w, 3]
|
||||||
|
# Loop over k faces and calculate 2D distance based probability
|
||||||
|
# map.
|
||||||
|
for k in range(K):
|
||||||
|
if pix_to_face[n, h, w, k] >= 0:
|
||||||
|
prob = torch.sigmoid(-dists[n, h, w, k] / sigma)
|
||||||
|
grad_distances[n, h, w, k] = (
|
||||||
|
grad_alpha * (-1.0 / sigma) * prob * alpha
|
||||||
|
)
|
||||||
|
return grad_distances
|
||||||
|
|
||||||
|
|
||||||
def softmax_blend_naive(colors, fragments, blend_params):
|
def softmax_blend_naive(colors, fragments, blend_params):
|
||||||
"""
|
"""
|
||||||
Naive for loop based implementation of softmax blending.
|
Naive for loop based implementation of softmax blending.
|
||||||
@ -76,7 +104,7 @@ def softmax_blend_naive(colors, fragments, blend_params):
|
|||||||
for h in range(H):
|
for h in range(H):
|
||||||
for w in range(W):
|
for w in range(W):
|
||||||
alpha = 1.0
|
alpha = 1.0
|
||||||
weights_k = torch.zeros(K)
|
weights_k = torch.zeros(K, device=device)
|
||||||
zmax = 0.0
|
zmax = 0.0
|
||||||
|
|
||||||
# Loop over K to find max z.
|
# Loop over K to find max z.
|
||||||
@ -102,7 +130,6 @@ def softmax_blend_naive(colors, fragments, blend_params):
|
|||||||
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
|
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
|
||||||
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
||||||
|
|
||||||
pixel_colors = torch.clamp(pixel_colors, min=0, max=1.0)
|
|
||||||
return torch.flip(pixel_colors, [1])
|
return torch.flip(pixel_colors, [1])
|
||||||
|
|
||||||
|
|
||||||
@ -110,6 +137,37 @@ class TestBlending(unittest.TestCase):
|
|||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
def _compare_impls(
|
||||||
|
self,
|
||||||
|
fn1,
|
||||||
|
fn2,
|
||||||
|
args1,
|
||||||
|
args2,
|
||||||
|
grad_var1=None,
|
||||||
|
grad_var2=None,
|
||||||
|
compare_grads=True,
|
||||||
|
):
|
||||||
|
|
||||||
|
out1 = fn1(*args1)
|
||||||
|
out2 = fn2(*args2)
|
||||||
|
self.assertTrue(torch.allclose(out1.cpu(), out2.cpu(), atol=1e-7))
|
||||||
|
|
||||||
|
# Check gradients
|
||||||
|
if not compare_grads:
|
||||||
|
return
|
||||||
|
|
||||||
|
grad_out = torch.randn_like(out1)
|
||||||
|
(out1 * grad_out).sum().backward()
|
||||||
|
self.assertTrue(hasattr(grad_var1, "grad"))
|
||||||
|
|
||||||
|
(out2 * grad_out).sum().backward()
|
||||||
|
self.assertTrue(hasattr(grad_var2, "grad"))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_hard_rgb_blend(self):
|
def test_hard_rgb_blend(self):
|
||||||
N, H, W, K = 5, 10, 10, 20
|
N, H, W, K = 5, 10, 10, 20
|
||||||
pix_to_face = torch.ones((N, H, W, K))
|
pix_to_face = torch.ones((N, H, W, K))
|
||||||
@ -129,116 +187,246 @@ class TestBlending(unittest.TestCase):
|
|||||||
expected_vals[..., :3] = pix_cols
|
expected_vals[..., :3] = pix_cols
|
||||||
self.assertTrue(torch.allclose(images, expected_vals))
|
self.assertTrue(torch.allclose(images, expected_vals))
|
||||||
|
|
||||||
def test_sigmoid_alpha_blend(self):
|
def test_sigmoid_alpha_blend_manual_gradients(self):
|
||||||
"""
|
# Create dummy outputs of rasterization
|
||||||
Test outputs of sigmoid alpha blend tensorised function match those of
|
torch.manual_seed(231)
|
||||||
the naive iterative version. Also check gradients match.
|
F = 32 # number of faces in the mesh
|
||||||
"""
|
# The python loop version is really slow so only using small input sizes.
|
||||||
|
N, S, K = 2, 3, 2
|
||||||
|
device = torch.device("cuda")
|
||||||
|
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||||
|
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||||
|
empty = torch.tensor([], device=device)
|
||||||
|
|
||||||
# Create dummy outputs of rasterization simulating a cube in the centre
|
# # randomly flip the sign of the distance
|
||||||
# of the image with surrounding padded values.
|
# # (-) means inside triangle, (+) means outside triangle.
|
||||||
N, S, K = 1, 8, 2
|
|
||||||
pix_to_face = -torch.ones((N, S, S, K), dtype=torch.int64)
|
|
||||||
h = int(S / 2)
|
|
||||||
pix_to_face_full = torch.randint(size=(N, h, h, K), low=0, high=100)
|
|
||||||
s = int(S / 4)
|
|
||||||
e = int(0.75 * S)
|
|
||||||
pix_to_face[:, s:e, s:e, :] = pix_to_face_full
|
|
||||||
bary_coords = torch.ones((N, S, S, K, 3))
|
|
||||||
|
|
||||||
# randomly flip the sign of the distance
|
|
||||||
# (-) means inside triangle, (+) means outside triangle.
|
|
||||||
random_sign_flip = torch.rand((N, S, S, K))
|
random_sign_flip = torch.rand((N, S, S, K))
|
||||||
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||||
dists = torch.randn(size=(N, S, S, K))
|
dists = torch.randn(
|
||||||
dists1 = dists * random_sign_flip
|
size=(N, S, S, K), requires_grad=True, device=device
|
||||||
dists2 = dists1.clone()
|
)
|
||||||
dists1.requires_grad = True
|
fragments = Fragments(
|
||||||
|
pix_to_face=pix_to_face,
|
||||||
|
bary_coords=empty, # dummy
|
||||||
|
zbuf=empty, # dummy
|
||||||
|
dists=dists,
|
||||||
|
)
|
||||||
|
blend_params = BlendParams(sigma=1e-3)
|
||||||
|
pix_cols = sigmoid_blend_naive_loop(colors, fragments, blend_params)
|
||||||
|
grad_out = torch.randn_like(pix_cols)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
pix_cols.backward(grad_out)
|
||||||
|
grad_dists = sigmoid_blend_naive_loop_backward(
|
||||||
|
grad_out, pix_cols, fragments, blend_params
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(dists.grad, grad_dists, atol=1e-7))
|
||||||
|
|
||||||
|
def test_sigmoid_alpha_blend_python(self):
|
||||||
|
"""
|
||||||
|
Test outputs of python tensorised function and python loop
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create dummy outputs of rasterization
|
||||||
|
torch.manual_seed(231)
|
||||||
|
F = 32 # number of faces in the mesh
|
||||||
|
# The python loop version is really slow so only using small input sizes.
|
||||||
|
N, S, K = 2, 10, 5
|
||||||
|
device = torch.device("cuda")
|
||||||
|
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||||
|
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||||
|
empty = torch.tensor([], device=device)
|
||||||
|
|
||||||
|
# # randomly flip the sign of the distance
|
||||||
|
# # (-) means inside triangle, (+) means outside triangle.
|
||||||
|
random_sign_flip = torch.rand((N, S, S, K))
|
||||||
|
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||||
|
dists1 = torch.randn(
|
||||||
|
size=(N, S, S, K), requires_grad=True, device=device
|
||||||
|
)
|
||||||
|
dists2 = dists1.detach().clone()
|
||||||
dists2.requires_grad = True
|
dists2.requires_grad = True
|
||||||
colors = torch.randn_like(bary_coords)
|
|
||||||
fragments1 = Fragments(
|
fragments1 = Fragments(
|
||||||
pix_to_face=pix_to_face,
|
pix_to_face=pix_to_face,
|
||||||
bary_coords=bary_coords, # dummy
|
bary_coords=empty, # dummy
|
||||||
zbuf=pix_to_face, # dummy
|
zbuf=empty, # dummy
|
||||||
dists=dists1,
|
dists=dists1,
|
||||||
)
|
)
|
||||||
fragments2 = Fragments(
|
fragments2 = Fragments(
|
||||||
pix_to_face=pix_to_face,
|
pix_to_face=pix_to_face,
|
||||||
bary_coords=bary_coords, # dummy
|
bary_coords=empty, # dummy
|
||||||
zbuf=pix_to_face, # dummy
|
zbuf=empty, # dummy
|
||||||
dists=dists2,
|
dists=dists2,
|
||||||
)
|
)
|
||||||
blend_params = BlendParams(sigma=2e-1)
|
|
||||||
images = sigmoid_alpha_blend(colors, fragments1, blend_params)
|
|
||||||
images_naive = sigmoid_blend_naive(colors, fragments2, blend_params)
|
|
||||||
self.assertTrue(torch.allclose(images, images_naive))
|
|
||||||
|
|
||||||
torch.manual_seed(231)
|
blend_params = BlendParams(sigma=1e-2)
|
||||||
images.sum().backward()
|
args1 = (colors, fragments1, blend_params)
|
||||||
self.assertTrue(hasattr(dists1, "grad"))
|
args2 = (colors, fragments2, blend_params)
|
||||||
images_naive.sum().backward()
|
|
||||||
self.assertTrue(hasattr(dists2, "grad"))
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(dists1.grad, dists2.grad, rtol=1e-5))
|
self._compare_impls(
|
||||||
|
sigmoid_alpha_blend,
|
||||||
|
sigmoid_blend_naive_loop,
|
||||||
|
args1,
|
||||||
|
args2,
|
||||||
|
dists1,
|
||||||
|
dists2,
|
||||||
|
compare_grads=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_softmax_rgb_blend(self):
|
def test_softmax_rgb_blend(self):
|
||||||
# Create dummy outputs of rasterization simulating a cube in the centre
|
# Create dummy outputs of rasterization simulating a cube in the centre
|
||||||
# of the image with surrounding padded values.
|
# of the image with surrounding padded values.
|
||||||
N, S, K = 1, 8, 2
|
N, S, K = 1, 8, 2
|
||||||
pix_to_face = -torch.ones((N, S, S, K), dtype=torch.int64)
|
device = torch.device("cuda")
|
||||||
|
pix_to_face = -torch.ones(
|
||||||
|
(N, S, S, K), dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
h = int(S / 2)
|
h = int(S / 2)
|
||||||
pix_to_face_full = torch.randint(size=(N, h, h, K), low=0, high=100)
|
pix_to_face_full = torch.randint(
|
||||||
|
size=(N, h, h, K), low=0, high=100, device=device
|
||||||
|
)
|
||||||
s = int(S / 4)
|
s = int(S / 4)
|
||||||
e = int(0.75 * S)
|
e = int(0.75 * S)
|
||||||
pix_to_face[:, s:e, s:e, :] = pix_to_face_full
|
pix_to_face[:, s:e, s:e, :] = pix_to_face_full
|
||||||
bary_coords = torch.ones((N, S, S, K, 3))
|
empty = torch.tensor([], device=device)
|
||||||
|
|
||||||
random_sign_flip = torch.rand((N, S, S, K))
|
random_sign_flip = torch.rand((N, S, S, K), device=device)
|
||||||
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||||
zbuf1 = torch.randn(size=(N, S, S, K))
|
zbuf1 = torch.randn(size=(N, S, S, K), device=device)
|
||||||
|
|
||||||
# randomly flip the sign of the distance
|
# randomly flip the sign of the distance
|
||||||
# (-) means inside triangle, (+) means outside triangle.
|
# (-) means inside triangle, (+) means outside triangle.
|
||||||
dists1 = torch.randn(size=(N, S, S, K)) * random_sign_flip
|
dists1 = (
|
||||||
|
torch.randn(size=(N, S, S, K), device=device) * random_sign_flip
|
||||||
|
)
|
||||||
dists2 = dists1.clone()
|
dists2 = dists1.clone()
|
||||||
zbuf2 = zbuf1.clone()
|
zbuf2 = zbuf1.clone()
|
||||||
dists1.requires_grad = True
|
dists1.requires_grad = True
|
||||||
dists2.requires_grad = True
|
dists2.requires_grad = True
|
||||||
zbuf1.requires_grad = True
|
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||||
zbuf2.requires_grad = True
|
|
||||||
colors = torch.randn_like(bary_coords)
|
|
||||||
fragments1 = Fragments(
|
fragments1 = Fragments(
|
||||||
pix_to_face=pix_to_face,
|
pix_to_face=pix_to_face,
|
||||||
bary_coords=bary_coords, # dummy
|
bary_coords=empty, # dummy
|
||||||
zbuf=zbuf1,
|
zbuf=zbuf1,
|
||||||
dists=dists1,
|
dists=dists1,
|
||||||
)
|
)
|
||||||
fragments2 = Fragments(
|
fragments2 = Fragments(
|
||||||
pix_to_face=pix_to_face,
|
pix_to_face=pix_to_face,
|
||||||
bary_coords=bary_coords, # dummy
|
bary_coords=empty, # dummy
|
||||||
zbuf=zbuf2,
|
zbuf=zbuf2,
|
||||||
dists=dists2,
|
dists=dists2,
|
||||||
)
|
)
|
||||||
blend_params = BlendParams(sigma=1e-1)
|
|
||||||
images = softmax_rgb_blend(colors, fragments1, blend_params)
|
|
||||||
images_naive = softmax_blend_naive(colors, fragments2, blend_params)
|
|
||||||
self.assertTrue(torch.allclose(images, images_naive))
|
|
||||||
|
|
||||||
# Check gradients.
|
blend_params = BlendParams(sigma=1e-3)
|
||||||
images.sum().backward()
|
args1 = (colors, fragments1, blend_params)
|
||||||
self.assertTrue(hasattr(dists1, "grad"))
|
args2 = (colors, fragments2, blend_params)
|
||||||
self.assertTrue(hasattr(zbuf1, "grad"))
|
self._compare_impls(
|
||||||
images_naive.sum().backward()
|
softmax_rgb_blend,
|
||||||
self.assertTrue(hasattr(dists2, "grad"))
|
softmax_blend_naive,
|
||||||
self.assertTrue(hasattr(zbuf2, "grad"))
|
args1,
|
||||||
|
args2,
|
||||||
|
dists1,
|
||||||
|
dists2,
|
||||||
|
compare_grads=True,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(dists1.grad, dists2.grad, atol=2e-5))
|
@staticmethod
|
||||||
self.assertTrue(torch.allclose(zbuf1.grad, zbuf2.grad, atol=2e-5))
|
def bm_sigmoid_alpha_blending(
|
||||||
|
num_meshes: int = 16,
|
||||||
|
image_size: int = 128,
|
||||||
|
faces_per_pixel: int = 100,
|
||||||
|
device: str = "cpu",
|
||||||
|
):
|
||||||
|
if torch.cuda.is_available() and "cuda:" in device:
|
||||||
|
# If a device other than the default is used, set the device explicity.
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
device = torch.device(device)
|
||||||
|
torch.manual_seed(231)
|
||||||
|
|
||||||
|
# Create dummy outputs of rasterization
|
||||||
|
N, S, K = num_meshes, image_size, faces_per_pixel
|
||||||
|
F = 32 # num faces in the mesh
|
||||||
|
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||||
|
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||||
|
empty = torch.tensor([], device=device)
|
||||||
|
|
||||||
|
# # randomly flip the sign of the distance
|
||||||
|
# # (-) means inside triangle, (+) means outside triangle.
|
||||||
|
random_sign_flip = torch.rand((N, S, S, K), device=device)
|
||||||
|
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||||
|
dists1 = torch.randn(
|
||||||
|
size=(N, S, S, K), requires_grad=True, device=device
|
||||||
|
)
|
||||||
|
fragments = Fragments(
|
||||||
|
pix_to_face=pix_to_face,
|
||||||
|
bary_coords=empty, # dummy
|
||||||
|
zbuf=empty, # dummy
|
||||||
|
dists=dists1,
|
||||||
|
)
|
||||||
|
blend_params = BlendParams(sigma=1e-3)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def fn():
|
||||||
|
# test forward and backward pass
|
||||||
|
images = sigmoid_alpha_blend(colors, fragments, blend_params)
|
||||||
|
images.sum().backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def bm_softmax_blending(
|
||||||
|
num_meshes: int = 16,
|
||||||
|
image_size: int = 128,
|
||||||
|
faces_per_pixel: int = 100,
|
||||||
|
device: str = "cpu",
|
||||||
|
):
|
||||||
|
if torch.cuda.is_available() and "cuda:" in device:
|
||||||
|
# If a device other than the default is used, set the device explicity.
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
device = torch.device(device)
|
||||||
|
torch.manual_seed(231)
|
||||||
|
|
||||||
|
# Create dummy outputs of rasterization
|
||||||
|
N, S, K = num_meshes, image_size, faces_per_pixel
|
||||||
|
F = 32 # num faces in the mesh
|
||||||
|
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||||
|
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||||
|
empty = torch.tensor([], device=device)
|
||||||
|
|
||||||
|
# # randomly flip the sign of the distance
|
||||||
|
# # (-) means inside triangle, (+) means outside triangle.
|
||||||
|
random_sign_flip = torch.rand((N, S, S, K), device=device)
|
||||||
|
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||||
|
dists1 = torch.randn(
|
||||||
|
size=(N, S, S, K), requires_grad=True, device=device
|
||||||
|
)
|
||||||
|
zbuf = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||||
|
fragments = Fragments(
|
||||||
|
pix_to_face=pix_to_face,
|
||||||
|
bary_coords=empty, # dummy
|
||||||
|
zbuf=zbuf,
|
||||||
|
dists=dists1,
|
||||||
|
)
|
||||||
|
blend_params = BlendParams(sigma=1e-3)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def fn():
|
||||||
|
# test forward and backward pass
|
||||||
|
images = softmax_rgb_blend(colors, fragments, blend_params)
|
||||||
|
images.sum().backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
def test_blend_params(self):
|
def test_blend_params(self):
|
||||||
"""Test colour parameter of BlendParams().
|
"""Test colour parameter of BlendParams().
|
||||||
Assert passed value overrides default value.
|
Assert passed value overrides default value.
|
||||||
"""
|
"""
|
||||||
bp_default = BlendParams()
|
bp_default = BlendParams()
|
||||||
bp_new = BlendParams(background_color=(0.5, 0.5, 0.5))
|
bp_new = BlendParams(background_color=(0.5, 0.5, 0.5))
|
||||||
self.assertEqual(bp_new.background_color, (0.5, 0.5, 0.5))
|
self.assertEqual(bp_new.background_color, (0.5, 0.5, 0.5))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user