diff --git a/pytorch3d/renderer/points/pulsar/unified.py b/pytorch3d/renderer/points/pulsar/unified.py index 5c271f34..1b968958 100644 --- a/pytorch3d/renderer/points/pulsar/unified.py +++ b/pytorch3d/renderer/points/pulsar/unified.py @@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -from ....transforms import matrix_to_rotation_6d +from ....utils import pulsar_from_cameras_projection from ...cameras import ( FoVOrthographicCameras, FoVPerspectiveCameras, @@ -102,7 +102,7 @@ class PulsarPointsRenderer(nn.Module): height=height, max_num_balls=max_num_spheres, orthogonal_projection=orthogonal_projection, - right_handed_system=True, + right_handed_system=False, n_channels=n_channels, **kwargs, ) @@ -359,24 +359,28 @@ class PulsarPointsRenderer(nn.Module): def _extract_extrinsics( self, kwargs, cloud_idx ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract the extrinsic information from the kwargs for a specific point cloud. + + Instead of implementing a direct translation from the PyTorch3D to the Pulsar + camera model, we chain the two conversions of PyTorch3D->OpenCV and + OpenCV->Pulsar for better maintainability (PyTorch3D->OpenCV is maintained and + tested by the core PyTorch3D team, whereas OpenCV->Pulsar is maintained and + tested by the Pulsar team). + """ # Shorthand: cameras = self.rasterizer.cameras R = kwargs.get("R", cameras.R)[cloud_idx] T = kwargs.get("T", cameras.T)[cloud_idx] - norm_mat = torch.tensor( - [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], - dtype=torch.float32, - device=R.device, + tmp_cams = PerspectiveCameras( + R=R.unsqueeze(0), T=T.unsqueeze(0), device=R.device ) - cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]).permute((0, 2, 1)) - norm_mat = torch.tensor( - [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], - dtype=torch.float32, - device=R.device, + size_tensor = torch.tensor( + [[self.renderer._renderer.height, self.renderer._renderer.width]] ) - cam_rot = torch.matmul(norm_mat, cam_rot) - cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None])) - cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot)) + pulsar_cam = pulsar_from_cameras_projection(tmp_cams, size_tensor) + cam_pos = pulsar_cam[0, :3] + cam_rot = pulsar_cam[0, 3:9] return cam_pos, cam_rot def _get_vert_rad( @@ -547,15 +551,17 @@ class PulsarPointsRenderer(nn.Module): otherargs["bg_col"] = bg_col # Go! images.append( - self.renderer( - vert_pos=vert_pos, - vert_col=vert_col, - vert_rad=vert_rad, - cam_params=cam_params, - gamma=gamma, - max_depth=zfar, - min_depth=znear, - **otherargs, + torch.flipud( + self.renderer( + vert_pos=vert_pos, + vert_col=vert_col, + vert_rad=vert_rad, + cam_params=cam_params, + gamma=gamma, + max_depth=zfar, + min_depth=znear, + **otherargs, + ) ) ) return torch.stack(images, dim=0) diff --git a/pytorch3d/utils/__init__.py b/pytorch3d/utils/__init__.py index 90e7fe8a..ec0eeeb8 100644 --- a/pytorch3d/utils/__init__.py +++ b/pytorch3d/utils/__init__.py @@ -7,6 +7,8 @@ from .camera_conversions import ( cameras_from_opencv_projection, opencv_from_cameras_projection, + pulsar_from_opencv_projection, + pulsar_from_cameras_projection, ) from .ico_sphere import ico_sphere from .torus import torus diff --git a/pytorch3d/utils/camera_conversions.py b/pytorch3d/utils/camera_conversions.py index 866090cf..265f68c7 100644 --- a/pytorch3d/utils/camera_conversions.py +++ b/pytorch3d/utils/camera_conversions.py @@ -4,12 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Tuple import torch from ..renderer import PerspectiveCameras -from ..transforms import so3_exp_map, so3_log_map +from ..transforms import matrix_to_rotation_6d + + +LOGGER = logging.getLogger(__name__) def cameras_from_opencv_projection( @@ -54,7 +58,6 @@ def cameras_from_opencv_projection( Returns: cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention. """ - focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1) principal_point = camera_matrix[:, :2, 2] @@ -68,7 +71,7 @@ def cameras_from_opencv_projection( # For R, T we flip x, y axes (opencv screen space has an opposite # orientation of screen axes). # We also transpose R (opencv multiplies points from the opposite=left side). - R_pytorch3d = R.permute(0, 2, 1) + R_pytorch3d = R.clone().permute(0, 2, 1) T_pytorch3d = tvec.clone() R_pytorch3d[:, :, :2] *= -1 T_pytorch3d[:, :2] *= -1 @@ -103,20 +106,22 @@ def opencv_from_cameras_projection( cameras: A batch of `N` cameras in the PyTorch3D convention. image_size: A tensor of shape `(N, 2)` containing the sizes of the images (height, width) attached to each camera. + return_as_rotmat (bool): If set to True, return the full 3x3 rotation + matrices. Otherwise, return an axis-angle vector (default). Returns: R: A batch of rotation matrices of shape `(N, 3, 3)`. tvec: A batch of translation vectors of shape `(N, 3)`. camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`. """ - R_pytorch3d = cameras.R - T_pytorch3d = cameras.T + R_pytorch3d = cameras.R.clone() # pyre-ignore + T_pytorch3d = cameras.T.clone() # pyre-ignore focal_pytorch3d = cameras.focal_length p0_pytorch3d = cameras.principal_point - T_pytorch3d[:, :2] *= -1 # pyre-ignore - R_pytorch3d[:, :, :2] *= -1 # pyre-ignore - tvec = T_pytorch3d.clone() # pyre-ignore - R = R_pytorch3d.permute(0, 2, 1) # pyre-ignore + T_pytorch3d[:, :2] *= -1 + R_pytorch3d[:, :, :2] *= -1 + tvec = T_pytorch3d + R = R_pytorch3d.permute(0, 2, 1) # Retype the image_size correctly and flip to width, height. image_size_wh = image_size.to(R).flip(dims=(1,)) @@ -130,3 +135,151 @@ def opencv_from_cameras_projection( camera_matrix[:, 0, 0] = focal_length[:, 0] camera_matrix[:, 1, 1] = focal_length[:, 1] return R, tvec, camera_matrix + + +def pulsar_from_opencv_projection( + R: torch.Tensor, + tvec: torch.Tensor, + camera_matrix: torch.Tensor, + image_size: torch.Tensor, + znear: float = 0.1, +) -> torch.Tensor: + """ + Convert OpenCV style camera parameters to Pulsar style camera parameters. + + Note: + * Pulsar does NOT support different focal lengths for x and y. + For conversion, we use the average of fx and fy. + * The Pulsar renderer MUST use a left-handed coordinate system for this + mapping to work. + * The resulting image will be vertically flipped - which has to be + addressed AFTER rendering by the user. + * The parameters `R, tvec, camera_matrix` correspond to the outputs + of `cv2.decomposeProjectionMatrix`. + + Args: + R: A batch of rotation matrices of shape `(N, 3, 3)`. + tvec: A batch of translation vectors of shape `(N, 3)`. + camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`. + image_size: A tensor of shape `(N, 2)` containing the sizes of the images + (height, width) attached to each camera. + znear (float): The near clipping value to use for Pulsar. + + Returns: + cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar + convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width, + c_x, c_y). + """ + assert len(camera_matrix.size()) == 3, "This function requires batched inputs!" + assert len(R.size()) == 3, "This function requires batched inputs!" + assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!" + + # Validate parameters. + image_size_wh = image_size.to(R).flip(dims=(1,)) + assert torch.all( + image_size_wh > 0 + ), "height and width must be positive but min is: %s" % ( + str(image_size_wh.min().item()) + ) + assert ( + camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3 + ), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % ( + camera_matrix.size(1), + camera_matrix.size(2), + ) + assert ( + R.size(1) == 3 and R.size(2) == 3 + ), "Incorrect R shape: expected 3x3 but got %dx%d" % ( + R.size(1), + R.size(2), + ) + if len(tvec.size()) == 2: + tvec = tvec.unsqueeze(2) + assert ( + tvec.size(1) == 3 and tvec.size(2) == 1 + ), "Incorrect tvec shape: expected 3x1 but got %dx%d" % ( + tvec.size(1), + tvec.size(2), + ) + # Check batch size. + batch_size = camera_matrix.size(0) + assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % ( + batch_size, + R.size(0), + ) + assert ( + tvec.size(0) == batch_size + ), "Expected tvec to have batch size %d. Has size %d." % ( + batch_size, + tvec.size(0), + ) + # Check image sizes. + image_w = image_size_wh[0, 0] + image_h = image_size_wh[0, 1] + assert torch.all( + image_size_wh[:, 0] == image_w + ), "All images in a batch must have the same width!" + assert torch.all( + image_size_wh[:, 1] == image_h + ), "All images in a batch must have the same height!" + # Focal length. + fx = camera_matrix[:, 0, 0].unsqueeze(1) + fy = camera_matrix[:, 1, 1].unsqueeze(1) + # Check that we introduce less than 1% error by averaging the focal lengths. + fx_y = fx / fy + if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99): + LOGGER.warning( + "Pulsar only supports a single focal lengths. For converting OpenCV " + "focal lengths, we average them for x and y directions. " + "The focal lengths for x and y you provided differ by more than 1%, " + "which means this could introduce a noticeable error." + ) + f = (fx + fy) / 2 + # Normalize f into normalized device coordinates. + focal_length_px = f / image_w + # Transfer into focal_length and sensor_width. + focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device) + focal_length = focal_length[None, :].repeat(batch_size, 1) + sensor_width = focal_length / focal_length_px + # Principal point. + cx = camera_matrix[:, 0, 2].unsqueeze(1) + cy = camera_matrix[:, 1, 2].unsqueeze(1) + # Transfer principal point offset into centered offset. + cx = -(cx - image_w / 2) + cy = cy - image_h / 2 + # Concatenate to final vector. + param = torch.cat([focal_length, sensor_width, cx, cy], dim=1) + R_trans = R.permute(0, 2, 1) + cam_pos = -torch.bmm(R_trans, tvec).squeeze(2) + cam_rot = matrix_to_rotation_6d(R_trans) + cam_params = torch.cat([cam_pos, cam_rot, param], dim=1) + return cam_params + + +def pulsar_from_cameras_projection( + cameras: PerspectiveCameras, + image_size: torch.Tensor, +) -> torch.Tensor: + """ + Convert PyTorch3D `PerspectiveCameras` to Pulsar style camera parameters. + + Note: + * Pulsar does NOT support different focal lengths for x and y. + For conversion, we use the average of fx and fy. + * The Pulsar renderer MUST use a left-handed coordinate system for this + mapping to work. + * The resulting image will be vertically flipped - which has to be + addressed AFTER rendering by the user. + + Args: + cameras: A batch of `N` cameras in the PyTorch3D convention. + image_size: A tensor of shape `(N, 2)` containing the sizes of the images + (height, width) attached to each camera. + + Returns: + cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar + convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width, + c_x, c_y). + """ + opencv_R, opencv_T, opencv_K = opencv_from_cameras_projection(cameras, image_size) + return pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size) diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png index cb6b5cec..34c15cfd 100644 Binary files a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png index cb6b5cec..34c15cfd 100644 Binary files a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png index 507fc204..53ec18de 100644 Binary files a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png index 507fc204..53ec18de 100644 Binary files a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png differ diff --git a/tests/test_camera_conversions.py b/tests/test_camera_conversions.py index cacf3487..b6841e2a 100644 --- a/tests/test_camera_conversions.py +++ b/tests/test_camera_conversions.py @@ -12,10 +12,12 @@ import numpy as np import torch from common_testing import TestCaseMixin, get_tests_dir from pytorch3d.ops import eyes +from pytorch3d.renderer.points.pulsar import Renderer as PulsarRenderer from pytorch3d.transforms import so3_exp_map, so3_log_map from pytorch3d.utils import ( cameras_from_opencv_projection, opencv_from_cameras_projection, + pulsar_from_opencv_projection, ) @@ -111,6 +113,9 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase): [105.0, 105.0], [120.0, 120.0], ] + # These values are in y, x format, but they should be in x, y format. + # The tests work like this because they only test for consistency, + # but this format is misleading. principal_point = [ [240, 320], [240.5, 320.3], @@ -160,3 +165,80 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase): self.assertClose(R, R_i) self.assertClose(tvec, tvec_i) self.assertClose(camera_matrix, camera_matrix_i) + + def test_pulsar_conversion(self): + """ + Tests that the cameras converted from opencv to pulsar convention + return correct projections of random 3D points. The check is done + against a set of results precomputed using `cv2.projectPoints` function. + """ + image_size = [[480, 640]] + R = [ + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ], + [ + [0.1968, -0.6663, -0.7192], + [0.7138, -0.4055, 0.5710], + [-0.6721, -0.6258, 0.3959], + ], + ] + tvec = [ + [10.0, 10.0, 3.0], + [-0.0, -0.0, 20.0], + ] + focal_length = [ + [100.0, 100.0], + [10.0, 10.0], + ] + principal_point = [ + [320, 240], + [320, 240], + ] + + principal_point, focal_length, R, tvec, image_size = [ + torch.FloatTensor(x) + for x in (principal_point, focal_length, R, tvec, image_size) + ] + camera_matrix = eyes(dim=3, N=2) + camera_matrix[:, 0, 0] = focal_length[:, 0] + camera_matrix[:, 1, 1] = focal_length[:, 1] + camera_matrix[:, :2, 2] = principal_point + rvec = so3_log_map(R) + pts = torch.tensor( + [[[0.0, 0.0, 120.0]], [[0.0, 0.0, 120.0]]], dtype=torch.float32 + ) + radii = torch.tensor([[1e-5], [1e-5]], dtype=torch.float32) + col = torch.zeros((2, 1, 1), dtype=torch.float32) + + # project the 3D points with the opencv projection function + pts_proj_opencv = cv2_project_points(pts, rvec, tvec, camera_matrix) + pulsar_cam = pulsar_from_opencv_projection( + R, tvec, camera_matrix, image_size, znear=100.0 + ) + pulsar_rend = PulsarRenderer( + 640, 480, 1, right_handed_system=False, n_channels=1 + ) + rendered = torch.flip( + pulsar_rend( + pts, + col, + radii, + pulsar_cam, + 1e-5, + max_depth=150.0, + min_depth=100.0, + ), + dims=(1,), + ) + for batch_id in range(2): + point_pos = torch.where(rendered[batch_id] == rendered[batch_id].min()) + point_pos = point_pos[1][0], point_pos[0][0] + self.assertLess( + torch.abs(point_pos[0] - pts_proj_opencv[batch_id, 0, 0]), 2 + ) + self.assertLess( + torch.abs(point_pos[1] - pts_proj_opencv[batch_id, 0, 1]), 2 + )