diff --git a/tests/benchmarks/bm_blending.py b/tests/benchmarks/bm_blending.py index 5759ad40..7b5349c2 100644 --- a/tests/benchmarks/bm_blending.py +++ b/tests/benchmarks/bm_blending.py @@ -16,7 +16,7 @@ def bm_blending() -> None: kwargs_list = [] num_meshes = [8] image_size = [64, 128, 256] - faces_per_pixel = [50, 100] + faces_per_pixel = [2, 50, 100] backend = ["pytorch", "custom"] test_cases = product(num_meshes, image_size, faces_per_pixel, devices, backend) @@ -47,6 +47,28 @@ def bm_blending() -> None: warmup_iters=1, ) + kwargs_list = [] + faces_per_pixel = [2, 10] + backend = ["pytorch"] + test_cases = product(num_meshes, image_size, faces_per_pixel, devices, backend) + for case in test_cases: + n, s, k, d, b = case + kwargs_list.append( + { + "num_meshes": n, + "image_size": s, + "faces_per_pixel": k, + "device": d, + "backend": b, + } + ) + benchmark( + TestBlending.bm_splatter_blending, + "SPLATTER_BLENDING_PYTORCH", + kwargs_list, + warmup_iters=1, + ) + if __name__ == "__main__": bm_blending() diff --git a/tests/test_blending.py b/tests/test_blending.py index 4503c1f7..13c494fd 100644 --- a/tests/test_blending.py +++ b/tests/test_blending.py @@ -14,7 +14,9 @@ from pytorch3d.renderer.blending import ( sigmoid_alpha_blend, softmax_rgb_blend, ) +from pytorch3d.renderer.cameras import FoVPerspectiveCameras from pytorch3d.renderer.mesh.rasterizer import Fragments +from pytorch3d.renderer.splatter_blend import SplatterBlender def sigmoid_blend_naive_loop(colors, fragments, blend_params): @@ -412,6 +414,54 @@ class TestBlending(TestCaseMixin, unittest.TestCase): return fn + @staticmethod + def bm_splatter_blending( + num_meshes: int = 16, + image_size: int = 128, + faces_per_pixel: int = 2, + use_jit: bool = False, + device: str = "cpu", + backend: str = "pytorch", + ): + if torch.cuda.is_available() and "cuda:" in device: + # If a device other than the default is used, set the device explicity. + torch.cuda.set_device(device) + + device = torch.device(device) + torch.manual_seed(231) + + # Create dummy outputs of rasterization + N, S, K = num_meshes, image_size, faces_per_pixel + F = 32 # num faces in the mesh + + pixel_coords_camera = torch.randn( + (N, S, S, K, 3), device=device, requires_grad=True + ) + cameras = FoVPerspectiveCameras(device=device) + colors = torch.randn((N, S, S, K, 3), device=device) + background_mask = torch.randint( + low=-1, high=F + 1, size=(N, S, S, K), device=device + ) + background_mask = torch.full((N, S, S, K), False, dtype=bool, device=device) + blend_params = BlendParams(sigma=0.5) + + torch.cuda.synchronize() + splatter_blender = SplatterBlender((N, S, S, K), colors.device) + + def fn(): + # test forward and backward pass + images = splatter_blender( + colors, + pixel_coords_camera, + cameras, + background_mask, + blend_params, + ) + images.sum().backward() + torch.cuda.synchronize() + + return fn + def test_blend_params(self): """Test color parameter of BlendParams(). Assert passed value overrides default value.