mesh rasterizer settings fix

Summary:
Fix default setting of `max_faces_per_bin` and update mesh rasterization benchmark tests.
The previous setting of `max_faces_per_bin` was wrong and for larger mesh sizes and batch sizes it was causing a significant slow down due to an unecessarily large intermediate tensor being created.

Reviewed By: gkioxari

Differential Revision: D22301819

fbshipit-source-id: d5e817f5b917fb5633c9c6a8634b6c8ff65e3508
This commit is contained in:
Nikhila Ravi
2020-06-30 12:42:42 -07:00
committed by Facebook GitHub Bot
parent 88f579389f
commit dd4a35cf9f
3 changed files with 49 additions and 30 deletions

View File

@@ -1004,26 +1004,42 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
@staticmethod
def rasterize_meshes_python_with_init(
num_meshes: int, ico_level: int, image_size: int, blur_radius: float
num_meshes: int,
ico_level: int,
image_size: int,
blur_radius: float,
faces_per_pixel: int,
):
device = torch.device("cpu")
meshes = ico_sphere(ico_level, device)
meshes_batch = meshes.extend(num_meshes)
def rasterize():
rasterize_meshes_python(meshes_batch, image_size, blur_radius)
rasterize_meshes_python(
meshes_batch, image_size, blur_radius, faces_per_pixel
)
return rasterize
@staticmethod
def rasterize_meshes_cpu_with_init(
num_meshes: int, ico_level: int, image_size: int, blur_radius: float
num_meshes: int,
ico_level: int,
image_size: int,
blur_radius: float,
faces_per_pixel: int,
):
meshes = ico_sphere(ico_level, torch.device("cpu"))
meshes_batch = meshes.extend(num_meshes)
def rasterize():
rasterize_meshes(meshes_batch, image_size, blur_radius, bin_size=0)
rasterize_meshes(
meshes_batch,
image_size,
blur_radius,
faces_per_pixel=faces_per_pixel,
bin_size=0,
)
return rasterize
@@ -1033,18 +1049,15 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
ico_level: int,
image_size: int,
blur_radius: float,
bin_size: int,
max_faces_per_bin: int,
faces_per_pixel: int,
):
meshes = ico_sphere(ico_level, get_random_cuda_device())
device = get_random_cuda_device()
meshes = ico_sphere(ico_level, device)
meshes_batch = meshes.extend(num_meshes)
torch.cuda.synchronize()
torch.cuda.synchronize(device)
def rasterize():
rasterize_meshes(
meshes_batch, image_size, blur_radius, 8, bin_size, max_faces_per_bin
)
torch.cuda.synchronize()
rasterize_meshes(meshes_batch, image_size, blur_radius, faces_per_pixel)
torch.cuda.synchronize(device)
return rasterize