mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
C++/CUDA implementation of sigmoid alpha blend
Summary: C++/CUDA implementation of forward and backward passes for the sigmoid alpha blending function. This is slightly faster than the vectorized implementation in Python, but more importantly uses less memory due to fewer tensors being created. Reviewed By: gkioxari Differential Revision: D19980671 fbshipit-source-id: 0779055d2c68b1f20fb0870e60046077ef4613ff
This commit is contained in:
committed by
Facebook GitHub Bot
parent
dc08c30583
commit
bce396df93
@@ -8,17 +8,24 @@ from test_blending import TestBlending
|
||||
|
||||
|
||||
def bm_blending() -> None:
|
||||
devices = ["cpu", "cuda"]
|
||||
devices = ["cuda"]
|
||||
kwargs_list = []
|
||||
num_meshes = [16]
|
||||
image_size = [128, 256]
|
||||
num_meshes = [8]
|
||||
image_size = [64, 128, 256]
|
||||
faces_per_pixel = [50, 100]
|
||||
test_cases = product(num_meshes, image_size, faces_per_pixel, devices)
|
||||
backend = ["pytorch", "custom"]
|
||||
test_cases = product(num_meshes, image_size, faces_per_pixel, devices, backend)
|
||||
|
||||
for case in test_cases:
|
||||
n, s, k, d = case
|
||||
n, s, k, d, b = case
|
||||
kwargs_list.append(
|
||||
{"num_meshes": n, "image_size": s, "faces_per_pixel": k, "device": d}
|
||||
{
|
||||
"num_meshes": n,
|
||||
"image_size": s,
|
||||
"faces_per_pixel": k,
|
||||
"device": d,
|
||||
"backend": b,
|
||||
}
|
||||
)
|
||||
|
||||
benchmark(
|
||||
@@ -28,6 +35,7 @@ def bm_blending() -> None:
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
kwargs_list = [case for case in kwargs_list if case["backend"] == "pytorch"]
|
||||
benchmark(
|
||||
TestBlending.bm_softmax_blending,
|
||||
"SOFTMAX_BLENDING_PYTORCH",
|
||||
|
||||
@@ -44,6 +44,16 @@ def sigmoid_blend_naive_loop(colors, fragments, blend_params):
|
||||
return pixel_colors
|
||||
|
||||
|
||||
def sigmoid_alpha_blend_vectorized(colors, fragments, blend_params) -> torch.Tensor:
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
|
||||
mask = fragments.pix_to_face >= 0
|
||||
prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
|
||||
pixel_colors[..., :3] = colors[..., 0, :]
|
||||
pixel_colors[..., 3] = 1.0 - torch.prod((1.0 - prob), dim=-1)
|
||||
return pixel_colors
|
||||
|
||||
|
||||
def sigmoid_blend_naive_loop_backward(grad_images, images, fragments, blend_params):
|
||||
pix_to_face = fragments.pix_to_face
|
||||
dists = fragments.dists
|
||||
@@ -136,10 +146,9 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
def _compare_impls(
|
||||
self, fn1, fn2, args1, args2, grad_var1=None, grad_var2=None, compare_grads=True
|
||||
):
|
||||
|
||||
out1 = fn1(*args1)
|
||||
out2 = fn2(*args2)
|
||||
self.assertTrue(torch.allclose(out1.cpu(), out2.cpu(), atol=1e-7))
|
||||
self.assertClose(out1.cpu()[..., 3], out2.cpu()[..., 3], atol=1e-7)
|
||||
|
||||
# Check gradients
|
||||
if not compare_grads:
|
||||
@@ -151,9 +160,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
(out2 * grad_out).sum().backward()
|
||||
self.assertTrue(hasattr(grad_var2, "grad"))
|
||||
self.assertTrue(
|
||||
torch.allclose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
|
||||
)
|
||||
self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
|
||||
|
||||
def test_hard_rgb_blend(self):
|
||||
N, H, W, K = 5, 10, 10, 20
|
||||
@@ -223,18 +230,15 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
torch.manual_seed(231)
|
||||
F = 32 # number of faces in the mesh
|
||||
# The python loop version is really slow so only using small input sizes.
|
||||
N, S, K = 2, 10, 5
|
||||
N, S, K = 1, 4, 1
|
||||
device = torch.device("cuda")
|
||||
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||
pix_to_face = torch.randint(low=-1, high=F, size=(N, S, S, K), device=device)
|
||||
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||
empty = torch.tensor([], device=device)
|
||||
|
||||
# # randomly flip the sign of the distance
|
||||
# # (-) means inside triangle, (+) means outside triangle.
|
||||
random_sign_flip = torch.rand((N, S, S, K))
|
||||
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||
dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||
dists2 = dists1.detach().clone()
|
||||
dists1 = torch.randn(size=(N, S, S, K), device=device)
|
||||
dists2 = dists1.clone()
|
||||
dists1.requires_grad = True
|
||||
dists2.requires_grad = True
|
||||
|
||||
fragments1 = Fragments(
|
||||
@@ -256,7 +260,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
self._compare_impls(
|
||||
sigmoid_alpha_blend,
|
||||
sigmoid_blend_naive_loop,
|
||||
sigmoid_alpha_blend_vectorized,
|
||||
args1,
|
||||
args2,
|
||||
dists1,
|
||||
@@ -324,26 +328,21 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
num_meshes: int = 16,
|
||||
image_size: int = 128,
|
||||
faces_per_pixel: int = 100,
|
||||
device: str = "cpu",
|
||||
device="cuda",
|
||||
backend: str = "pytorch",
|
||||
):
|
||||
if torch.cuda.is_available() and "cuda:" in device:
|
||||
# If a device other than the default is used, set the device explicity.
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
device = torch.device(device)
|
||||
torch.manual_seed(231)
|
||||
|
||||
# Create dummy outputs of rasterization
|
||||
N, S, K = num_meshes, image_size, faces_per_pixel
|
||||
F = 32 # num faces in the mesh
|
||||
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||
pix_to_face = torch.randint(
|
||||
low=-1, high=F + 1, size=(N, S, S, K), device=device
|
||||
)
|
||||
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||
empty = torch.tensor([], device=device)
|
||||
|
||||
# # randomly flip the sign of the distance
|
||||
# # (-) means inside triangle, (+) means outside triangle.
|
||||
random_sign_flip = torch.rand((N, S, S, K), device=device)
|
||||
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||
dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
@@ -352,11 +351,18 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
dists=dists1,
|
||||
)
|
||||
blend_params = BlendParams(sigma=1e-3)
|
||||
|
||||
blend_fn = (
|
||||
sigmoid_alpha_blend_vectorized
|
||||
if backend == "pytorch"
|
||||
else sigmoid_alpha_blend
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def fn():
|
||||
# test forward and backward pass
|
||||
images = sigmoid_alpha_blend(colors, fragments, blend_params)
|
||||
images = blend_fn(colors, fragments, blend_params)
|
||||
images.sum().backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@@ -368,6 +374,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
image_size: int = 128,
|
||||
faces_per_pixel: int = 100,
|
||||
device: str = "cpu",
|
||||
backend: str = "pytorch",
|
||||
):
|
||||
if torch.cuda.is_available() and "cuda:" in device:
|
||||
# If a device other than the default is used, set the device explicity.
|
||||
@@ -379,14 +386,12 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
# Create dummy outputs of rasterization
|
||||
N, S, K = num_meshes, image_size, faces_per_pixel
|
||||
F = 32 # num faces in the mesh
|
||||
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||
pix_to_face = torch.randint(
|
||||
low=-1, high=F + 1, size=(N, S, S, K), device=device
|
||||
)
|
||||
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||
empty = torch.tensor([], device=device)
|
||||
|
||||
# # randomly flip the sign of the distance
|
||||
# # (-) means inside triangle, (+) means outside triangle.
|
||||
random_sign_flip = torch.rand((N, S, S, K), device=device)
|
||||
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||
dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||
zbuf = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||
fragments = Fragments(
|
||||
|
||||
Reference in New Issue
Block a user