diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 217974d6..2aa28f58 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -262,7 +262,10 @@ class CamerasBase(TensorProperties): # We don't flip xy because we assume that world points are in # PyTorch3D coordinates, and thus conversion from screen to ndc # is a mere scaling from image to [-1, 1] scale. - return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs) + image_size = kwargs.get("image_size", self.get_image_size()) + return get_screen_to_ndc_transform( + self, with_xyflip=False, image_size=image_size + ) def transform_points_ndc( self, points, eps: Optional[float] = None, **kwargs @@ -318,8 +321,9 @@ class CamerasBase(TensorProperties): new_points: transformed points with the same shape as the input. """ points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs) + image_size = kwargs.get("image_size", self.get_image_size()) return get_ndc_to_screen_transform( - self, with_xyflip=True, **kwargs + self, with_xyflip=True, image_size=image_size ).transform_points(points_ndc, eps=eps) def clone(self): @@ -923,7 +927,7 @@ class PerspectiveCameras(CamerasBase): K: (optional) A calibration matrix of shape (N, 4, 4) If provided, don't need focal_length, principal_point image_size: (height, width) of image size. - A tensor of shape (N, 2). Required for screen cameras. + A tensor of shape (N, 2) or a list/tuple. Required for screen cameras. device: torch.device or string """ # The initializer formats all inputs to torch tensors and broadcasts @@ -1044,8 +1048,9 @@ class PerspectiveCameras(CamerasBase): pr_point_fix_transform = Transform3d( matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device ) + image_size = kwargs.get("image_size", self.get_image_size()) screen_to_ndc_transform = get_screen_to_ndc_transform( - self, with_xyflip=False, **kwargs + self, with_xyflip=False, image_size=image_size ) ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform) @@ -1105,7 +1110,7 @@ class OrthographicCameras(CamerasBase): K: Optional[torch.Tensor] = None, device: Device = "cpu", in_ndc: bool = True, - image_size: Optional[torch.Tensor] = None, + image_size: Optional[Union[List, Tuple, torch.Tensor]] = None, ) -> None: """ @@ -1123,7 +1128,7 @@ class OrthographicCameras(CamerasBase): K: (optional) A calibration matrix of shape (N, 4, 4) If provided, don't need focal_length, principal_point, image_size image_size: (height, width) of image size. - A tensor of shape (N, 2). Required for screen cameras. + A tensor of shape (N, 2) or list/tuple. Required for screen cameras. device: torch.device or string """ # The initializer formats all inputs to torch tensors and broadcasts @@ -1241,8 +1246,9 @@ class OrthographicCameras(CamerasBase): pr_point_fix_transform = Transform3d( matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device ) + image_size = kwargs.get("image_size", self.get_image_size()) screen_to_ndc_transform = get_screen_to_ndc_transform( - self, with_xyflip=False, **kwargs + self, with_xyflip=False, image_size=image_size ) ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform) @@ -1537,7 +1543,9 @@ def look_at_view_transform( def get_ndc_to_screen_transform( - cameras, with_xyflip: bool = False, **kwargs + cameras, + with_xyflip: bool = False, + image_size: Optional[Union[List, Tuple, torch.Tensor]] = None, ) -> Transform3d: """ PyTorch3D NDC to screen conversion. @@ -1563,7 +1571,6 @@ def get_ndc_to_screen_transform( """ # We require the image size, which is necessary for the transform - image_size = kwargs.get("image_size", cameras.get_image_size()) if image_size is None: msg = "For NDC to screen conversion, image_size=(height, width) needs to be specified." raise ValueError(msg) @@ -1605,7 +1612,9 @@ def get_ndc_to_screen_transform( def get_screen_to_ndc_transform( - cameras, with_xyflip: bool = False, **kwargs + cameras, + with_xyflip: bool = False, + image_size: Optional[Union[List, Tuple, torch.Tensor]] = None, ) -> Transform3d: """ Screen to PyTorch3D NDC conversion. @@ -1631,6 +1640,8 @@ def get_screen_to_ndc_transform( """ transform = get_ndc_to_screen_transform( - cameras, with_xyflip=with_xyflip, **kwargs + cameras, + with_xyflip=with_xyflip, + image_size=image_size, ).inverse() return transform diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index d6ae1c7e..1517bb5d 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -1146,3 +1146,54 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ) self.assertClose(rgb, image_ref, atol=0.05) + + def test_cameras_kwarg(self): + """ + Test that when cameras are passed in as a kwarg the rendering + works as expected + """ + device = torch.device("cuda:0") + + # Init mesh + sphere_mesh = ico_sphere(5, device) + verts_padded = sphere_mesh.verts_padded() + faces_padded = sphere_mesh.faces_padded() + feats = torch.ones_like(verts_padded, device=device) + textures = TexturesVertex(verts_features=feats) + sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures) + + # No elevation or azimuth rotation + R, T = look_at_view_transform(2.7, 0.0, 0.0) + for cam_type in ( + FoVPerspectiveCameras, + FoVOrthographicCameras, + PerspectiveCameras, + OrthographicCameras, + ): + cameras = cam_type(device=device, R=R, T=T) + + # Init shader settings + materials = Materials(device=device) + lights = PointLights(device=device) + lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] + + raster_settings = RasterizationSettings( + image_size=512, blur_radius=0.0, faces_per_pixel=1 + ) + rasterizer = MeshRasterizer(raster_settings=raster_settings) + blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) + + shader = HardPhongShader( + lights=lights, + materials=materials, + blend_params=blend_params, + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + + # Cameras can be passed into the renderer in the forward pass + images = renderer(sphere_mesh, cameras=cameras) + rgb = images.squeeze()[..., :3].cpu().numpy() + image_ref = load_rgb_image( + "test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR + ) + self.assertClose(rgb, image_ref, atol=0.05)