barycentric clipping in cuda/c++

Summary:
Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting.

Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster.

Reviewed By: gkioxari

Differential Revision: D21705503

fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769
This commit is contained in:
Nikhila Ravi
2020-07-16 10:15:30 -07:00
committed by Facebook GitHub Bot
parent bce396df93
commit cc70950f40
13 changed files with 611 additions and 55 deletions

View File

@@ -0,0 +1,112 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh.rasterizer import (
Fragments,
MeshRasterizer,
RasterizationSettings,
)
from pytorch3d.renderer.mesh.utils import (
_clip_barycentric_coordinates,
_interpolate_zbuf,
)
from pytorch3d.utils.ico_sphere import ico_sphere
def baryclip_cuda(
num_meshes: int = 8,
ico_level: int = 5,
image_size: int = 64,
faces_per_pixel: int = 50,
device="cuda",
):
# Init meshes
sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
# Init transform
R, T = look_at_view_transform(1.0, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
# Init rasterizer
raster_settings = RasterizationSettings(
image_size=image_size,
blur_radius=1e-4,
faces_per_pixel=faces_per_pixel,
clip_barycentric_coords=True,
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
torch.cuda.synchronize()
def raster_fn():
rasterizer(sphere_meshes)
torch.cuda.synchronize()
return raster_fn
def baryclip_pytorch(
num_meshes: int = 8,
ico_level: int = 5,
image_size: int = 64,
faces_per_pixel: int = 50,
device="cuda",
):
# Init meshes
sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
# Init transform
R, T = look_at_view_transform(1.0, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
# Init rasterizer
raster_settings = RasterizationSettings(
image_size=image_size,
blur_radius=1e-4,
faces_per_pixel=faces_per_pixel,
clip_barycentric_coords=False,
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
torch.cuda.synchronize()
def raster_fn():
fragments = rasterizer(sphere_meshes)
# Clip bary and reinterpolate
clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords)
clipped_zbuf = _interpolate_zbuf(
fragments.pix_to_face, clipped_bary_coords, sphere_meshes
)
fragments = Fragments(
bary_coords=clipped_bary_coords,
zbuf=clipped_zbuf,
dists=fragments.dists,
pix_to_face=fragments.pix_to_face,
)
torch.cuda.synchronize()
return raster_fn
def bm_barycentric_clip() -> None:
if torch.cuda.is_available():
kwargs_list = []
num_meshes = [1, 8]
ico_level = [0, 4]
image_size = [64, 128, 256]
faces_per_pixel = [10, 75, 100]
test_cases = product(num_meshes, ico_level, image_size, faces_per_pixel)
for case in test_cases:
n, ic, im, nf = case
kwargs_list.append(
{
"num_meshes": n,
"ico_level": ic,
"image_size": im,
"faces_per_pixel": nf,
}
)
benchmark(baryclip_cuda, "BARY_CLIP_CUDA", kwargs_list, warmup_iters=1)
benchmark(baryclip_pytorch, "BARY_CLIP_PYTORCH", kwargs_list, warmup_iters=1)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

After

Width:  |  Height:  |  Size: 43 KiB

View File

@@ -10,6 +10,10 @@ from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes,
rasterize_meshes_python,
)
from pytorch3d.renderer.mesh.utils import (
_clip_barycentric_coordinates,
_interpolate_zbuf,
)
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
@@ -21,6 +25,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1)
self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1)
self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1)
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
def test_simple_cpu_naive(self):
@@ -170,8 +175,29 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
verts2.requires_grad = True
meshes_cuda = Meshes(verts=[verts2], faces=[faces2])
args_cpu = (meshes_cpu, image_size, radius, faces_per_pixel)
args_cuda = (meshes_cuda, image_size, radius, faces_per_pixel, 0, 0)
barycentric_clip = True
args_cpu = (
meshes_cpu,
image_size,
radius,
faces_per_pixel,
None,
None,
False,
barycentric_clip,
False,
)
args_cuda = (
meshes_cuda,
image_size,
radius,
faces_per_pixel,
0,
0,
False,
barycentric_clip,
False,
)
self._compare_impls(
rasterize_meshes,
rasterize_meshes,
@@ -333,6 +359,39 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
idxs_cuda[:K] = sorted(idxs_cuda[:K])
self.assertEqual(idxs_cpu, idxs_cuda)
def test_python_vs_cpp_bary_clip(self):
torch.manual_seed(232)
N = 2
V = 10
F = 5
verts1 = torch.randn(N, V, 3, requires_grad=True)
verts2 = verts1.detach().clone().requires_grad_(True)
faces = torch.randint(V, size=(N, F, 3))
meshes1 = Meshes(verts1, faces)
meshes2 = Meshes(verts2, faces)
kwargs = {"image_size": 24, "clip_barycentric_coords": True}
fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
fn2 = functools.partial(rasterize_meshes_python, meshes2, **kwargs)
args = ()
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
def test_cpp_vs_cuda_bary_clip(self):
meshes = ico_sphere(2, device=torch.device("cpu"))
verts1, faces1 = meshes.get_mesh_verts_faces(0)
verts1.requires_grad = True
meshes1 = Meshes(verts=[verts1], faces=[faces1])
device = get_random_cuda_device()
verts2 = verts1.detach().to(device).requires_grad_(True)
faces2 = faces1.detach().clone().to(device)
meshes2 = Meshes(verts=[verts2], faces=[faces2])
kwargs = {"image_size": 64, "clip_barycentric_coords": True}
fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
fn2 = functools.partial(rasterize_meshes, meshes2, bin_size=0, **kwargs)
args = ()
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
def test_python_vs_cpp_perspective_correct(self):
torch.manual_seed(232)
N = 2
@@ -621,6 +680,82 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self.assertLess(zbuf_f_bary_diff, 1e-4)
self.assertLess(zbuf_t_bary_diff, 1e-4)
def _test_barycentric_clipping(self, rasterize_meshes_fn, device, bin_size=None):
# fmt: off
verts = torch.tensor([
[-0.4, -0.4, 10], # noqa: E241, E201
[ 0.4, -0.4, 10], # noqa: E241, E201
[ 0.0, 0.4, 20], # noqa: E241, E201
], dtype=torch.float32, device=device)
# fmt: on
faces = torch.tensor([[0, 1, 2]], device=device)
meshes = Meshes(verts=[verts], faces=[faces])
kwargs = {
"meshes": meshes,
"image_size": 5,
"faces_per_pixel": 1,
"blur_radius": 0.2,
"perspective_correct": False,
"clip_barycentric_coords": False, # Initially set this to false
}
if bin_size != -1:
kwargs["bin_size"] = bin_size
# Run with and without perspective correction
idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs)
# fmt: off
expected_bary = torch.tensor([
[
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-0.2500, -0.2500, 1.5000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
],
[
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-0.5000, 0.5000, 1.0000], # noqa: E241, E201
[-0.0000, -0.0000, 1.0000], # noqa: E241, E201
[ 0.5000, -0.5000, 1.0000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
],
[
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[-0.2500, 0.7500, 0.5000], # noqa: E241, E201
[ 0.2500, 0.2500, 0.5000], # noqa: E241, E201
[ 0.7500, -0.2500, 0.5000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
],
[
[-0.5000, 1.5000, -0.0000], # noqa: E241, E201
[-0.0000, 1.0000, -0.0000], # noqa: E241, E201
[ 0.5000, 0.5000, -0.0000], # noqa: E241, E201
[ 1.0000, -0.0000, -0.0000], # noqa: E241, E201
[ 1.5000, -0.5000, 0.0000] # noqa: E241, E201
],
[
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
[ 0.2500, 1.2500, -0.5000], # noqa: E241, E201
[ 0.7500, 0.7500, -0.5000], # noqa: E241, E201
[ 1.2500, 0.2500, -0.5000], # noqa: E241, E201
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
]
], dtype=torch.float32, device=device).view(1, 5, 5, 1, 3)
# fmt: on
self.assertClose(expected_bary, bary_f, atol=1e-4)
# calculate the expected clipped barycentrics and zbuf
expected_bary_clipped = _clip_barycentric_coordinates(expected_bary)
expected_z_clipped = _interpolate_zbuf(idx_f, expected_bary_clipped, meshes)
kwargs["clip_barycentric_coords"] = True
idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs)
self.assertClose(expected_bary_clipped, bary_t, atol=1e-4)
self.assertClose(expected_z_clipped, zbuf_t, atol=1e-4)
def _test_behind_camera(self, rasterize_meshes_fn, device, bin_size=None):
"""
All verts are behind the camera so nothing should get rasterized.

View File

@@ -212,6 +212,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=80,
clip_barycentric_coords=True,
)
# Init rasterizer settings
@@ -269,11 +270,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
# the cow is facing the -z direction.
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
# Init renderer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=TexturedSoftPhongShader(
lights=lights, cameras=cameras, materials=materials
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
),
)
@@ -346,6 +355,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=100,
clip_barycentric_coords=True,
)
# Load reference image