diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index 1a4d4abc..3ea2ae16 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -92,23 +92,20 @@ class MeshRasterizer(nn.Module): msg = "Cameras must be specified either at initialization \ or in the forward pass of MeshRasterizer" raise ValueError(msg) - verts_world = meshes_world.verts_padded() - verts_world_packed = meshes_world.verts_packed() - verts_screen = cameras.transform_points(verts_world, **kwargs) # NOTE: Retaining view space z coordinate for now. # TODO: Revisit whether or not to transform z coordinate to [-1, 1] or # [0, 1] range. - view_transform = get_world_to_view_transform(R=cameras.R, T=cameras.T) - verts_view = view_transform.transform_points(verts_world) + verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points( + verts_world + ) + verts_screen = cameras.get_projection_transform(**kwargs).transform_points( + verts_view + ) verts_screen[..., 2] = verts_view[..., 2] - - # Offset verts of input mesh to reuse cached padded/packed calculations. - pad_to_packed_idx = meshes_world.verts_padded_to_packed_idx() - verts_screen_packed = verts_screen.view(-1, 3)[pad_to_packed_idx, :] - verts_packed_offset = verts_screen_packed - verts_world_packed - return meshes_world.offset_verts(verts_packed_offset) + meshes_screen = meshes_world.update_padded(new_verts_padded=verts_screen) + return meshes_screen def forward(self, meshes_world, **kwargs) -> Fragments: """ diff --git a/tests/bm_mesh_rasterizer_transform.py b/tests/bm_mesh_rasterizer_transform.py new file mode 100644 index 00000000..97672504 --- /dev/null +++ b/tests/bm_mesh_rasterizer_transform.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +from itertools import product + +import torch +from fvcore.common.benchmark import benchmark +from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer +from pytorch3d.utils.ico_sphere import ico_sphere + + +def rasterize_transform_with_init(num_meshes: int, ico_level: int = 5, device="cuda"): + # Init meshes + sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) + # Init transform + R, T = look_at_view_transform(1.0, 0.0, 0.0) + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + # Init rasterizer + rasterizer = MeshRasterizer(cameras=cameras) + + torch.cuda.synchronize() + + def raster_fn(): + rasterizer.transform(sphere_meshes) + torch.cuda.synchronize() + + return raster_fn + + +def bm_mesh_rasterizer_transform() -> None: + if torch.cuda.is_available(): + kwargs_list = [] + num_meshes = [1, 8] + ico_level = [0, 1, 3, 4] + test_cases = product(num_meshes, ico_level) + for case in test_cases: + n, ic = case + kwargs_list.append({"num_meshes": n, "ico_level": ic}) + benchmark( + rasterize_transform_with_init, + "MESH_RASTERIZER", + kwargs_list, + warmup_iters=1, + )