diff --git a/pytorch3d/renderer/fisheyecameras.py b/pytorch3d/renderer/fisheyecameras.py new file mode 100644 index 00000000..7ce0ffd0 --- /dev/null +++ b/pytorch3d/renderer/fisheyecameras.py @@ -0,0 +1,585 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List, Optional, Tuple, Union + +import torch +from pytorch3d.common.datatypes import Device +from pytorch3d.renderer.cameras import _R, _T, CamerasBase + +_focal_length = torch.tensor(((1.0,),)) +_principal_point = torch.tensor(((0.0, 0.0),)) +_radial_params = torch.tensor(((0.0, 0.0, 0.0, 0.0, 0.0, 0.0),)) +_tangential_params = torch.tensor(((0.0, 0.0),)) +_thin_prism_params = torch.tensor(((0.0, 0.0, 0.0, 0.0),)) + + +class FishEyeCameras(CamerasBase): + """ + A class which extends Pinhole camera by considering radial, tangential and + thin-prism distortion. For the fisheye camera model, k1, k2, ..., k_n_radial are + polynomial coefficents to model radial distortions. Two common types of radial + distortions are barrel and pincusion radial distortions. + + a = x / z, b = y / z, r = (a*a+b*b)^(1/2) + th = atan(r) + [x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [a/r] + [y_r] [b/r] [1] + + + The tangential distortion parameters are p1 and p2. The primary cause is + due to the lens assembly not being centered over and parallel to the image plane. + tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1] + [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0] [2] + where rd^2 = x_r^2 + y_r^2 + + The thin-prism distortion is modeled with s1, s2, s3, s4 coefficients + thinPrismDistortion = [s0 * rd^2 + s1 rd^4] + [s2 * rd^2 + s3 rd^4] [3] + + The projection + proj = diag(f, f) * uvDistorted + [cu; cv] + uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion [4] + [y_r] + f is the focal length and cu, cv are principal points in x, y axis. + + """ + + _FIELDS = ( + "focal_length", + "principal_point", + "R", + "T", + "radial_params", + "tangential_params", + "thin_prism_params", + "world_coordinates", + "use_radial", + "use_tangential", + "use_tin_prism", + "device", + "image_size", + ) + + def __init__( + self, + focal_length=_focal_length, + principal_point=_principal_point, + radial_params=_radial_params, + tangential_params=_tangential_params, + thin_prism_params=_thin_prism_params, + R: torch.Tensor = _R, + T: torch.Tensor = _T, + world_coordinates: bool = False, + use_radial: bool = True, + use_tangential: bool = True, + use_thin_prism: bool = True, + device: Device = "cpu", + image_size: Optional[Union[List, Tuple, torch.Tensor]] = None, + ) -> None: + + """ + + Args: + focal_ength: Focal length of the camera in world units. + A tensor of shape (N, 1) for square pixels, + where N is number of transforms. + principal_point: xy coordinates of the center of + the principal point of the camera in pixels. + A tensor of shape (N, 2). + radial_params: parameters for radial distortions. + A tensor of shape (N, num_radial). + tangential_params:parameters for tangential distortions. + A tensor of shape (N, 2). + thin_prism_params: parameters for thin-prism distortions. + A tensor of shape (N, 4). + R: Rotation matrix of shape (N, 3, 3) + T: Translation matrix of shape (N, 3) + world_coordinates: if True, project from world coordinates; otherwise from camera + coordinates + use_radial: radial_distortion, default to True + use_tangential: tangential distortion, default to True + use_thin_prism: thin prism distortion, default to True + device: torch.device or string + image_size: (height, width) of image size. + A tensor of shape (N, 2) or a list/tuple. Required for screen cameras. + + """ + + kwargs = {"image_size": image_size} if image_size is not None else {} + super().__init__( + device=device, + R=R, + T=T, + **kwargs, # pyre-ignore + ) + if image_size is not None: + if (self.image_size < 1).any(): # pyre-ignore + raise ValueError("Image_size provided has invalid values") + else: + self.image_size = None + + self.device = device + self.focal = focal_length.to(self.device) + self.principal_point = principal_point.to(self.device) + self.radial_params = radial_params.to(self.device) + self.tangential_params = tangential_params.to(self.device) + self.thin_prism_params = thin_prism_params.to(self.device) + self.R = R + self.T = T + self.world_coordinates = world_coordinates + self.use_radial = use_radial + self.use_tangential = use_tangential + self.use_thin_prism = use_thin_prism + self.epsilon = 1e-10 + self.num_distortion_iters = 50 + + self.R = self.R.to(self.device) + self.T = self.T.to(self.device) + self.num_radial = radial_params.shape[-1] + + def _project_points_batch( + self, + focal, + principal_point, + radial_params, + tangential_params, + thin_prism_params, + points, + ) -> torch.Tensor: + """ + Takes in points in the local reference frame of the camera and projects it + onto the image plan. Since this is a symmetric model, points with negative z are + projected to the positive sphere. i.e project(1,1,-1) == project(-1,-1,1) + + Args: + focal: (1) + principal_point: (2) + radial_params: (num_radial) + tangential_params: (2) + thin_prism_params: (4) + points in the camera coordinate frame: (..., 3). E.g., (P, 3) (1, P, 3) + or (M, P, 3) where P is the number of points + + Returns: + projected_points in the image plane: (..., 3). E.g., (P, 3) or + (1, P, 3) or (M, P, 3) + + """ + assert points.shape[-1] == 3, "points shape incorrect" + ab = points[..., :2] / points[..., 2:] + uv_distorted = ab + + r = ab.norm(dim=-1) + th = r.atan() + theta_sq = th * th + + # compute radial distortions, eq 1 + t = theta_sq + theta_pow = torch.stack([t, t**2, t**3, t**4, t**5, t**6], dim=-1) + th_radial = 1 + torch.sum(theta_pow * radial_params, dim=-1) + + # compute th/r, using the limit for small values + th_divr = th / r + boolean_mask = abs(r) < self.epsilon + th_divr[boolean_mask] = 1.0 + + # the distorted coordinates -- except for focal length and principal point + # start with the radial term + coeff = th_radial * th_divr + xr_yr = coeff[..., None] * ab + xr_yr_squared_norm = torch.pow(xr_yr, 2).sum(dim=-1, keepdim=True) + + if self.use_radial: + uv_distorted = xr_yr + + # compute tangential distortions, eq 2 + if self.use_tangential: + temp = 2 * torch.sum( + xr_yr * tangential_params, + dim=-1, + ) + uv_distorted = uv_distorted + ( + temp[..., None] * xr_yr + xr_yr_squared_norm * tangential_params + ) + + # compute thin-prism distortions, eq 3 + sh = uv_distorted.shape[:-1] + if self.use_thin_prism: + radial_powers = torch.cat( + [xr_yr_squared_norm, xr_yr_squared_norm * xr_yr_squared_norm], dim=-1 + ) + uv_distorted[..., 0] = uv_distorted[..., 0] + torch.sum( + thin_prism_params[..., 0:2] * radial_powers, + dim=-1, + ) + uv_distorted[..., 1] = uv_distorted[..., 1] + torch.sum( + thin_prism_params[..., 2:4] * radial_powers, + dim=-1, + ) + # return value: distorted points on the uv plane, eq 4 + projected_points = focal * uv_distorted + principal_point + return torch.cat( + [projected_points, torch.ones(list(sh) + [1], device=self.device)], dim=-1 + ) + + def check_input(self, points: torch.Tensor, batch_size: int): + """ + Check if the shapes are broadcastable between points and transforms. + Accept points of shape (P, 3) or (1, P, 3) or (M, P, 3). The batch_size + for transforms should be 1 when points take (M, P, 3). The batch_size + can be 1 or N when points take shape (P, 3). + + Args: + points: tensor of shape (P, 3) or (1, P, 3) or (M, P, 3) + batch_size: number of transforms + + Returns: + Boolean value if the input shapes are compatible. + """ + if points.ndim > 3: + return False + if points.ndim == 3: + M, P, K = points.shape + if K != 3: + return False + if M > 1 and batch_size > 1: + return False + return True + + def transform_points( + self, points, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transform input points from camera space to image space. + Args: + points: tensor of (..., 3). E.g., (P, 3) or (1, P, 3), (M, P, 3) + eps: tiny number to avoid zero divsion + + Returns: + torch.Tensor + when points take shape (P, 3) or (1, P, 3), output is (N, P, 3) + when points take shape (M, P, 3), output is (M, P, 3) + where N is the number of transforms, P number of points + """ + # project from world space to camera space + if self.world_coordinates: + world_to_view_transform = self.get_world_to_view_transform( + R=self.R, T=self.T + ) + points = world_to_view_transform.transform_points( + points.to(self.device), eps=eps + ) + else: + points = points.to(self.device) + + # project from camera space to image space + N = len(self.radial_params) + if not self.check_input(points, N): + msg = "Expected points of (P, 3) with batch_size 1 or N, or shape (M, P, 3) \ + with batch_size 1; got points of shape %r and batch_size %r" + raise ValueError(msg % (points.shape, N)) + + if N == 1: + return self._project_points_batch( + self.focal[0], + self.principal_point[0], + self.radial_params[0], + self.tangential_params[0], + self.thin_prism_params[0], + points, + ) + else: + outputs = [] + for i in range(N): + outputs.append( + self._project_points_batch( + self.focal[i], + self.principal_point[i], + self.radial_params[i], + self.tangential_params[i], + self.thin_prism_params[i], + points, + ) + ) + outputs = torch.stack(outputs, dim=0) + return outputs.squeeze() + + def _unproject_points_batch( + self, + focal, + principal_point, + radial_params, + tangential_params, + thin_prism_params, + xy: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + focal: (1) + principal_point: (2) + radial_params: (num_radial) + tangential_params: (2) + thin_prism_params: (4) + xy: (..., 2) + + Returns: + point3d_est: (..., 3) + """ + sh = list(xy.shape[:-1]) + assert xy.shape[-1] == 2, "xy_depth shape incorrect" + uv_distorted = (xy - principal_point) / focal + + # get xr_yr from uvDistorted + xr_yr = self._compute_xr_yr_from_uv_distorted( + tangential_params, thin_prism_params, uv_distorted + ) + xr_yrNorm = torch.norm(xr_yr, dim=-1) + + # find theta + theta = self._get_theta_from_norm_xr_yr(radial_params, xr_yrNorm) + # get the point coordinates: + point3d_est = theta.new_ones(*sh, 3) + point3d_est[..., :2] = theta.tan()[..., None] / xr_yrNorm[..., None] * xr_yr + return point3d_est + + def unproject_points( + self, + xy_depth: torch.Tensor, + world_coordinates: bool = True, + scaled_depth_input: bool = False, + **kwargs, + ) -> torch.Tensor: + """ + Takes in 3-point ``uv_depth`` in the image plane of the camera and unprojects it + into the reference frame of the camera. + This function is the inverse of ``transform_points``. In particular it holds that + + X = unproject(project(X)) + and + x = project(unproject(s*x)) + + Args: + xy_depth: points in the image plane of shape (..., 3). E.g., + (P, 3) or (1, P, 3) or (M, P, 3) + world_coordinates: if the output is in world_coordinate, if False, convert to + camera coordinate + scaled_depth_input: False + + Returns: + unprojected_points in the camera frame with z = 1 + when points take shape (P, 3) or (1, P, 3), output is (N, P, 3) + when points take shape (M, P, 3), output is (M, P, 3) + where N is the number of transforms, P number of point + """ + xy_depth = xy_depth.to(self.device) + N = len(self.radial_params) + if N == 1: + return self._unproject_points_batch( + self.focal[0], + self.principal_point[0], + self.radial_params[0], + self.tangential_params[0], + self.thin_prism_params[0], + xy_depth[..., 0:2], + ) + else: + outputs = [] + for i in range(N): + outputs.append( + self._unproject_points_batch( + self.focal[i], + self.principal_point[i], + self.radial_params[i], + self.tangential_params[i], + self.thin_prism_params[i], + xy_depth[..., 0:2], + ) + ) + outputs = torch.stack(outputs, dim=0) + return outputs.squeeze() + + def _compute_xr_yr_from_uv_distorted( + self, tangential_params, thin_prism_params, uv_distorted: torch.Tensor + ) -> torch.Tensor: + """ + Helper function to compute the vector [x_r; y_r] from uvDistorted + + Args: + tangential_params: (2) + thin_prism_params: (4) + uv_distorted: (..., 2), E.g., (P, 2), (1, P, 2), (M, P, 2) + + Returns: + xr_yr: (..., 2) + """ + # early exit if we're not using any tangential/ thin prism distortions + if not self.use_tangential and not self.use_thin_prism: + return uv_distorted + + xr_yr = uv_distorted + # do Newton iterations to find xr_yr + for _ in range(self.num_distortion_iters): + # compute the estimated uvDistorted + uv_distorted_est = xr_yr + xr_yr_squared_norm = torch.pow(xr_yr, 2).sum(dim=-1, keepdim=True) + + if self.use_tangential: + temp = 2.0 * torch.sum( + xr_yr * tangential_params[..., 0:2], + dim=-1, + keepdim=True, + ) + uv_distorted_est = uv_distorted_est + ( + temp * xr_yr + xr_yr_squared_norm * tangential_params[..., 0:2] + ) + + if self.use_thin_prism: + radial_powers = torch.cat( + [xr_yr_squared_norm, xr_yr_squared_norm * xr_yr_squared_norm], + dim=-1, + ) + uv_distorted_est[..., 0] = uv_distorted_est[..., 0] + torch.sum( + thin_prism_params[..., 0:2] * radial_powers, + dim=-1, + ) + uv_distorted_est[..., 1] = uv_distorted_est[..., 1] + torch.sum( + thin_prism_params[..., 2:4] * radial_powers, + dim=-1, + ) + + # compute the derivative of uvDistorted wrt xr_yr + duv_distorted_dxryr = self._compute_duv_distorted_dxryr( + tangential_params, thin_prism_params, xr_yr, xr_yr_squared_norm[..., 0] + ) + # compute correction: + # note: the matrix duvDistorted_dxryr will be close to identity (for reasonable + # values of tangential/thin prism distortions) + correction = torch.linalg.solve( + duv_distorted_dxryr, (uv_distorted - uv_distorted_est)[..., None] + ) + xr_yr = xr_yr + correction[..., 0] + return xr_yr + + def _get_theta_from_norm_xr_yr( + self, radial_params, th_radial_desired + ) -> torch.Tensor: + """ + Helper function to compute the angle theta from the norm of the vector [x_r; y_r] + + Args: + radial_params: k1, k2, ..., k_num_radial, (num_radial) + th_radial_desired: desired angle of shape (...), E.g., (P), (1, P), (M, P) + + Returns: + th: angle theta (in radians) of shape (...), E.g., (P), (1, P), (M, P) + """ + sh = list(th_radial_desired.shape) + # th = th_radial_desired.clone() + th = th_radial_desired + c = torch.tensor( + [2.0 * i + 3 for i in range(self.num_radial)], device=self.device + ) + for _ in range(self.num_distortion_iters): + theta_sq = th * th + th_radial = 1.0 + dthD_dth = 1.0 + + # compute the theta polynomial and its derivative wrt theta + t = theta_sq + theta_pow = torch.stack([t, t**2, t**3, t**4, t**5, t**6], dim=-1) + th_radial = th_radial + torch.sum(theta_pow * radial_params, dim=-1) + + dthD_dth = dthD_dth + torch.sum(c * radial_params * theta_pow, dim=-1) + th_radial = th_radial * th + + # compute the correction + step = torch.zeros(*sh, device=self.device) + # make sure don't divide by zero + nonzero_mask = dthD_dth.abs() > self.epsilon + step = step + nonzero_mask * (th_radial_desired - th_radial) / dthD_dth + # if derivative is close to zero, apply small correction in the appropriate + # direction to avoid numerical explosions + close_to_zero_mask = dthD_dth.abs() <= self.epsilon + dir_mask = (th_radial_desired - th_radial) * dthD_dth > 0.0 + boolean_mask = close_to_zero_mask & dir_mask + step = step + 10.0 * self.epsilon * boolean_mask + step = step - 10 * self.epsilon * (~nonzero_mask & ~boolean_mask) + + # apply correction + th = th + step + # revert to within 180 degrees FOV to avoid numerical overflow + idw = th.abs() >= math.pi / 2.0 + th[idw] = 0.999 * math.pi / 2.0 + return th + + def _compute_duv_distorted_dxryr( + self, tangential_params, thin_prism_params, xr_yr, xr_yr_squareNorm + ) -> torch.Tensor: + """ + Helper function, computes the Jacobian of uvDistorted wrt the vector [x_r;y_r] + + Args: + tangential_params: (2) + thin_prism_params: (4) + xr_yr: (P, 2) + xr_yr_squareNorm: (...), E.g., (P), (1, P), (M, P) + + Returns: + duv_distorted_dxryr: (..., 2, 2) Jacobian + """ + sh = list(xr_yr.shape[:-1]) + duv_distorted_dxryr = torch.empty((*sh, 2, 2), device=self.device) + if self.use_tangential: + duv_distorted_dxryr[..., 0, 0] = ( + 1.0 + + 6.0 * xr_yr[..., 0] * tangential_params[..., 0] + + 2.0 * xr_yr[..., 1] * tangential_params[..., 1] + ) + offdiag = 2.0 * ( + xr_yr[..., 0] * tangential_params[..., 1] + + xr_yr[..., 1] * tangential_params[..., 0] + ) + duv_distorted_dxryr[..., 0, 1] = offdiag + duv_distorted_dxryr[..., 1, 0] = offdiag + duv_distorted_dxryr[..., 1, 1] = ( + 1.0 + + 6.0 * xr_yr[..., 1] * tangential_params[..., 1] + + 2.0 * xr_yr[..., 0] * tangential_params[..., 0] + ) + else: + duv_distorted_dxryr = torch.eye(2).view(*sh, 2, 2).expand(*sh, 1, 1) + + if self.use_thin_prism: + temp1 = 2.0 * ( + thin_prism_params[..., 0] + + 2.0 * thin_prism_params[..., 1] * xr_yr_squareNorm[...] + ) + duv_distorted_dxryr[..., 0, 0] = ( + duv_distorted_dxryr[..., 0, 0] + xr_yr[..., 0] * temp1 + ) + duv_distorted_dxryr[..., 0, 1] = ( + duv_distorted_dxryr[..., 0, 1] + xr_yr[..., 1] * temp1 + ) + + temp2 = 2.0 * ( + thin_prism_params[..., 2] + + 2.0 * thin_prism_params[..., 3] * xr_yr_squareNorm[...] + ) + duv_distorted_dxryr[..., 1, 0] = ( + duv_distorted_dxryr[..., 1, 0] + xr_yr[..., 0] * temp2 + ) + duv_distorted_dxryr[..., 1, 1] = ( + duv_distorted_dxryr[..., 1, 1] + xr_yr[..., 1] * temp2 + ) + return duv_distorted_dxryr + + def in_ndc(self): + return True + + def is_perspective(self): + return False diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 5feccca8..41ff2b92 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -52,6 +52,7 @@ from pytorch3d.renderer.cameras import ( SfMOrthographicCameras, SfMPerspectiveCameras, ) +from pytorch3d.renderer.fisheyecameras import FishEyeCameras from pytorch3d.transforms import Transform3d from pytorch3d.transforms.rotation_conversions import random_rotations from pytorch3d.transforms.so3 import so3_exp_map @@ -186,6 +187,12 @@ def init_random_cameras( ): cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1 cam_params["principal_point"] = torch.randn((batch_size, 2)) + elif cam_type == FishEyeCameras: + cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1 + cam_params["principal_point"] = torch.randn((batch_size, 2)) + cam_params["radial_params"] = torch.randn((batch_size, 6)) + cam_params["tangential_params"] = torch.randn((batch_size, 2)) + cam_params["thin_prism_params"] = torch.randn((batch_size, 4)) else: raise ValueError(str(cam_type)) @@ -196,7 +203,6 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() torch.manual_seed(42) - np.random.seed(42) def test_look_at_view_transform_from_eye_point_tuple(self): dist = math.sqrt(2) @@ -606,6 +612,22 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): ) self.assertTrue(torch.allclose(xyz_unproj, matching_xyz, atol=1e-4)) + @staticmethod + def unproject_points(cam_type, batch_size=50, num_points=100): + """ + Checks that an unprojection of a randomly projected point cloud + stays the same. + """ + + def run_cameras(): + # init the cameras + cameras = init_random_cameras(cam_type, batch_size) + # xyz - the ground truth point cloud + xyz = torch.randn(num_points, 3) * 0.3 + xyz = cameras.unproject_points(xyz, scaled_depth_input=True) + + return run_cameras + def test_project_points_screen(self, batch_size=50, num_points=100): """ Checks that an unprojection of a randomly projected point cloud @@ -643,6 +665,24 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): # we set atol to 1e-4, remember that screen points are in [0, W]x[0, H] space self.assertClose(xyz_project_screen, xyz_project_screen_naive, atol=1e-4) + @staticmethod + def transform_points(cam_type, batch_size=50, num_points=100): + """ + Checks that an unprojection of a randomly projected point cloud + stays the same. + """ + + def run_cameras(): + # init the cameras + cameras = init_random_cameras(cam_type, batch_size) + # xyz - the ground truth point cloud + xy = torch.randn(num_points, 2) * 2.0 - 1.0 + z = torch.randn(num_points, 1) * 3.0 + 1.0 + xyz = torch.cat((xy, z), dim=-1) + xy = cameras.transform_points(xyz) + + return run_cameras + def test_equiv_project_points(self, batch_size=50, num_points=100): """ Checks that NDC and screen cameras project points to ndc correctly. @@ -1251,3 +1291,284 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): # Check in_ndc is handled correctly self.assertEqual(cam._in_ndc, c0._in_ndc) + + +############################################################ +# FishEye Camera # +############################################################ + + +class TestFishEyeProjection(TestCaseMixin, unittest.TestCase): + def setUpSimpleCase(self) -> None: + super().setUp() + focal = torch.tensor([[240]], dtype=torch.float32) + principal_point = torch.tensor([[320, 240]]) + p_3d = torch.tensor( + [ + [2.0, 3.0, 1.0], + [3.0, 2.0, 1.0], + ], + dtype=torch.float32, + ) + return focal, principal_point, p_3d + + def setUpAriaCase(self) -> None: + super().setUp() + torch.manual_seed(42) + focal = torch.tensor([[608.9255557152]], dtype=torch.float32) + principal_point = torch.tensor( + [[712.0114821205, 706.8666571177]], dtype=torch.float32 + ) + radial_params = torch.tensor( + [ + [ + 0.3877090026, + -0.315613384, + -0.3434984955, + 1.8565874201, + -2.1799372221, + 0.7713834763, + ], + ], + dtype=torch.float32, + ) + tangential_params = torch.tensor( + [[-0.0002747019, 0.0005228974]], dtype=torch.float32 + ) + thin_prism_params = torch.tensor( + [ + [0.000134884, -0.000084822, -0.0009420014, -0.0001276838], + ], + dtype=torch.float32, + ) + return ( + focal, + principal_point, + radial_params, + tangential_params, + thin_prism_params, + ) + + def setUpBatchCameras(self) -> None: + super().setUp() + focal, principal_point, p_3d = self.setUpSimpleCase() + radial_params = torch.tensor( + [ + [0, 0, 0, 0, 0, 0], + ], + dtype=torch.float32, + ) + tangential_params = torch.tensor([[0, 0]], dtype=torch.float32) + thin_prism_params = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32) + ( + focal1, + principal_point1, + radial_params1, + tangential_params1, + thin_prism_params1, + ) = self.setUpAriaCase() + focal = torch.cat([focal, focal1], dim=0) + principal_point = torch.cat([principal_point, principal_point1], dim=0) + radial_params = torch.cat([radial_params, radial_params1], dim=0) + tangential_params = torch.cat([tangential_params, tangential_params1], dim=0) + thin_prism_params = torch.cat([thin_prism_params, thin_prism_params1], dim=0) + + cameras = FishEyeCameras( + focal_length=focal, + principal_point=principal_point, + radial_params=radial_params, + tangential_params=tangential_params, + thin_prism_params=thin_prism_params, + ) + + return cameras + + def test_distortion_params_set_to_zeors(self): + # test case 1: all distortion params are 0. Note that + # setting radial_params to zeros is not equivalent to + # disabling radial distortions, set use_radial=False does + focal, principal_point, p_3d = self.setUpSimpleCase() + cameras = FishEyeCameras( + focal_length=focal, + principal_point=principal_point, + ) + uv_case1 = cameras.transform_points(p_3d) + self.assertClose( + uv_case1, + torch.tensor( + [[493.0993, 499.6489, 1.0], [579.6489, 413.0993, 1.0]], + ), + ) + # test case 2: equivalent of test case 1 by + # disabling use_tangential and use_thin_prism + cameras = FishEyeCameras( + focal_length=focal, + principal_point=principal_point, + use_tangential=False, + use_thin_prism=False, + ) + uv_case2 = cameras.transform_points(p_3d) + self.assertClose(uv_case2, uv_case1) + + def test_fisheye_against_perspective_cameras(self): + # test case: check equivalence with PerspectiveCameras + # by disabling all distortions + focal, principal_point, p_3d = self.setUpSimpleCase() + cameras = PerspectiveCameras( + focal_length=focal, + principal_point=principal_point, + ) + P = cameras.get_projection_transform() + uv_perspective = P.transform_points(p_3d) + + # disable all distortions + cameras = FishEyeCameras( + focal_length=focal, + principal_point=principal_point, + use_radial=False, + use_tangential=False, + use_thin_prism=False, + ) + uv = cameras.transform_points(p_3d) + self.assertClose(uv, uv_perspective) + + def test_project_shape_broadcasts(self): + focal, principal_point, p_3d = self.setUpSimpleCase() + # test case 1: + # 1 transform with points of shape (P, 3) -> (P, 3) + # 1 transform with points of shape (1, P, 3) -> (1, P, 3) + # 1 transform with points of shape (M, P, 3) -> (M, P, 3) + points = p_3d.repeat(1, 1, 1) + cameras = FishEyeCameras( + focal_length=focal, + principal_point=principal_point, + use_radial=False, + use_tangential=False, + use_thin_prism=False, + ) + uv = cameras.transform_points(p_3d) + uv_point_batch = cameras.transform_points(points) + self.assertClose(uv_point_batch, uv.repeat(1, 1, 1)) + + points = p_3d.repeat(3, 1, 1) + uv_point_batch = cameras.transform_points(points) + self.assertClose(uv_point_batch, uv.repeat(3, 1, 1)) + + # test case 2 + # test with N transforms and points of shape (P, 3) -> (N, P, 3) + # test with N transforms and points of shape (1, P, 3) -> (N, P, 3) + + # first camera transform params + cameras = self.setUpBatchCameras() + p_3d = torch.tensor( + [ + [2.0, 3.0, 1.0], + [2.0, 3.0, 1.0], + [3.0, 2.0, 1.0], + ] + ) + expected_res = torch.tensor( + [ + [ + [493.0993, 499.6489, 1.0], + [493.0993, 499.6489, 1.0], + [579.6489, 413.0993, 1.0], + ], + [ + [1660.2700, 2128.2273, 1.0], + [1660.2700, 2128.2273, 1.0], + [2134.5815, 1650.9565, 1.0], + ], + ] + ) + uv_point_batch = cameras.transform_points(p_3d) + self.assertClose(uv_point_batch, expected_res) + + uv_point_batch = cameras.transform_points(p_3d.repeat(1, 1, 1)) + self.assertClose(uv_point_batch, expected_res) + + def test_cuda(self): + """ + Test cuda device + """ + focal, principal_point, p_3d = self.setUpSimpleCase() + cameras_cuda = FishEyeCameras( + focal_length=focal, + principal_point=principal_point, + device="cuda:0", + ) + uv = cameras_cuda.transform_points(p_3d) + expected_res = torch.tensor( + [[493.0993, 499.6489, 1.0], [579.6489, 413.0993, 1.0]], + ) + self.assertClose(uv, expected_res.to("cuda:0")) + + rep_3d = cameras_cuda.unproject_points(uv) + self.assertClose(rep_3d, p_3d.to("cuda:0")) + + def test_unproject_shape_broadcasts(self): + # test case 1: + # 1 transform with points of (P, 3) -> (P, 3) + # 1 transform with points of (M, P, 3) -> (M, P, 3) + ( + focal, + principal_point, + radial_params, + tangential_params, + thin_prism_params, + ) = self.setUpAriaCase() + xy_depth = torch.tensor( + [ + [2134.5814033, 1650.95653328, 1.0], + [1074.25442904, 1159.52461285, 1.0], + ] + ) + cameras = FishEyeCameras( + focal_length=focal, + principal_point=principal_point, + radial_params=radial_params, + tangential_params=tangential_params, + thin_prism_params=thin_prism_params, + ) + rep_3d = cameras.unproject_points(xy_depth) + expected_res = torch.tensor( + [ + [3.0000, 2.0000, 1.0000], + [0.666667, 0.833333, 1.0000], + ], + ) + self.assertClose(rep_3d, expected_res) + + rep_3d = cameras.unproject_points(xy_depth.repeat(3, 1, 1)) + self.assertClose(rep_3d, expected_res.repeat(3, 1, 1)) + + # test case 2: + # N transforms with points of (P, 3) -> (N, P, 3) + # N transforms with points of (1, P, 3) -> (N, P, 3) + cameras = FishEyeCameras( + focal_length=focal.repeat(2, 1), + principal_point=principal_point.repeat(2, 1), + radial_params=radial_params.repeat(2, 1), + tangential_params=tangential_params.repeat(2, 1), + thin_prism_params=thin_prism_params.repeat(2, 1), + ) + rep_3d = cameras.unproject_points(xy_depth) + self.assertClose(rep_3d, expected_res.repeat(2, 1, 1)) + + def test_unhandled_shape(self): + """ + Test error handling when shape of transforms + and points are not expected. + """ + cameras = self.setUpBatchCameras() + points = torch.rand(3, 3, 1) + with self.assertRaises(ValueError): + cameras.transform_points(points) + + def test_getitem(self): + # Check get item returns an instance of the same class + # with all the same keys + cam = self.setUpBatchCameras() + c0 = cam[0] + self.assertTrue(isinstance(c0, FishEyeCameras)) + self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())