mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Accumulate points (#4)
Summary: Code for accumulating points in the z-buffer in three ways: 1. weighted sum 2. normalised weighted sum 3. alpha compositing Pull Request resolved: https://github.com/fairinternal/pytorch3d/pull/4 Reviewed By: nikhilaravi Differential Revision: D20522422 Pulled By: gkioxari fbshipit-source-id: 5023baa05f15e338f3821ef08f5552c2dcbfc06c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5218f45c2c
commit
53599770dd
442
tests/test_compositing.py
Normal file
442
tests/test_compositing.py
Normal file
@@ -0,0 +1,442 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.renderer.compositing import (
|
||||
alpha_composite,
|
||||
norm_weighted_sum,
|
||||
weighted_sum,
|
||||
)
|
||||
|
||||
|
||||
class TestAccumulatePoints(unittest.TestCase):
|
||||
|
||||
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
|
||||
@staticmethod
|
||||
def accumulate_alphacomposite_python(points_idx, alphas, features):
|
||||
"""
|
||||
Naive pure PyTorch implementation of alpha_composite.
|
||||
Inputs / Outputs: Same as function
|
||||
"""
|
||||
|
||||
B, K, H, W = points_idx.size()
|
||||
C = features.size(0)
|
||||
|
||||
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
||||
|
||||
for b in range(0, B):
|
||||
for c in range(0, C):
|
||||
for i in range(0, W):
|
||||
for j in range(0, H):
|
||||
t_alpha = 1
|
||||
for k in range(0, K):
|
||||
n_idx = points_idx[b, k, j, i]
|
||||
|
||||
if n_idx < 0:
|
||||
continue
|
||||
|
||||
alpha = alphas[b, k, j, i]
|
||||
output[b, c, j, i] += (
|
||||
features[c, n_idx] * alpha * t_alpha
|
||||
)
|
||||
t_alpha = (1 - alpha) * t_alpha
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def accumulate_weightedsum_python(points_idx, alphas, features):
|
||||
"""
|
||||
Naive pure PyTorch implementation of weighted_sum rasterization.
|
||||
Inputs / Outputs: Same as function
|
||||
"""
|
||||
B, K, H, W = points_idx.size()
|
||||
C = features.size(0)
|
||||
|
||||
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
||||
|
||||
for b in range(0, B):
|
||||
for c in range(0, C):
|
||||
for i in range(0, W):
|
||||
for j in range(0, H):
|
||||
|
||||
for k in range(0, K):
|
||||
n_idx = points_idx[b, k, j, i]
|
||||
|
||||
if n_idx < 0:
|
||||
continue
|
||||
|
||||
alpha = alphas[b, k, j, i]
|
||||
output[b, c, j, i] += features[c, n_idx] * alpha
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def accumulate_weightedsumnorm_python(points_idx, alphas, features):
|
||||
"""
|
||||
Naive pure PyTorch implementation of norm_weighted_sum.
|
||||
Inputs / Outputs: Same as function
|
||||
"""
|
||||
|
||||
B, K, H, W = points_idx.size()
|
||||
C = features.size(0)
|
||||
|
||||
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
||||
|
||||
for b in range(0, B):
|
||||
for c in range(0, C):
|
||||
for i in range(0, W):
|
||||
for j in range(0, H):
|
||||
t_alpha = 0
|
||||
for k in range(0, K):
|
||||
n_idx = points_idx[b, k, j, i]
|
||||
|
||||
if n_idx < 0:
|
||||
continue
|
||||
|
||||
t_alpha += alphas[b, k, j, i]
|
||||
|
||||
t_alpha = max(t_alpha, 1e-4)
|
||||
|
||||
for k in range(0, K):
|
||||
n_idx = points_idx[b, k, j, i]
|
||||
|
||||
if n_idx < 0:
|
||||
continue
|
||||
|
||||
alpha = alphas[b, k, j, i]
|
||||
output[b, c, j, i] += (
|
||||
features[c, n_idx] * alpha / t_alpha
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def test_python(self):
|
||||
device = torch.device("cpu")
|
||||
self._simple_alphacomposite(
|
||||
self.accumulate_alphacomposite_python, device
|
||||
)
|
||||
self._simple_wsum(self.accumulate_weightedsum_python, device)
|
||||
self._simple_wsumnorm(self.accumulate_weightedsumnorm_python, device)
|
||||
|
||||
def test_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._simple_alphacomposite(alpha_composite, device)
|
||||
self._simple_wsum(weighted_sum, device)
|
||||
self._simple_wsumnorm(norm_weighted_sum, device)
|
||||
|
||||
def test_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._simple_alphacomposite(alpha_composite, device)
|
||||
self._simple_wsum(weighted_sum, device)
|
||||
self._simple_wsumnorm(norm_weighted_sum, device)
|
||||
|
||||
def test_python_vs_cpu_vs_cuda(self):
|
||||
self._python_vs_cpu_vs_cuda(
|
||||
self.accumulate_alphacomposite_python, alpha_composite
|
||||
)
|
||||
self._python_vs_cpu_vs_cuda(
|
||||
self.accumulate_weightedsumnorm_python, norm_weighted_sum
|
||||
)
|
||||
self._python_vs_cpu_vs_cuda(
|
||||
self.accumulate_weightedsum_python, weighted_sum
|
||||
)
|
||||
|
||||
def _python_vs_cpu_vs_cuda(self, accumulate_func_python, accumulate_func):
|
||||
torch.manual_seed(231)
|
||||
device = torch.device("cpu")
|
||||
|
||||
W = 8
|
||||
C = 3
|
||||
P = 32
|
||||
|
||||
for d in ["cpu", "cuda"]:
|
||||
# TODO(gkioxari) add torch.float64 to types after double precision
|
||||
# support is added to atomicAdd
|
||||
for t in [torch.float32]:
|
||||
device = torch.device(d)
|
||||
|
||||
# Create values
|
||||
alphas = torch.rand(2, 4, W, W, dtype=t).to(device)
|
||||
alphas.requires_grad = True
|
||||
alphas_cpu = alphas.detach().cpu()
|
||||
alphas_cpu.requires_grad = True
|
||||
|
||||
features = torch.randn(C, P, dtype=t).to(device)
|
||||
features.requires_grad = True
|
||||
features_cpu = features.detach().cpu()
|
||||
features_cpu.requires_grad = True
|
||||
|
||||
inds = torch.randint(P + 1, size=(2, 4, W, W)).to(device) - 1
|
||||
inds_cpu = inds.detach().cpu()
|
||||
|
||||
args_cuda = (inds, alphas, features)
|
||||
args_cpu = (inds_cpu, alphas_cpu, features_cpu)
|
||||
|
||||
self._compare_impls(
|
||||
accumulate_func_python,
|
||||
accumulate_func,
|
||||
args_cpu,
|
||||
args_cuda,
|
||||
(alphas_cpu, features_cpu),
|
||||
(alphas, features),
|
||||
compare_grads=True,
|
||||
)
|
||||
|
||||
def _compare_impls(
|
||||
self, fn1, fn2, args1, args2, grads1, grads2, compare_grads=False
|
||||
):
|
||||
res1 = fn1(*args1)
|
||||
res2 = fn2(*args2)
|
||||
|
||||
self.assertTrue(torch.allclose(res1.cpu(), res2.cpu(), atol=1e-6))
|
||||
|
||||
if not compare_grads:
|
||||
return
|
||||
|
||||
# Compare gradients
|
||||
torch.manual_seed(231)
|
||||
grad_res = torch.randn_like(res1)
|
||||
loss1 = (res1 * grad_res).sum()
|
||||
loss1.backward()
|
||||
|
||||
grads1 = [gradsi.grad.data.clone().cpu() for gradsi in grads1]
|
||||
grad_res = grad_res.to(res2)
|
||||
|
||||
loss2 = (res2 * grad_res).sum()
|
||||
loss2.backward()
|
||||
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
|
||||
|
||||
for i in range(0, len(grads1)):
|
||||
self.assertTrue(
|
||||
torch.allclose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
|
||||
)
|
||||
|
||||
def _simple_wsum(self, accum_func, device):
|
||||
# Initialise variables
|
||||
features = torch.Tensor(
|
||||
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
|
||||
).to(device)
|
||||
|
||||
alphas = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
points_idx = (
|
||||
torch.Tensor(
|
||||
[
|
||||
[
|
||||
# fmt: off
|
||||
[
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
[0, -1, -1, -1], # noqa: E241, E201
|
||||
[0, 1, 1, 0], # noqa: E241, E201
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[2, 2, 2, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 2, -1, 2], # noqa: E241, E201
|
||||
],
|
||||
# fmt: on
|
||||
]
|
||||
]
|
||||
)
|
||||
.long()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
result = accum_func(points_idx, alphas, features)
|
||||
|
||||
self.assertTrue(result.shape == (1, 2, 4, 4))
|
||||
|
||||
true_result = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.35, 0.35, 0.35, 0.35],
|
||||
[0.35, 0.90, 0.90, 0.30],
|
||||
[0.35, 1.30, 1.30, 0.35],
|
||||
[0.35, 0.35, 0.05, 0.35],
|
||||
],
|
||||
[
|
||||
[0.35, 0.35, 0.35, 0.35],
|
||||
[0.35, 0.90, 0.90, 0.30],
|
||||
[0.35, 1.30, 1.30, 0.35],
|
||||
[0.35, 0.35, 0.05, 0.35],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
||||
)
|
||||
|
||||
def _simple_wsumnorm(self, accum_func, device):
|
||||
# Initialise variables
|
||||
features = torch.Tensor(
|
||||
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
|
||||
).to(device)
|
||||
|
||||
alphas = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
# fmt: off
|
||||
points_idx = (
|
||||
torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
[0, -1, -1, -1], # noqa: E241, E201
|
||||
[0, 1, 1, 0], # noqa: E241, E201
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[2, 2, 2, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 2, -1, 2], # noqa: E241, E201
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
.long()
|
||||
.to(device)
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
result = accum_func(points_idx, alphas, features)
|
||||
|
||||
self.assertTrue(result.shape == (1, 2, 4, 4))
|
||||
|
||||
true_result = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.35, 0.35, 0.35, 0.35],
|
||||
[0.35, 0.90, 0.90, 0.60],
|
||||
[0.35, 0.65, 0.65, 0.35],
|
||||
[0.35, 0.35, 0.10, 0.35],
|
||||
],
|
||||
[
|
||||
[0.35, 0.35, 0.35, 0.35],
|
||||
[0.35, 0.90, 0.90, 0.60],
|
||||
[0.35, 0.65, 0.65, 0.35],
|
||||
[0.35, 0.35, 0.10, 0.35],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
||||
)
|
||||
|
||||
def _simple_alphacomposite(self, accum_func, device):
|
||||
# Initialise variables
|
||||
features = torch.Tensor(
|
||||
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
|
||||
).to(device)
|
||||
|
||||
alphas = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
# fmt: off
|
||||
points_idx = (
|
||||
torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
[0, -1, -1, -1], # noqa: E241, E201
|
||||
[0, 1, 1, 0], # noqa: E241, E201
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[2, 2, 2, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 2, -1, 2], # noqa: E241, E201
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
.long()
|
||||
.to(device)
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
result = accum_func(points_idx, alphas, features)
|
||||
|
||||
self.assertTrue(result.shape == (1, 2, 4, 4))
|
||||
|
||||
true_result = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.20, 0.20, 0.20, 0.20],
|
||||
[0.20, 0.90, 0.90, 0.30],
|
||||
[0.20, 0.40, 0.40, 0.20],
|
||||
[0.20, 0.20, 0.05, 0.20],
|
||||
],
|
||||
[
|
||||
[0.20, 0.20, 0.20, 0.20],
|
||||
[0.20, 0.90, 0.90, 0.30],
|
||||
[0.20, 0.40, 0.40, 0.20],
|
||||
[0.20, 0.20, 0.05, 0.20],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue((result == true_result).all().item())
|
||||
@@ -76,7 +76,10 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
verts_torch = verts.detach().clone().to(dtype)
|
||||
verts_torch.requires_grad = True
|
||||
faces_torch = faces.detach().clone()
|
||||
areas_torch, normals_torch = TestFaceAreasNormals.face_areas_normals_python(
|
||||
(
|
||||
areas_torch,
|
||||
normals_torch,
|
||||
) = TestFaceAreasNormals.face_areas_normals_python(
|
||||
verts_torch, faces_torch
|
||||
)
|
||||
self.assertClose(areas_torch, areas, atol=1e-7)
|
||||
|
||||
Reference in New Issue
Block a user