Camera alignment

Summary:
adds `corresponding_cameras_alignment` function that estimates a similarity transformation between two sets of cameras.

The function is essential for computing camera errors in SfM pipelines.

```
Benchmark                                                   Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
CORRESPONDING_CAMERAS_ALIGNMENT_10_centers_False                32219           36211             16
CORRESPONDING_CAMERAS_ALIGNMENT_10_centers_True                 32429           36063             16
CORRESPONDING_CAMERAS_ALIGNMENT_10_extrinsics_False              5548            8782             91
CORRESPONDING_CAMERAS_ALIGNMENT_10_extrinsics_True               6153            9752             82
CORRESPONDING_CAMERAS_ALIGNMENT_100_centers_False               33344           40398             16
CORRESPONDING_CAMERAS_ALIGNMENT_100_centers_True                34528           37095             15
CORRESPONDING_CAMERAS_ALIGNMENT_100_extrinsics_False             5576            7187             90
CORRESPONDING_CAMERAS_ALIGNMENT_100_extrinsics_True              6256            9166             80
CORRESPONDING_CAMERAS_ALIGNMENT_1000_centers_False              32020           37247             16
CORRESPONDING_CAMERAS_ALIGNMENT_1000_centers_True               32776           37644             16
CORRESPONDING_CAMERAS_ALIGNMENT_1000_extrinsics_False            5336            8795             94
CORRESPONDING_CAMERAS_ALIGNMENT_1000_extrinsics_True             6266            9929             80
--------------------------------------------------------------------------------
```

Reviewed By: shapovalov

Differential Revision: D22946415

fbshipit-source-id: 8caae7ee365b304d8aa1f8133cf0dd92c35bc0dd
This commit is contained in:
David Novotny 2020-09-03 13:26:13 -07:00 committed by Facebook GitHub Bot
parent 14f015d8bf
commit 316b77782e
6 changed files with 482 additions and 65 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .cameras_alignment import corresponding_cameras_alignment
from .cubify import cubify
from .graph_conv import GraphConv
from .interp_face_attrs import interpolate_face_attributes

View File

@ -0,0 +1,215 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import TYPE_CHECKING
import torch
from .. import ops
if TYPE_CHECKING:
from pytorch3d.renderer.cameras import CamerasBase
def corresponding_cameras_alignment(
cameras_src: "CamerasBase",
cameras_tgt: "CamerasBase",
estimate_scale: bool = True,
mode: str = "extrinsics",
eps: float = 1e-9,
) -> "CamerasBase":
"""
.. warning::
The `corresponding_cameras_alignment` API is experimental
and subject to change!
Estimates a single similarity transformation between two sets of cameras
`cameras_src` and `cameras_tgt` and returns an aligned version of
`cameras_src`.
Given source cameras [(R_1, T_1), (R_2, T_2), ..., (R_N, T_N)] and
target cameras [(R_1', T_1'), (R_2', T_2'), ..., (R_N', T_N')],
where (R_i, T_i) is a 2-tuple of the camera rotation and translation matrix
respectively, the algorithm finds a global rotation, translation and scale
(R_A, T_A, s_A) which aligns all source cameras with the target cameras
such that the following holds:
Under the change of coordinates using a similarity transform
(R_A, T_A, s_A) a 3D point X' is mapped to X with:
```
X = (X' R_A + T_A) / s_A
```
Then, for all cameras `i`, we assume that the following holds:
```
X R_i + T_i = s' (X' R_i' + T_i'),
```
i.e. an adjusted point X' is mapped by a camera (R_i', T_i')
to the same point as imaged from camera (R_i, T_i) after resolving
the scale ambiguity with a global scalar factor s'.
Substituting for X above gives rise to the following:
```
(X' R_A + T_A) / s_A R_i + T_i = s' (X' R_i' + T_i') // · s_A
(X' R_A + T_A) R_i + T_i s_A = (s' s_A) (X' R_i' + T_i')
s' := 1 / s_A # without loss of generality
(X' R_A + T_A) R_i + T_i s_A = X' R_i' + T_i'
X' R_A R_i + T_A R_i + T_i s_A = X' R_i' + T_i'
^^^^^^^ ^^^^^^^^^^^^^^^^^
~= R_i' ~= T_i'
```
i.e. after estimating R_A, T_A, s_A, the aligned source cameras have
extrinsics:
`cameras_src_align = (R_A R_i, T_A R_i + T_i s_A) ~= (R_i', T_i')`
We support two ways `R_A, T_A, s_A` can be estimated:
1) `mode=='centers'`
Estimates the similarity alignment between camera centers using
Umeyama's algorithm (see `pytorch3d.ops.corresponding_points_alignment`
for details) and transforms camera extrinsics accordingly.
2) `mode=='extrinsics'`
Defines the alignment problem as a system
of the following equations:
```
for all i:
[ R_A 0 ] x [ R_i 0 ] = [ R_i' 0 ]
[ T_A^T 1 ] [ (s_A T_i^T) 1 ] [ T_i' 1 ]
```
`R_A, T_A` and `s_A` are then obtained by solving the
system in the least squares sense.
The estimated camera transformation is a true similarity transform, i.e.
it cannot be a reflection.
Args:
cameras_src: `N` cameras to be aligned.
cameras_tgt: `N` target cameras.
estimate_scale: Controls whether the alignment transform is rigid
(`estimate_scale=False`), or a similarity (`estimate_scale=True`).
`s_A` is set to `1` if `estimate_scale==False`.
mode: Controls the alignment algorithm.
Can be one either `'centers'` or `'extrinsics'`. Please refer to the
description above for details.
eps: A scalar for clamping to avoid dividing by zero.
Active when `estimate_scale==True`.
Returns:
cameras_src_aligned: `cameras_src` after applying the alignment transform.
"""
if cameras_src.R.shape[0] != cameras_tgt.R.shape[0]:
raise ValueError(
"cameras_src and cameras_tgt have to contain the same number of cameras!"
)
if mode == "centers":
align_fun = _align_camera_centers
elif mode == "extrinsics":
align_fun = _align_camera_extrinsics
else:
raise ValueError("mode has to be one of (centers, extrinsics)")
align_t_R, align_t_T, align_t_s = align_fun(
cameras_src, cameras_tgt, estimate_scale=estimate_scale, eps=eps
)
# create a new cameras object and set the R and T accordingly
cameras_src_aligned = cameras_src.clone()
cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
cameras_src_aligned.T = (
torch.bmm(
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1), cameras_src.R
)[:, 0]
+ cameras_src.T * align_t_s
)
return cameras_src_aligned
def _align_camera_centers(
cameras_src: "CamerasBase",
cameras_tgt: "CamerasBase",
estimate_scale: bool = True,
eps: float = 1e-9,
):
"""
Use Umeyama's algorithm to align the camera centers.
"""
centers_src = cameras_src.get_camera_center()
centers_tgt = cameras_tgt.get_camera_center()
align_t = ops.corresponding_points_alignment(
centers_src[None],
centers_tgt[None],
estimate_scale=estimate_scale,
allow_reflection=False,
eps=eps,
)
# the camera transform is the inverse of the estimated transform between centers
align_t_R = align_t.R.permute(0, 2, 1)
align_t_T = -(torch.bmm(align_t.T[:, None], align_t_R))[:, 0]
align_t_s = align_t.s[0]
return align_t_R, align_t_T, align_t_s
def _align_camera_extrinsics(
cameras_src: "CamerasBase",
cameras_tgt: "CamerasBase",
estimate_scale: bool = True,
eps: float = 1e-9,
):
"""
Get the global rotation R_A with svd of cov(RR^T):
```
R_A R_i = R_i' for all i
R_A [R_1 R_2 ... R_N] = [R_1' R_2' ... R_N']
U, _, V = svd([R_1 R_2 ... R_N]^T [R_1' R_2' ... R_N'])
R_A = (U V^T)^T
```
"""
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
U, _, V = torch.svd(RRcov)
align_t_R = V @ U.t()
"""
The translation + scale `T_A` and `s_A` is computed by finding
a translation and scaling that aligns two tensors `A, B`
defined as follows:
```
T_A R_i + s_A T_i = T_i' ; for all i // · R_i^T
s_A T_i R_i^T + T_A = T_i' R_i^T ; for all i
^^^^^^^^^ ^^^^^^^^^^
A_i B_i
A_i := T_i R_i^T
A = [A_1 A_2 ... A_N]
B_i := T_i' R_i^T
B = [B_1 B_2 ... B_N]
```
The scale s_A can be retrieved by matching the correlations of
the points sets A and B:
```
s_A = (A-mean(A))*(B-mean(B)).sum() / ((A-mean(A))**2).sum()
```
The translation `T_A` is then defined as:
```
T_A = mean(B) - mean(A) * s_A
```
"""
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
Amu = A.mean(0, keepdim=True)
Bmu = B.mean(0, keepdim=True)
if estimate_scale and A.shape[0] > 1:
# get the scaling component by matching covariances
# of centered A and centered B
Ac = A - Amu
Bc = B - Bmu
align_t_s = (Ac * Bc).mean() / (Ac ** 2).mean().clamp(eps)
else:
# set the scale to identity
align_t_s = 1.0
# get the translation as the difference between the means of A and B
align_t_T = Bmu - align_t_s * Amu
return align_t_R, align_t_T, align_t_s

View File

@ -13,8 +13,8 @@ from .utils import TensorProperties, convert_to_tensors_and_broadcast
# Default values for rotation and translation matrices.
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
_R = torch.eye(3)[None] # (1, 3, 3)
_T = torch.zeros(1, 3) # (1, 3)
class CamerasBase(TensorProperties):
@ -280,8 +280,8 @@ def OpenGLPerspectiveCameras(
aspect_ratio=1.0,
fov=60.0,
degrees: bool = True,
R=r,
T=t,
R=_R,
T=_T,
device="cpu",
):
"""
@ -331,8 +331,8 @@ class FoVPerspectiveCameras(CamerasBase):
aspect_ratio=1.0,
fov=60.0,
degrees: bool = True,
R=r,
T=t,
R=_R,
T=_T,
device="cpu",
):
"""
@ -436,7 +436,7 @@ class FoVPerspectiveCameras(CamerasBase):
P[:, 2, 2] = z_sign * zfar / (zfar - znear)
P[:, 2, 3] = -(zfar * znear) / (zfar - znear)
# Transpose the projection matrix as PyTorch3d transforms use row vectors.
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
return transform
@ -494,8 +494,8 @@ def OpenGLOrthographicCameras(
left=-1.0,
right=1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R=r,
T=t,
R=_R,
T=_T,
device="cpu",
):
"""
@ -540,8 +540,8 @@ class FoVOrthographicCameras(CamerasBase):
max_x=1.0,
min_x=-1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R=r,
T=t,
R=_R,
T=_T,
device="cpu",
):
"""
@ -688,7 +688,7 @@ we assume the parameters are in screen space.
def SfMPerspectiveCameras(
focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_R, device="cpu"
):
"""
SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
@ -747,8 +747,8 @@ class PerspectiveCameras(CamerasBase):
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
R=_R,
T=_T,
device="cpu",
image_size=((-1, -1),),
):
@ -848,7 +848,7 @@ class PerspectiveCameras(CamerasBase):
def SfMOrthographicCameras(
focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu"
):
"""
SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
@ -906,8 +906,8 @@ class OrthographicCameras(CamerasBase):
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
R=_R,
T=_T,
device="cpu",
image_size=((-1, -1),),
):
@ -1109,7 +1109,7 @@ def _get_sfm_calibration_matrix(
################################################
def get_world_to_view_transform(R=r, T=t) -> Transform3d:
def get_world_to_view_transform(R=_R, T=_T) -> Transform3d:
"""
This function returns a Transform3d representing the transformation
matrix to go from world space to view space by applying a rotation and

View File

@ -0,0 +1,23 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from test_cameras_alignment import TestCamerasAlignment
def bm_cameras_alignment() -> None:
case_grid = {
"batch_size": [10, 100, 1000],
"mode": ["centers", "extrinsics"],
"estimate_scale": [False, True],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
benchmark(
TestCamerasAlignment.corresponding_cameras_alignment,
"CORRESPONDING_CAMERAS_ALIGNMENT",
kwargs_list,
warmup_iters=1,
)

View File

@ -26,6 +26,7 @@
# SOFTWARE.
import math
import typing
import unittest
import numpy as np
@ -47,6 +48,7 @@ from pytorch3d.renderer.cameras import (
look_at_view_transform,
)
from pytorch3d.transforms import Transform3d
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.so3 import so3_exponential_map
@ -132,6 +134,51 @@ def ndc_to_screen_points_naive(points, imsize):
return torch.stack((x, y, z), dim=2)
def init_random_cameras(
cam_type: typing.Type[CamerasBase], batch_size: int, random_z: bool = False
):
cam_params = {}
T = torch.randn(batch_size, 3) * 0.03
if not random_z:
T[:, 2] = 4
R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
cam_params = {"R": R, "T": T}
if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == OpenGLPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == FoVPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (
SfMOrthographicCameras,
SfMPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
else:
raise ValueError(str(cam_type))
return cam_type(**cam_params)
class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
@ -410,7 +457,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
def test_get_camera_center(self, batch_size=10):
T = torch.randn(batch_size, 3)
R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
R = random_rotations(batch_size)
for cam_type in (
OpenGLPerspectiveCameras,
OpenGLOrthographicCameras,
@ -426,48 +473,6 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
self.assertTrue(torch.allclose(C, C_, atol=1e-05))
@staticmethod
def init_random_cameras(cam_type: CamerasBase, batch_size: int):
cam_params = {}
T = torch.randn(batch_size, 3) * 0.03
T[:, 2] = 4
R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
cam_params = {"R": R, "T": T}
if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == OpenGLPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == FoVPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (
SfMOrthographicCameras,
SfMPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
else:
raise ValueError(str(cam_type))
return cam_type(**cam_params)
@staticmethod
def init_equiv_cameras_ndc_screen(cam_type: CamerasBase, batch_size: int):
T = torch.randn(batch_size, 3) * 0.03
@ -508,7 +513,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
PerspectiveCameras,
):
# init the cameras
cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size)
cameras = init_random_cameras(cam_type, batch_size)
# xyz - the ground truth point cloud
xyz = torch.randn(batch_size, num_points, 3) * 0.3
# xyz in camera coordinates
@ -572,7 +577,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
):
# init the cameras
cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size)
cameras = init_random_cameras(cam_type, batch_size)
# xyz - the ground truth point cloud
xyz = torch.randn(batch_size, num_points, 3) * 0.3
# image size
@ -618,7 +623,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
OrthographicCameras,
PerspectiveCameras,
):
cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size)
cameras = init_random_cameras(cam_type, batch_size)
cameras = cameras.to(torch.device("cpu"))
cameras_clone = cameras.clone()

View File

@ -0,0 +1,174 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import corresponding_cameras_alignment
from pytorch3d.renderer.cameras import (
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
SfMOrthographicCameras,
SfMPerspectiveCameras,
)
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.so3 import so3_exponential_map, so3_relative_angle
from test_cameras import init_random_cameras
class TestCamerasAlignment(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_corresponding_cameras_alignment(self):
"""
Checks the corresponding_cameras_alignment function.
"""
device = torch.device("cuda:0")
# try few different random setups
for _ in range(3):
for estimate_scale in (True, False):
# init true alignment transform
R_align_gt = random_rotations(1, device=device)[0]
T_align_gt = torch.randn(3, dtype=torch.float32, device=device)
# init true scale
if estimate_scale:
s_align_gt = torch.randn(
1, dtype=torch.float32, device=device
).exp()
else:
s_align_gt = torch.tensor(1.0, dtype=torch.float32, device=device)
for cam_type in (
SfMOrthographicCameras,
OpenGLPerspectiveCameras,
OpenGLOrthographicCameras,
SfMPerspectiveCameras,
):
# try well-determined and underdetermined cases
for batch_size in (10, 4, 3, 2, 1):
# get random cameras
cameras = init_random_cameras(
cam_type, batch_size, random_z=True
).to(device)
# try all alignment modes
for mode in ("extrinsics", "centers"):
# try different noise levels
for add_noise in (0.0, 0.01, 1e-4):
self._corresponding_cameras_alignment_test_case(
cameras,
R_align_gt,
T_align_gt,
s_align_gt,
estimate_scale,
mode,
add_noise,
)
def _corresponding_cameras_alignment_test_case(
self,
cameras,
R_align_gt,
T_align_gt,
s_align_gt,
estimate_scale,
mode,
add_noise,
):
batch_size = cameras.R.shape[0]
# get target camera centers
R_new = torch.bmm(R_align_gt[None].expand_as(cameras.R), cameras.R)
T_new = (
torch.bmm(T_align_gt[None, None].repeat(batch_size, 1, 1), cameras.R)[:, 0]
+ cameras.T
) * s_align_gt
if add_noise != 0.0:
R_new = torch.bmm(
R_new, so3_exponential_map(torch.randn_like(T_new) * add_noise)
)
T_new += torch.randn_like(T_new) * add_noise
# create new cameras from R_new and T_new
cameras_tgt = cameras.clone()
cameras_tgt.R = R_new
cameras_tgt.T = T_new
# align cameras and cameras_tgt
cameras_aligned = corresponding_cameras_alignment(
cameras, cameras_tgt, estimate_scale=estimate_scale, mode=mode
)
if batch_size <= 2 and mode == "centers":
# underdetermined case - check only the center alignment error
# since the rotation and translation are ambiguous here
self.assertClose(
cameras_aligned.get_camera_center(),
cameras_tgt.get_camera_center(),
atol=max(add_noise * 7.0, 1e-4),
)
else:
def _rmse(a):
return (torch.norm(a, dim=1, p=2) ** 2).mean().sqrt()
if add_noise != 0.0:
# in a noisy case check mean rotation/translation error for
# extrinsic alignment and root mean center error for center alignment
if mode == "centers":
self.assertNormsClose(
cameras_aligned.get_camera_center(),
cameras_tgt.get_camera_center(),
_rmse,
atol=max(add_noise * 10.0, 1e-4),
)
elif mode == "extrinsics":
angle_err = so3_relative_angle(
cameras_aligned.R, cameras_tgt.R
).mean()
self.assertClose(
angle_err, torch.zeros_like(angle_err), atol=add_noise * 10.0
)
self.assertNormsClose(
cameras_aligned.T, cameras_tgt.T, _rmse, atol=add_noise * 7.0
)
else:
raise ValueError(mode)
else:
# compare the rotations and translations of cameras
self.assertClose(cameras_aligned.R, cameras_tgt.R, atol=3e-4)
self.assertClose(cameras_aligned.T, cameras_tgt.T, atol=3e-4)
# compare the centers
self.assertClose(
cameras_aligned.get_camera_center(),
cameras_tgt.get_camera_center(),
atol=3e-4,
)
@staticmethod
def corresponding_cameras_alignment(
batch_size: int, estimate_scale: bool, mode: str, cam_type=SfMPerspectiveCameras
):
device = torch.device("cuda:0")
cameras_src, cameras_tgt = [
init_random_cameras(cam_type, batch_size, random_z=True).to(device)
for _ in range(2)
]
torch.cuda.synchronize()
def compute_corresponding_cameras_alignment():
corresponding_cameras_alignment(
cameras_src, cameras_tgt, estimate_scale=estimate_scale, mode=mode
)
torch.cuda.synchronize()
return compute_corresponding_cameras_alignment