mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
barycentric clipping in cuda/c++
Summary: Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting. Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster. Reviewed By: gkioxari Differential Revision: D21705503 fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bce396df93
commit
cc70950f40
@@ -10,6 +10,10 @@ from pytorch3d.renderer.mesh.rasterize_meshes import (
|
||||
rasterize_meshes,
|
||||
rasterize_meshes_python,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.utils import (
|
||||
_clip_barycentric_coordinates,
|
||||
_interpolate_zbuf,
|
||||
)
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.utils import ico_sphere
|
||||
|
||||
@@ -21,6 +25,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1)
|
||||
self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1)
|
||||
self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1)
|
||||
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
|
||||
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
|
||||
|
||||
def test_simple_cpu_naive(self):
|
||||
@@ -170,8 +175,29 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
verts2.requires_grad = True
|
||||
meshes_cuda = Meshes(verts=[verts2], faces=[faces2])
|
||||
|
||||
args_cpu = (meshes_cpu, image_size, radius, faces_per_pixel)
|
||||
args_cuda = (meshes_cuda, image_size, radius, faces_per_pixel, 0, 0)
|
||||
barycentric_clip = True
|
||||
args_cpu = (
|
||||
meshes_cpu,
|
||||
image_size,
|
||||
radius,
|
||||
faces_per_pixel,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
barycentric_clip,
|
||||
False,
|
||||
)
|
||||
args_cuda = (
|
||||
meshes_cuda,
|
||||
image_size,
|
||||
radius,
|
||||
faces_per_pixel,
|
||||
0,
|
||||
0,
|
||||
False,
|
||||
barycentric_clip,
|
||||
False,
|
||||
)
|
||||
self._compare_impls(
|
||||
rasterize_meshes,
|
||||
rasterize_meshes,
|
||||
@@ -333,6 +359,39 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
idxs_cuda[:K] = sorted(idxs_cuda[:K])
|
||||
self.assertEqual(idxs_cpu, idxs_cuda)
|
||||
|
||||
def test_python_vs_cpp_bary_clip(self):
|
||||
torch.manual_seed(232)
|
||||
N = 2
|
||||
V = 10
|
||||
F = 5
|
||||
verts1 = torch.randn(N, V, 3, requires_grad=True)
|
||||
verts2 = verts1.detach().clone().requires_grad_(True)
|
||||
faces = torch.randint(V, size=(N, F, 3))
|
||||
meshes1 = Meshes(verts1, faces)
|
||||
meshes2 = Meshes(verts2, faces)
|
||||
|
||||
kwargs = {"image_size": 24, "clip_barycentric_coords": True}
|
||||
fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
|
||||
fn2 = functools.partial(rasterize_meshes_python, meshes2, **kwargs)
|
||||
args = ()
|
||||
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
|
||||
|
||||
def test_cpp_vs_cuda_bary_clip(self):
|
||||
meshes = ico_sphere(2, device=torch.device("cpu"))
|
||||
verts1, faces1 = meshes.get_mesh_verts_faces(0)
|
||||
verts1.requires_grad = True
|
||||
meshes1 = Meshes(verts=[verts1], faces=[faces1])
|
||||
device = get_random_cuda_device()
|
||||
verts2 = verts1.detach().to(device).requires_grad_(True)
|
||||
faces2 = faces1.detach().clone().to(device)
|
||||
meshes2 = Meshes(verts=[verts2], faces=[faces2])
|
||||
|
||||
kwargs = {"image_size": 64, "clip_barycentric_coords": True}
|
||||
fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
|
||||
fn2 = functools.partial(rasterize_meshes, meshes2, bin_size=0, **kwargs)
|
||||
args = ()
|
||||
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
|
||||
|
||||
def test_python_vs_cpp_perspective_correct(self):
|
||||
torch.manual_seed(232)
|
||||
N = 2
|
||||
@@ -621,6 +680,82 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertLess(zbuf_f_bary_diff, 1e-4)
|
||||
self.assertLess(zbuf_t_bary_diff, 1e-4)
|
||||
|
||||
def _test_barycentric_clipping(self, rasterize_meshes_fn, device, bin_size=None):
|
||||
# fmt: off
|
||||
verts = torch.tensor([
|
||||
[-0.4, -0.4, 10], # noqa: E241, E201
|
||||
[ 0.4, -0.4, 10], # noqa: E241, E201
|
||||
[ 0.0, 0.4, 20], # noqa: E241, E201
|
||||
], dtype=torch.float32, device=device)
|
||||
# fmt: on
|
||||
faces = torch.tensor([[0, 1, 2]], device=device)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
kwargs = {
|
||||
"meshes": meshes,
|
||||
"image_size": 5,
|
||||
"faces_per_pixel": 1,
|
||||
"blur_radius": 0.2,
|
||||
"perspective_correct": False,
|
||||
"clip_barycentric_coords": False, # Initially set this to false
|
||||
}
|
||||
if bin_size != -1:
|
||||
kwargs["bin_size"] = bin_size
|
||||
|
||||
# Run with and without perspective correction
|
||||
idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs)
|
||||
|
||||
# fmt: off
|
||||
expected_bary = torch.tensor([
|
||||
[
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-0.2500, -0.2500, 1.5000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-0.5000, 0.5000, 1.0000], # noqa: E241, E201
|
||||
[-0.0000, -0.0000, 1.0000], # noqa: E241, E201
|
||||
[ 0.5000, -0.5000, 1.0000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-0.2500, 0.7500, 0.5000], # noqa: E241, E201
|
||||
[ 0.2500, 0.2500, 0.5000], # noqa: E241, E201
|
||||
[ 0.7500, -0.2500, 0.5000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[-0.5000, 1.5000, -0.0000], # noqa: E241, E201
|
||||
[-0.0000, 1.0000, -0.0000], # noqa: E241, E201
|
||||
[ 0.5000, 0.5000, -0.0000], # noqa: E241, E201
|
||||
[ 1.0000, -0.0000, -0.0000], # noqa: E241, E201
|
||||
[ 1.5000, -0.5000, 0.0000] # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[ 0.2500, 1.2500, -0.5000], # noqa: E241, E201
|
||||
[ 0.7500, 0.7500, -0.5000], # noqa: E241, E201
|
||||
[ 1.2500, 0.2500, -0.5000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
|
||||
]
|
||||
], dtype=torch.float32, device=device).view(1, 5, 5, 1, 3)
|
||||
# fmt: on
|
||||
|
||||
self.assertClose(expected_bary, bary_f, atol=1e-4)
|
||||
|
||||
# calculate the expected clipped barycentrics and zbuf
|
||||
expected_bary_clipped = _clip_barycentric_coordinates(expected_bary)
|
||||
expected_z_clipped = _interpolate_zbuf(idx_f, expected_bary_clipped, meshes)
|
||||
|
||||
kwargs["clip_barycentric_coords"] = True
|
||||
idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs)
|
||||
|
||||
self.assertClose(expected_bary_clipped, bary_t, atol=1e-4)
|
||||
self.assertClose(expected_z_clipped, zbuf_t, atol=1e-4)
|
||||
|
||||
def _test_behind_camera(self, rasterize_meshes_fn, device, bin_size=None):
|
||||
"""
|
||||
All verts are behind the camera so nothing should get rasterized.
|
||||
|
||||
Reference in New Issue
Block a user