mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
barycentric clipping in cuda/c++
Summary: Added support for barycentric clipping in the C++/CUDA rasterization kernels which can be switched on/off via a rasterization setting. Added tests and a benchmark to compare with the current implementation in PyTorch - for some cases of large image size/faces per pixel the cuda version is 10x faster. Reviewed By: gkioxari Differential Revision: D21705503 fbshipit-source-id: e835c0f927f1e5088ca89020aef5ff27ac3a8769
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bce396df93
commit
cc70950f40
112
tests/bm_barycentric_clipping.py
Normal file
112
tests/bm_barycentric_clipping.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.mesh.rasterizer import (
|
||||
Fragments,
|
||||
MeshRasterizer,
|
||||
RasterizationSettings,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.utils import (
|
||||
_clip_barycentric_coordinates,
|
||||
_interpolate_zbuf,
|
||||
)
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
|
||||
def baryclip_cuda(
|
||||
num_meshes: int = 8,
|
||||
ico_level: int = 5,
|
||||
image_size: int = 64,
|
||||
faces_per_pixel: int = 50,
|
||||
device="cuda",
|
||||
):
|
||||
# Init meshes
|
||||
sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
|
||||
# Init transform
|
||||
R, T = look_at_view_transform(1.0, 0.0, 0.0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
# Init rasterizer
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=image_size,
|
||||
blur_radius=1e-4,
|
||||
faces_per_pixel=faces_per_pixel,
|
||||
clip_barycentric_coords=True,
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def raster_fn():
|
||||
rasterizer(sphere_meshes)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return raster_fn
|
||||
|
||||
|
||||
def baryclip_pytorch(
|
||||
num_meshes: int = 8,
|
||||
ico_level: int = 5,
|
||||
image_size: int = 64,
|
||||
faces_per_pixel: int = 50,
|
||||
device="cuda",
|
||||
):
|
||||
# Init meshes
|
||||
sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes)
|
||||
# Init transform
|
||||
R, T = look_at_view_transform(1.0, 0.0, 0.0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
# Init rasterizer
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=image_size,
|
||||
blur_radius=1e-4,
|
||||
faces_per_pixel=faces_per_pixel,
|
||||
clip_barycentric_coords=False,
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def raster_fn():
|
||||
fragments = rasterizer(sphere_meshes)
|
||||
|
||||
# Clip bary and reinterpolate
|
||||
clipped_bary_coords = _clip_barycentric_coordinates(fragments.bary_coords)
|
||||
clipped_zbuf = _interpolate_zbuf(
|
||||
fragments.pix_to_face, clipped_bary_coords, sphere_meshes
|
||||
)
|
||||
fragments = Fragments(
|
||||
bary_coords=clipped_bary_coords,
|
||||
zbuf=clipped_zbuf,
|
||||
dists=fragments.dists,
|
||||
pix_to_face=fragments.pix_to_face,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return raster_fn
|
||||
|
||||
|
||||
def bm_barycentric_clip() -> None:
|
||||
if torch.cuda.is_available():
|
||||
kwargs_list = []
|
||||
num_meshes = [1, 8]
|
||||
ico_level = [0, 4]
|
||||
image_size = [64, 128, 256]
|
||||
faces_per_pixel = [10, 75, 100]
|
||||
test_cases = product(num_meshes, ico_level, image_size, faces_per_pixel)
|
||||
for case in test_cases:
|
||||
n, ic, im, nf = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"num_meshes": n,
|
||||
"ico_level": ic,
|
||||
"image_size": im,
|
||||
"faces_per_pixel": nf,
|
||||
}
|
||||
)
|
||||
|
||||
benchmark(baryclip_cuda, "BARY_CLIP_CUDA", kwargs_list, warmup_iters=1)
|
||||
benchmark(baryclip_pytorch, "BARY_CLIP_PYTORCH", kwargs_list, warmup_iters=1)
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 43 KiB |
@@ -10,6 +10,10 @@ from pytorch3d.renderer.mesh.rasterize_meshes import (
|
||||
rasterize_meshes,
|
||||
rasterize_meshes_python,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.utils import (
|
||||
_clip_barycentric_coordinates,
|
||||
_interpolate_zbuf,
|
||||
)
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.utils import ico_sphere
|
||||
|
||||
@@ -21,6 +25,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1)
|
||||
self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1)
|
||||
self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1)
|
||||
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
|
||||
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
|
||||
|
||||
def test_simple_cpu_naive(self):
|
||||
@@ -170,8 +175,29 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
verts2.requires_grad = True
|
||||
meshes_cuda = Meshes(verts=[verts2], faces=[faces2])
|
||||
|
||||
args_cpu = (meshes_cpu, image_size, radius, faces_per_pixel)
|
||||
args_cuda = (meshes_cuda, image_size, radius, faces_per_pixel, 0, 0)
|
||||
barycentric_clip = True
|
||||
args_cpu = (
|
||||
meshes_cpu,
|
||||
image_size,
|
||||
radius,
|
||||
faces_per_pixel,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
barycentric_clip,
|
||||
False,
|
||||
)
|
||||
args_cuda = (
|
||||
meshes_cuda,
|
||||
image_size,
|
||||
radius,
|
||||
faces_per_pixel,
|
||||
0,
|
||||
0,
|
||||
False,
|
||||
barycentric_clip,
|
||||
False,
|
||||
)
|
||||
self._compare_impls(
|
||||
rasterize_meshes,
|
||||
rasterize_meshes,
|
||||
@@ -333,6 +359,39 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
idxs_cuda[:K] = sorted(idxs_cuda[:K])
|
||||
self.assertEqual(idxs_cpu, idxs_cuda)
|
||||
|
||||
def test_python_vs_cpp_bary_clip(self):
|
||||
torch.manual_seed(232)
|
||||
N = 2
|
||||
V = 10
|
||||
F = 5
|
||||
verts1 = torch.randn(N, V, 3, requires_grad=True)
|
||||
verts2 = verts1.detach().clone().requires_grad_(True)
|
||||
faces = torch.randint(V, size=(N, F, 3))
|
||||
meshes1 = Meshes(verts1, faces)
|
||||
meshes2 = Meshes(verts2, faces)
|
||||
|
||||
kwargs = {"image_size": 24, "clip_barycentric_coords": True}
|
||||
fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
|
||||
fn2 = functools.partial(rasterize_meshes_python, meshes2, **kwargs)
|
||||
args = ()
|
||||
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
|
||||
|
||||
def test_cpp_vs_cuda_bary_clip(self):
|
||||
meshes = ico_sphere(2, device=torch.device("cpu"))
|
||||
verts1, faces1 = meshes.get_mesh_verts_faces(0)
|
||||
verts1.requires_grad = True
|
||||
meshes1 = Meshes(verts=[verts1], faces=[faces1])
|
||||
device = get_random_cuda_device()
|
||||
verts2 = verts1.detach().to(device).requires_grad_(True)
|
||||
faces2 = faces1.detach().clone().to(device)
|
||||
meshes2 = Meshes(verts=[verts2], faces=[faces2])
|
||||
|
||||
kwargs = {"image_size": 64, "clip_barycentric_coords": True}
|
||||
fn1 = functools.partial(rasterize_meshes, meshes1, **kwargs)
|
||||
fn2 = functools.partial(rasterize_meshes, meshes2, bin_size=0, **kwargs)
|
||||
args = ()
|
||||
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
|
||||
|
||||
def test_python_vs_cpp_perspective_correct(self):
|
||||
torch.manual_seed(232)
|
||||
N = 2
|
||||
@@ -621,6 +680,82 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertLess(zbuf_f_bary_diff, 1e-4)
|
||||
self.assertLess(zbuf_t_bary_diff, 1e-4)
|
||||
|
||||
def _test_barycentric_clipping(self, rasterize_meshes_fn, device, bin_size=None):
|
||||
# fmt: off
|
||||
verts = torch.tensor([
|
||||
[-0.4, -0.4, 10], # noqa: E241, E201
|
||||
[ 0.4, -0.4, 10], # noqa: E241, E201
|
||||
[ 0.0, 0.4, 20], # noqa: E241, E201
|
||||
], dtype=torch.float32, device=device)
|
||||
# fmt: on
|
||||
faces = torch.tensor([[0, 1, 2]], device=device)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
kwargs = {
|
||||
"meshes": meshes,
|
||||
"image_size": 5,
|
||||
"faces_per_pixel": 1,
|
||||
"blur_radius": 0.2,
|
||||
"perspective_correct": False,
|
||||
"clip_barycentric_coords": False, # Initially set this to false
|
||||
}
|
||||
if bin_size != -1:
|
||||
kwargs["bin_size"] = bin_size
|
||||
|
||||
# Run with and without perspective correction
|
||||
idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs)
|
||||
|
||||
# fmt: off
|
||||
expected_bary = torch.tensor([
|
||||
[
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-0.2500, -0.2500, 1.5000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-0.5000, 0.5000, 1.0000], # noqa: E241, E201
|
||||
[-0.0000, -0.0000, 1.0000], # noqa: E241, E201
|
||||
[ 0.5000, -0.5000, 1.0000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[-0.2500, 0.7500, 0.5000], # noqa: E241, E201
|
||||
[ 0.2500, 0.2500, 0.5000], # noqa: E241, E201
|
||||
[ 0.7500, -0.2500, 0.5000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[-0.5000, 1.5000, -0.0000], # noqa: E241, E201
|
||||
[-0.0000, 1.0000, -0.0000], # noqa: E241, E201
|
||||
[ 0.5000, 0.5000, -0.0000], # noqa: E241, E201
|
||||
[ 1.0000, -0.0000, -0.0000], # noqa: E241, E201
|
||||
[ 1.5000, -0.5000, 0.0000] # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[-1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
||||
[ 0.2500, 1.2500, -0.5000], # noqa: E241, E201
|
||||
[ 0.7500, 0.7500, -0.5000], # noqa: E241, E201
|
||||
[ 1.2500, 0.2500, -0.5000], # noqa: E241, E201
|
||||
[-1.0000, -1.0000, -1.0000] # noqa: E241, E201
|
||||
]
|
||||
], dtype=torch.float32, device=device).view(1, 5, 5, 1, 3)
|
||||
# fmt: on
|
||||
|
||||
self.assertClose(expected_bary, bary_f, atol=1e-4)
|
||||
|
||||
# calculate the expected clipped barycentrics and zbuf
|
||||
expected_bary_clipped = _clip_barycentric_coordinates(expected_bary)
|
||||
expected_z_clipped = _interpolate_zbuf(idx_f, expected_bary_clipped, meshes)
|
||||
|
||||
kwargs["clip_barycentric_coords"] = True
|
||||
idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs)
|
||||
|
||||
self.assertClose(expected_bary_clipped, bary_t, atol=1e-4)
|
||||
self.assertClose(expected_z_clipped, zbuf_t, atol=1e-4)
|
||||
|
||||
def _test_behind_camera(self, rasterize_meshes_fn, device, bin_size=None):
|
||||
"""
|
||||
All verts are behind the camera so nothing should get rasterized.
|
||||
|
||||
@@ -212,6 +212,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
image_size=512,
|
||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
||||
faces_per_pixel=80,
|
||||
clip_barycentric_coords=True,
|
||||
)
|
||||
|
||||
# Init rasterizer settings
|
||||
@@ -269,11 +270,19 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
# the cow is facing the -z direction.
|
||||
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
|
||||
|
||||
blend_params = BlendParams(
|
||||
sigma=1e-1,
|
||||
gamma=1e-4,
|
||||
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
|
||||
)
|
||||
# Init renderer
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
shader=TexturedSoftPhongShader(
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -346,6 +355,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
image_size=512,
|
||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
||||
faces_per_pixel=100,
|
||||
clip_barycentric_coords=True,
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
|
||||
Reference in New Issue
Block a user