mesh rasterizer settings fix

Summary:
Fix default setting of `max_faces_per_bin` and update mesh rasterization benchmark tests.
The previous setting of `max_faces_per_bin` was wrong and for larger mesh sizes and batch sizes it was causing a significant slow down due to an unecessarily large intermediate tensor being created.

Reviewed By: gkioxari

Differential Revision: D22301819

fbshipit-source-id: d5e817f5b917fb5633c9c6a8634b6c8ff65e3508
This commit is contained in:
Nikhila Ravi 2020-06-30 12:42:42 -07:00 committed by Facebook GitHub Bot
parent 88f579389f
commit dd4a35cf9f
3 changed files with 49 additions and 30 deletions

View File

@ -130,7 +130,7 @@ def rasterize_meshes(
)
if max_faces_per_bin is None:
max_faces_per_bin = int(max(10000, verts_packed.shape[0] / 5))
max_faces_per_bin = int(max(10000, meshes._F / 5))
# pyre-fixme[16]: `_RasterizeFaceVerts` has no attribute `apply`.
return _RasterizeFaceVerts.apply(

View File

@ -13,6 +13,8 @@ from test_rasterize_meshes import TestRasterizeMeshes
# 1: (42 verts, 80 faces)
# 3: (642 verts, 1280 faces)
# 4: (2562 verts, 5120 faces)
# 5: (10242 verts, 20480 faces)
# 6: (40962 verts, 81920 faces)
def bm_rasterize_meshes() -> None:
@ -22,6 +24,7 @@ def bm_rasterize_meshes() -> None:
"ico_level": 0,
"image_size": 10, # very slow with large image size
"blur_radius": 0.0,
"faces_per_pixel": 3,
}
]
benchmark(
@ -35,12 +38,19 @@ def bm_rasterize_meshes() -> None:
num_meshes = [1]
ico_level = [1]
image_size = [64, 128]
blur = [0.0, 1e-8, 1e-4]
test_cases = product(num_meshes, ico_level, image_size, blur)
blur = [1e-6]
faces_per_pixel = [3, 50]
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
for case in test_cases:
n, ic, im, b = case
n, ic, im, b, f = case
kwargs_list.append(
{"num_meshes": n, "ico_level": ic, "image_size": im, "blur_radius": b}
{
"num_meshes": n,
"ico_level": ic,
"image_size": im,
"blur_radius": b,
"faces_per_pixel": f,
}
)
benchmark(
TestRasterizeMeshes.rasterize_meshes_cpu_with_init,
@ -51,26 +61,22 @@ def bm_rasterize_meshes() -> None:
if torch.cuda.is_available():
kwargs_list = []
num_meshes = [1, 8]
ico_level = [0, 1, 3, 4]
num_meshes = [8, 16]
ico_level = [4, 5, 6]
image_size = [64, 128, 512]
blur = [0.0, 1e-8, 1e-4]
bin_size = [0, 8, 32]
test_cases = product(num_meshes, ico_level, image_size, blur, bin_size)
# only keep cases where bin_size == 0 or image_size / bin_size < 16
test_cases = [
elem for elem in test_cases if (elem[-1] == 0 or elem[-3] / elem[-1] < 16)
]
blur = [1e-6]
faces_per_pixel = [50]
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
for case in test_cases:
n, ic, im, b, bn = case
n, ic, im, b, f = case
kwargs_list.append(
{
"num_meshes": n,
"ico_level": ic,
"image_size": im,
"blur_radius": b,
"bin_size": bn,
"max_faces_per_bin": 200,
"faces_per_pixel": f,
}
)
benchmark(

View File

@ -1004,26 +1004,42 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
@staticmethod
def rasterize_meshes_python_with_init(
num_meshes: int, ico_level: int, image_size: int, blur_radius: float
num_meshes: int,
ico_level: int,
image_size: int,
blur_radius: float,
faces_per_pixel: int,
):
device = torch.device("cpu")
meshes = ico_sphere(ico_level, device)
meshes_batch = meshes.extend(num_meshes)
def rasterize():
rasterize_meshes_python(meshes_batch, image_size, blur_radius)
rasterize_meshes_python(
meshes_batch, image_size, blur_radius, faces_per_pixel
)
return rasterize
@staticmethod
def rasterize_meshes_cpu_with_init(
num_meshes: int, ico_level: int, image_size: int, blur_radius: float
num_meshes: int,
ico_level: int,
image_size: int,
blur_radius: float,
faces_per_pixel: int,
):
meshes = ico_sphere(ico_level, torch.device("cpu"))
meshes_batch = meshes.extend(num_meshes)
def rasterize():
rasterize_meshes(meshes_batch, image_size, blur_radius, bin_size=0)
rasterize_meshes(
meshes_batch,
image_size,
blur_radius,
faces_per_pixel=faces_per_pixel,
bin_size=0,
)
return rasterize
@ -1033,18 +1049,15 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
ico_level: int,
image_size: int,
blur_radius: float,
bin_size: int,
max_faces_per_bin: int,
faces_per_pixel: int,
):
meshes = ico_sphere(ico_level, get_random_cuda_device())
device = get_random_cuda_device()
meshes = ico_sphere(ico_level, device)
meshes_batch = meshes.extend(num_meshes)
torch.cuda.synchronize()
torch.cuda.synchronize(device)
def rasterize():
rasterize_meshes(
meshes_batch, image_size, blur_radius, 8, bin_size, max_faces_per_bin
)
torch.cuda.synchronize()
rasterize_meshes(meshes_batch, image_size, blur_radius, faces_per_pixel)
torch.cuda.synchronize(device)
return rasterize