diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index baf7164a..4efdd0ea 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - +from .cameras_alignment import corresponding_cameras_alignment from .cubify import cubify from .graph_conv import GraphConv from .interp_face_attrs import interpolate_face_attributes diff --git a/pytorch3d/ops/cameras_alignment.py b/pytorch3d/ops/cameras_alignment.py new file mode 100644 index 00000000..55a67ecf --- /dev/null +++ b/pytorch3d/ops/cameras_alignment.py @@ -0,0 +1,215 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import TYPE_CHECKING + +import torch + +from .. import ops + + +if TYPE_CHECKING: + from pytorch3d.renderer.cameras import CamerasBase + + +def corresponding_cameras_alignment( + cameras_src: "CamerasBase", + cameras_tgt: "CamerasBase", + estimate_scale: bool = True, + mode: str = "extrinsics", + eps: float = 1e-9, +) -> "CamerasBase": + """ + .. warning:: + The `corresponding_cameras_alignment` API is experimental + and subject to change! + + Estimates a single similarity transformation between two sets of cameras + `cameras_src` and `cameras_tgt` and returns an aligned version of + `cameras_src`. + + Given source cameras [(R_1, T_1), (R_2, T_2), ..., (R_N, T_N)] and + target cameras [(R_1', T_1'), (R_2', T_2'), ..., (R_N', T_N')], + where (R_i, T_i) is a 2-tuple of the camera rotation and translation matrix + respectively, the algorithm finds a global rotation, translation and scale + (R_A, T_A, s_A) which aligns all source cameras with the target cameras + such that the following holds: + + Under the change of coordinates using a similarity transform + (R_A, T_A, s_A) a 3D point X' is mapped to X with: + ``` + X = (X' R_A + T_A) / s_A + ``` + Then, for all cameras `i`, we assume that the following holds: + ``` + X R_i + T_i = s' (X' R_i' + T_i'), + ``` + i.e. an adjusted point X' is mapped by a camera (R_i', T_i') + to the same point as imaged from camera (R_i, T_i) after resolving + the scale ambiguity with a global scalar factor s'. + + Substituting for X above gives rise to the following: + ``` + (X' R_A + T_A) / s_A R_i + T_i = s' (X' R_i' + T_i') // · s_A + (X' R_A + T_A) R_i + T_i s_A = (s' s_A) (X' R_i' + T_i') + s' := 1 / s_A # without loss of generality + (X' R_A + T_A) R_i + T_i s_A = X' R_i' + T_i' + X' R_A R_i + T_A R_i + T_i s_A = X' R_i' + T_i' + ^^^^^^^ ^^^^^^^^^^^^^^^^^ + ~= R_i' ~= T_i' + ``` + i.e. after estimating R_A, T_A, s_A, the aligned source cameras have + extrinsics: + `cameras_src_align = (R_A R_i, T_A R_i + T_i s_A) ~= (R_i', T_i')` + + We support two ways `R_A, T_A, s_A` can be estimated: + 1) `mode=='centers'` + Estimates the similarity alignment between camera centers using + Umeyama's algorithm (see `pytorch3d.ops.corresponding_points_alignment` + for details) and transforms camera extrinsics accordingly. + + 2) `mode=='extrinsics'` + Defines the alignment problem as a system + of the following equations: + ``` + for all i: + [ R_A 0 ] x [ R_i 0 ] = [ R_i' 0 ] + [ T_A^T 1 ] [ (s_A T_i^T) 1 ] [ T_i' 1 ] + ``` + `R_A, T_A` and `s_A` are then obtained by solving the + system in the least squares sense. + + The estimated camera transformation is a true similarity transform, i.e. + it cannot be a reflection. + + Args: + cameras_src: `N` cameras to be aligned. + cameras_tgt: `N` target cameras. + estimate_scale: Controls whether the alignment transform is rigid + (`estimate_scale=False`), or a similarity (`estimate_scale=True`). + `s_A` is set to `1` if `estimate_scale==False`. + mode: Controls the alignment algorithm. + Can be one either `'centers'` or `'extrinsics'`. Please refer to the + description above for details. + eps: A scalar for clamping to avoid dividing by zero. + Active when `estimate_scale==True`. + + Returns: + cameras_src_aligned: `cameras_src` after applying the alignment transform. + """ + + if cameras_src.R.shape[0] != cameras_tgt.R.shape[0]: + raise ValueError( + "cameras_src and cameras_tgt have to contain the same number of cameras!" + ) + + if mode == "centers": + align_fun = _align_camera_centers + elif mode == "extrinsics": + align_fun = _align_camera_extrinsics + else: + raise ValueError("mode has to be one of (centers, extrinsics)") + + align_t_R, align_t_T, align_t_s = align_fun( + cameras_src, cameras_tgt, estimate_scale=estimate_scale, eps=eps + ) + + # create a new cameras object and set the R and T accordingly + cameras_src_aligned = cameras_src.clone() + cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R) + cameras_src_aligned.T = ( + torch.bmm( + align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1), cameras_src.R + )[:, 0] + + cameras_src.T * align_t_s + ) + + return cameras_src_aligned + + +def _align_camera_centers( + cameras_src: "CamerasBase", + cameras_tgt: "CamerasBase", + estimate_scale: bool = True, + eps: float = 1e-9, +): + """ + Use Umeyama's algorithm to align the camera centers. + """ + centers_src = cameras_src.get_camera_center() + centers_tgt = cameras_tgt.get_camera_center() + align_t = ops.corresponding_points_alignment( + centers_src[None], + centers_tgt[None], + estimate_scale=estimate_scale, + allow_reflection=False, + eps=eps, + ) + # the camera transform is the inverse of the estimated transform between centers + align_t_R = align_t.R.permute(0, 2, 1) + align_t_T = -(torch.bmm(align_t.T[:, None], align_t_R))[:, 0] + align_t_s = align_t.s[0] + + return align_t_R, align_t_T, align_t_s + + +def _align_camera_extrinsics( + cameras_src: "CamerasBase", + cameras_tgt: "CamerasBase", + estimate_scale: bool = True, + eps: float = 1e-9, +): + """ + Get the global rotation R_A with svd of cov(RR^T): + ``` + R_A R_i = R_i' for all i + R_A [R_1 R_2 ... R_N] = [R_1' R_2' ... R_N'] + U, _, V = svd([R_1 R_2 ... R_N]^T [R_1' R_2' ... R_N']) + R_A = (U V^T)^T + ``` + """ + RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0) + U, _, V = torch.svd(RRcov) + align_t_R = V @ U.t() + + """ + The translation + scale `T_A` and `s_A` is computed by finding + a translation and scaling that aligns two tensors `A, B` + defined as follows: + ``` + T_A R_i + s_A T_i = T_i' ; for all i // · R_i^T + s_A T_i R_i^T + T_A = T_i' R_i^T ; for all i + ^^^^^^^^^ ^^^^^^^^^^ + A_i B_i + + A_i := T_i R_i^T + A = [A_1 A_2 ... A_N] + B_i := T_i' R_i^T + B = [B_1 B_2 ... B_N] + ``` + The scale s_A can be retrieved by matching the correlations of + the points sets A and B: + ``` + s_A = (A-mean(A))*(B-mean(B)).sum() / ((A-mean(A))**2).sum() + ``` + The translation `T_A` is then defined as: + ``` + T_A = mean(B) - mean(A) * s_A + ``` + """ + A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0] + B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0] + Amu = A.mean(0, keepdim=True) + Bmu = B.mean(0, keepdim=True) + if estimate_scale and A.shape[0] > 1: + # get the scaling component by matching covariances + # of centered A and centered B + Ac = A - Amu + Bc = B - Bmu + align_t_s = (Ac * Bc).mean() / (Ac ** 2).mean().clamp(eps) + else: + # set the scale to identity + align_t_s = 1.0 + # get the translation as the difference between the means of A and B + align_t_T = Bmu - align_t_s * Amu + + return align_t_R, align_t_T, align_t_s diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index dd435509..8ea5d4d3 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -13,8 +13,8 @@ from .utils import TensorProperties, convert_to_tensors_and_broadcast # Default values for rotation and translation matrices. -r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3) -t = np.expand_dims(np.zeros(3), axis=0) # (1, 3) +_R = torch.eye(3)[None] # (1, 3, 3) +_T = torch.zeros(1, 3) # (1, 3) class CamerasBase(TensorProperties): @@ -280,8 +280,8 @@ def OpenGLPerspectiveCameras( aspect_ratio=1.0, fov=60.0, degrees: bool = True, - R=r, - T=t, + R=_R, + T=_T, device="cpu", ): """ @@ -331,8 +331,8 @@ class FoVPerspectiveCameras(CamerasBase): aspect_ratio=1.0, fov=60.0, degrees: bool = True, - R=r, - T=t, + R=_R, + T=_T, device="cpu", ): """ @@ -436,7 +436,7 @@ class FoVPerspectiveCameras(CamerasBase): P[:, 2, 2] = z_sign * zfar / (zfar - znear) P[:, 2, 3] = -(zfar * znear) / (zfar - znear) - # Transpose the projection matrix as PyTorch3d transforms use row vectors. + # Transpose the projection matrix as PyTorch3D transforms use row vectors. transform = Transform3d(device=self.device) transform._matrix = P.transpose(1, 2).contiguous() return transform @@ -494,8 +494,8 @@ def OpenGLOrthographicCameras( left=-1.0, right=1.0, scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) - R=r, - T=t, + R=_R, + T=_T, device="cpu", ): """ @@ -540,8 +540,8 @@ class FoVOrthographicCameras(CamerasBase): max_x=1.0, min_x=-1.0, scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) - R=r, - T=t, + R=_R, + T=_T, device="cpu", ): """ @@ -688,7 +688,7 @@ we assume the parameters are in screen space. def SfMPerspectiveCameras( - focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu" + focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_R, device="cpu" ): """ SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead. @@ -747,8 +747,8 @@ class PerspectiveCameras(CamerasBase): self, focal_length=1.0, principal_point=((0.0, 0.0),), - R=r, - T=t, + R=_R, + T=_T, device="cpu", image_size=((-1, -1),), ): @@ -848,7 +848,7 @@ class PerspectiveCameras(CamerasBase): def SfMOrthographicCameras( - focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu" + focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu" ): """ SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead. @@ -906,8 +906,8 @@ class OrthographicCameras(CamerasBase): self, focal_length=1.0, principal_point=((0.0, 0.0),), - R=r, - T=t, + R=_R, + T=_T, device="cpu", image_size=((-1, -1),), ): @@ -1109,7 +1109,7 @@ def _get_sfm_calibration_matrix( ################################################ -def get_world_to_view_transform(R=r, T=t) -> Transform3d: +def get_world_to_view_transform(R=_R, T=_T) -> Transform3d: """ This function returns a Transform3d representing the transformation matrix to go from world space to view space by applying a rotation and diff --git a/tests/bm_cameras_alignment.py b/tests/bm_cameras_alignment.py new file mode 100644 index 00000000..6c65e3d2 --- /dev/null +++ b/tests/bm_cameras_alignment.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import itertools +from fvcore.common.benchmark import benchmark +from test_cameras_alignment import TestCamerasAlignment + + +def bm_cameras_alignment() -> None: + + case_grid = { + "batch_size": [10, 100, 1000], + "mode": ["centers", "extrinsics"], + "estimate_scale": [False, True], + } + test_cases = itertools.product(*case_grid.values()) + kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] + + benchmark( + TestCamerasAlignment.corresponding_cameras_alignment, + "CORRESPONDING_CAMERAS_ALIGNMENT", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 8755701b..b74cd74d 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -26,6 +26,7 @@ # SOFTWARE. import math +import typing import unittest import numpy as np @@ -47,6 +48,7 @@ from pytorch3d.renderer.cameras import ( look_at_view_transform, ) from pytorch3d.transforms import Transform3d +from pytorch3d.transforms.rotation_conversions import random_rotations from pytorch3d.transforms.so3 import so3_exponential_map @@ -132,6 +134,51 @@ def ndc_to_screen_points_naive(points, imsize): return torch.stack((x, y, z), dim=2) +def init_random_cameras( + cam_type: typing.Type[CamerasBase], batch_size: int, random_z: bool = False +): + cam_params = {} + T = torch.randn(batch_size, 3) * 0.03 + if not random_z: + T[:, 2] = 4 + R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0) + cam_params = {"R": R, "T": T} + if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras): + cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] + if cam_type == OpenGLPerspectiveCameras: + cam_params["fov"] = torch.rand(batch_size) * 60 + 30 + cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 + else: + cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9 + cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9 + elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras): + cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] + if cam_type == FoVPerspectiveCameras: + cam_params["fov"] = torch.rand(batch_size) * 60 + 30 + cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 + else: + cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9 + cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9 + elif cam_type in ( + SfMOrthographicCameras, + SfMPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, + ): + cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["principal_point"] = torch.randn((batch_size, 2)) + + else: + raise ValueError(str(cam_type)) + return cam_type(**cam_params) + + class TestCameraHelpers(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() @@ -410,7 +457,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): def test_get_camera_center(self, batch_size=10): T = torch.randn(batch_size, 3) - R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0) + R = random_rotations(batch_size) for cam_type in ( OpenGLPerspectiveCameras, OpenGLOrthographicCameras, @@ -426,48 +473,6 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): C_ = -torch.bmm(R, T[:, :, None])[:, :, 0] self.assertTrue(torch.allclose(C, C_, atol=1e-05)) - @staticmethod - def init_random_cameras(cam_type: CamerasBase, batch_size: int): - cam_params = {} - T = torch.randn(batch_size, 3) * 0.03 - T[:, 2] = 4 - R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0) - cam_params = {"R": R, "T": T} - if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras): - cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 - cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] - if cam_type == OpenGLPerspectiveCameras: - cam_params["fov"] = torch.rand(batch_size) * 60 + 30 - cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 - else: - cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9 - cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9 - cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9 - cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9 - elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras): - cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 - cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] - if cam_type == FoVPerspectiveCameras: - cam_params["fov"] = torch.rand(batch_size) * 60 + 30 - cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 - else: - cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9 - cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9 - cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9 - cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9 - elif cam_type in ( - SfMOrthographicCameras, - SfMPerspectiveCameras, - OrthographicCameras, - PerspectiveCameras, - ): - cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1 - cam_params["principal_point"] = torch.randn((batch_size, 2)) - - else: - raise ValueError(str(cam_type)) - return cam_type(**cam_params) - @staticmethod def init_equiv_cameras_ndc_screen(cam_type: CamerasBase, batch_size: int): T = torch.randn(batch_size, 3) * 0.03 @@ -508,7 +513,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): PerspectiveCameras, ): # init the cameras - cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size) + cameras = init_random_cameras(cam_type, batch_size) # xyz - the ground truth point cloud xyz = torch.randn(batch_size, num_points, 3) * 0.3 # xyz in camera coordinates @@ -572,7 +577,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): ): # init the cameras - cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size) + cameras = init_random_cameras(cam_type, batch_size) # xyz - the ground truth point cloud xyz = torch.randn(batch_size, num_points, 3) * 0.3 # image size @@ -618,7 +623,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): OrthographicCameras, PerspectiveCameras, ): - cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size) + cameras = init_random_cameras(cam_type, batch_size) cameras = cameras.to(torch.device("cpu")) cameras_clone = cameras.clone() diff --git a/tests/test_cameras_alignment.py b/tests/test_cameras_alignment.py new file mode 100644 index 00000000..0c7aa276 --- /dev/null +++ b/tests/test_cameras_alignment.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import numpy as np +import torch +from common_testing import TestCaseMixin +from pytorch3d.ops import corresponding_cameras_alignment +from pytorch3d.renderer.cameras import ( + OpenGLOrthographicCameras, + OpenGLPerspectiveCameras, + SfMOrthographicCameras, + SfMPerspectiveCameras, +) +from pytorch3d.transforms.rotation_conversions import random_rotations +from pytorch3d.transforms.so3 import so3_exponential_map, so3_relative_angle +from test_cameras import init_random_cameras + + +class TestCamerasAlignment(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + def test_corresponding_cameras_alignment(self): + """ + Checks the corresponding_cameras_alignment function. + """ + device = torch.device("cuda:0") + + # try few different random setups + for _ in range(3): + for estimate_scale in (True, False): + # init true alignment transform + R_align_gt = random_rotations(1, device=device)[0] + T_align_gt = torch.randn(3, dtype=torch.float32, device=device) + + # init true scale + if estimate_scale: + s_align_gt = torch.randn( + 1, dtype=torch.float32, device=device + ).exp() + else: + s_align_gt = torch.tensor(1.0, dtype=torch.float32, device=device) + + for cam_type in ( + SfMOrthographicCameras, + OpenGLPerspectiveCameras, + OpenGLOrthographicCameras, + SfMPerspectiveCameras, + ): + # try well-determined and underdetermined cases + for batch_size in (10, 4, 3, 2, 1): + # get random cameras + cameras = init_random_cameras( + cam_type, batch_size, random_z=True + ).to(device) + # try all alignment modes + for mode in ("extrinsics", "centers"): + # try different noise levels + for add_noise in (0.0, 0.01, 1e-4): + self._corresponding_cameras_alignment_test_case( + cameras, + R_align_gt, + T_align_gt, + s_align_gt, + estimate_scale, + mode, + add_noise, + ) + + def _corresponding_cameras_alignment_test_case( + self, + cameras, + R_align_gt, + T_align_gt, + s_align_gt, + estimate_scale, + mode, + add_noise, + ): + batch_size = cameras.R.shape[0] + + # get target camera centers + R_new = torch.bmm(R_align_gt[None].expand_as(cameras.R), cameras.R) + T_new = ( + torch.bmm(T_align_gt[None, None].repeat(batch_size, 1, 1), cameras.R)[:, 0] + + cameras.T + ) * s_align_gt + + if add_noise != 0.0: + R_new = torch.bmm( + R_new, so3_exponential_map(torch.randn_like(T_new) * add_noise) + ) + T_new += torch.randn_like(T_new) * add_noise + + # create new cameras from R_new and T_new + cameras_tgt = cameras.clone() + cameras_tgt.R = R_new + cameras_tgt.T = T_new + + # align cameras and cameras_tgt + cameras_aligned = corresponding_cameras_alignment( + cameras, cameras_tgt, estimate_scale=estimate_scale, mode=mode + ) + + if batch_size <= 2 and mode == "centers": + # underdetermined case - check only the center alignment error + # since the rotation and translation are ambiguous here + self.assertClose( + cameras_aligned.get_camera_center(), + cameras_tgt.get_camera_center(), + atol=max(add_noise * 7.0, 1e-4), + ) + + else: + + def _rmse(a): + return (torch.norm(a, dim=1, p=2) ** 2).mean().sqrt() + + if add_noise != 0.0: + # in a noisy case check mean rotation/translation error for + # extrinsic alignment and root mean center error for center alignment + if mode == "centers": + self.assertNormsClose( + cameras_aligned.get_camera_center(), + cameras_tgt.get_camera_center(), + _rmse, + atol=max(add_noise * 10.0, 1e-4), + ) + elif mode == "extrinsics": + angle_err = so3_relative_angle( + cameras_aligned.R, cameras_tgt.R + ).mean() + self.assertClose( + angle_err, torch.zeros_like(angle_err), atol=add_noise * 10.0 + ) + self.assertNormsClose( + cameras_aligned.T, cameras_tgt.T, _rmse, atol=add_noise * 7.0 + ) + else: + raise ValueError(mode) + + else: + # compare the rotations and translations of cameras + self.assertClose(cameras_aligned.R, cameras_tgt.R, atol=3e-4) + self.assertClose(cameras_aligned.T, cameras_tgt.T, atol=3e-4) + # compare the centers + self.assertClose( + cameras_aligned.get_camera_center(), + cameras_tgt.get_camera_center(), + atol=3e-4, + ) + + @staticmethod + def corresponding_cameras_alignment( + batch_size: int, estimate_scale: bool, mode: str, cam_type=SfMPerspectiveCameras + ): + device = torch.device("cuda:0") + cameras_src, cameras_tgt = [ + init_random_cameras(cam_type, batch_size, random_z=True).to(device) + for _ in range(2) + ] + + torch.cuda.synchronize() + + def compute_corresponding_cameras_alignment(): + corresponding_cameras_alignment( + cameras_src, cameras_tgt, estimate_scale=estimate_scale, mode=mode + ) + torch.cuda.synchronize() + + return compute_corresponding_cameras_alignment