Efficient PnP.

Summary:
Efficient PnP algorithm to fit 2D to 3D correspondences under perspective assumption.

Benchmarked both variants of nullspace and pick one; SVD takes 7 times longer in the 100K points case.

Reviewed By: davnov134, gkioxari

Differential Revision: D20095754

fbshipit-source-id: 2b4519729630e6373820880272f674829eaed073
This commit is contained in:
Roman Shapovalov 2020-04-17 07:42:16 -07:00 committed by Facebook GitHub Bot
parent 7788a38050
commit 04d8bf6a43
6 changed files with 680 additions and 12 deletions

View File

@ -0,0 +1,401 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This file contains Efficient PnP algorithm for Perspective-n-Points problem.
It finds a camera position (defined by rotation `R` and translation `T`) that
minimises re-projection error between the given 3D points `x` and
the corresponding uncalibrated 2D points `y`.
"""
import warnings
from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from pytorch3d.ops import points_alignment, utils as oputil
class EpnpSolution(NamedTuple):
x_cam: torch.Tensor
R: torch.Tensor
T: torch.Tensor
err_2d: torch.Tensor
err_3d: torch.Tensor
def _define_control_points(x, weight, storage_opts=None):
"""
Returns control points that define barycentric coordinates
Args:
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
weight: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
storage_opts: dict of keyword arguments to the tensor constructor.
"""
storage_opts = storage_opts or {}
x_mean = oputil.wmean(x, weight)
x_std = oputil.wmean((x - x_mean) ** 2, weight) ** 0.5
c_world = F.pad(torch.eye(3, **storage_opts), (0, 0, 0, 1), value=0.0).expand_as(
x[:, :4, :]
)
return c_world * x_std + x_mean
def _compute_alphas(x, c_world):
"""
Computes barycentric coordinates of x in the frame c_world.
Args:
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
c_world: control points in world coordinates.
"""
x = F.pad(x, (0, 1), value=1.0)
c = F.pad(c_world, (0, 1), value=1.0)
return torch.matmul(x, torch.inverse(c)) # B x N x 4
def _build_M(y, alphas, weight):
""" Returns the matrix defining the reprojection equations.
Args:
y: projected points in camera coordinates of size B x N x 2
alphas: barycentric coordinates of size B x N x 4
weight: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
"""
bs, n, _ = y.size()
# prepend t with the column of v's
def prepad(t, v):
return F.pad(t, (1, 0), value=v)
# outer left-multiply by alphas
def lm_alphas(t):
return torch.matmul(alphas[..., None], t).reshape(bs, n, 12)
M = torch.cat(
(
lm_alphas(
prepad(prepad(-y[:, :, 0, None, None], 0.0), 1.0)
), # u constraints
lm_alphas(
prepad(prepad(-y[:, :, 1, None, None], 1.0), 0.0)
), # v constraints
),
dim=-1,
).reshape(bs, -1, 12)
if weight is not None:
M = M * weight.repeat(1, 2)[:, :, None]
return M
def _null_space(m, kernel_dim):
""" Finds the null space (kernel) basis of the matrix
Args:
m: the batch of input matrices, B x N x 12
kernel_dim: number of dimensions to approximate the kernel
Returns:
* a batch of null space basis vectors
of size B x 4 x 3 x kernel_dim
* a batch of spectral values where near-0s correspond to actual
kernel vectors, of size B x kernel_dim
"""
mTm = torch.bmm(m.transpose(1, 2), m)
s, v = torch.symeig(mTm, eigenvectors=True)
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
def _reproj_error(y_hat, y, weight):
""" Projects estimated 3D points and computes the reprojection error
Args:
y_hat: a batch of predicted 2D points in homogeneous coordinates
y: a batch of ground-truth 2D points
weight: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
Returns:
Optionally weighted RMSE of difference between y and y_hat.
"""
y_hat = y_hat / y_hat[..., 2:]
dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5
return oputil.wmean(dist, weight)[:, 0, 0]
def _algebraic_error(x_w_rotated, x_cam, weight):
""" Computes the residual of Umeyama in 3D.
Args:
x_w_rotated: The given 3D points rotated with the predicted camera.
x_cam: the lifted 2D points y
weight: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
Returns:
Optionally weighted MSE of difference between x_w_rotated and x_cam.
"""
dist = ((x_w_rotated - x_cam) ** 2).sum(dim=-1, keepdim=True)
return oputil.wmean(dist, weight)[:, 0, 0]
def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e-9):
""" Given a solution, adjusts the scale and flip
Args:
c_cam: control points in camera coordinates
alphas: barycentric coordinates of the points
x_world: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`.
weights: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
eps: epsilon to threshold negative `z` values
"""
# position of reference points in camera coordinates
x_cam = torch.matmul(alphas, c_cam)
x_cam = x_cam * (1.0 - 2.0 * (oputil.wmean(x_cam[..., 2:], weight) < 0).float())
if torch.any(x_cam[..., 2:] < -eps):
neg_rate = oputil.wmean((x_cam[..., 2:] < 0).float(), weight, dim=(0, 1)).item()
warnings.warn("\nEPnP: %2.2f%% points have z<0." % (neg_rate * 100.0))
R, T, s = points_alignment.corresponding_points_alignment(
x_world, x_cam, weight, estimate_scale=True
)
x_cam = x_cam / s[:, None, None]
T = T / s[:, None]
x_w_rotated = torch.matmul(x_world, R) + T[:, None, :]
err_2d = _reproj_error(x_w_rotated, y, weight)
err_3d = _algebraic_error(x_w_rotated, x_cam, weight)
return EpnpSolution(x_cam, R, T, err_2d, err_3d)
def _gen_pairs(input, dim=-2, reducer=lambda l, r: ((l - r) ** 2).sum(dim=-1)):
""" Generates all pairs of different rows and then applies the reducer
Args:
input: a tensor
dim: a dimension to generate pairs across
reducer: a function of generated pair of rows to apply (beyond just concat)
Returns:
for default args, for A x B x C input, will output A x (B choose 2)
"""
n = input.size()[dim]
range = torch.arange(n)
idx = torch.combinations(range).to(input).long()
left = input.index_select(dim, idx[:, 0])
right = input.index_select(dim, idx[:, 1])
return reducer(left, right)
def _kernel_vec_distances(v):
""" Computes the coefficients for linearisation of the quadratic system
to match all pairwise distances between 4 control points (dim=1).
The last dimension corresponds to the coefficients for quadratic terms
Bij = Bi * Bj, where Bi and Bj correspond to kernel vectors.
Arg:
v: tensor of B x 4 x 3 x D, where D is dim(kernel), usually 4
Returns:
a tensor of B x 6 x [(D choose 2) + D];
for D=4, the last dim means [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34].
"""
dv = _gen_pairs(v, dim=-3, reducer=lambda l, r: l - r) # B x 6 x 3 x D
# we should take dot-product of all (i,j), i < j, with coeff 2
rows_2ij = 2.0 * _gen_pairs(dv, dim=-1, reducer=lambda l, r: (l * r).sum(dim=-2))
# this should produce B x 6 x (D choose 2) tensor
# we should take dot-product of all (i,i)
rows_ii = (dv ** 2).sum(dim=-2)
# this should produce B x 6 x D tensor
return torch.cat((rows_ii, rows_2ij), dim=-1)
def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx):
""" Solves an over-determined linear system for selected LHS columns.
A batched version of `torch.lstsq`.
Args:
rhs: right-hand side vectors
lhs: left-hand side matrices
lhs_col_idx: a slice of columns in lhs
Returns:
a least-squares solution for lhs * X = rhs
"""
lhs = lhs.index_select(-1, torch.tensor(lhs_col_idx, device=lhs.device).long())
return torch.matmul(torch.pinverse(lhs), rhs[:, :, None])
def _find_null_space_coords_1(kernel_dsts, cw_dst):
""" Solves case 1 from the paper [1]; solve for 4 coefficients:
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
^ ^ ^ ^
Args:
kernel_dsts: distances between kernel vectors
cw_dst: distances between control points
Returns:
coefficients to weight kernel vectors
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
EPnP: An Accurate O(n) solution to the PnP problem.
International Journal of Computer Vision.
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6])
beta = beta * beta[:, :1, :].sign()
return beta / (beta[:, :1, :] ** 0.5)
def _find_null_space_coords_2(kernel_dsts, cw_dst):
""" Solves case 2 from the paper; solve for 3 coefficients:
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
^ ^ ^
Args:
kernel_dsts: distances between kernel vectors
cw_dst: distances between control points
Returns:
coefficients to weight kernel vectors
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
EPnP: An Accurate O(n) solution to the PnP problem.
International Journal of Computer Vision.
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1])
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
).float()
return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1)
def _find_null_space_coords_3(kernel_dsts, cw_dst):
""" Solves case 3 from the paper; solve for 5 coefficients:
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
^ ^ ^ ^ ^
Args:
kernel_dsts: distances between kernel vectors
cw_dst: distances between control points
Returns:
coefficients to weight kernel vectors
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
EPnP: An Accurate O(n) solution to the PnP problem.
International Journal of Computer Vision.
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7])
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
).float()
coord_2 = beta[:, 3:4, :] / coord_0[:, :1, :]
return torch.cat(
(coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1
)
def efficient_pnp(
x: torch.Tensor,
y: torch.Tensor,
weights: Optional[torch.Tensor] = None,
skip_quadratic_eq: bool = False,
) -> EpnpSolution:
"""
Implements Efficient PnP algorithm [1] for Perspective-n-Points problem:
finds a camera position (defined by rotation `R` and translation `T`) that
minimizes re-projection error between the given 3D points `x` and
the corresponding uncalibrated 2D points `y`, i.e. solves
`y[i] = Proj(x[i] R[i] + T[i])`
in the least-squares sense, where `i` are indices within the batch, and
`Proj` is the perspective projection operator: `Proj([x y z]) = [x/z y/z]`.
In the noise-less case, 4 points are enough to find the solution as long
as they are not co-planar.
Args:
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`.
weights: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
skip_quadratic_eq: If True, assumes the solution space for the
linear system is one-dimensional, i.e. takes the scaled eigenvector
that corresponds to the smallest eigenvalue as a solution.
If False, finds the candidate coordinates in the potentially
4D null space by approximately solving the systems of quadratic
equations. The best candidate is chosen by examining the 2D
re-projection error. While this option finds a better solution,
especially when the number of points is small or perspective
distortions are low (the points are far away), it may be more
difficult to back-propagate through.
Returns:
`EpnpSolution` namedtuple containing elements:
**x_cam**: Batch of transformed points `x` that is used to find
the camera parameters, of shape `(minibatch, num_points, 3)`.
In the general (noisy) case, they are not exactly equal to
`x[i] R[i] + T[i]` but are some affine transform of `x[i]`s.
**R**: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
**T**: Batch of translation vectors of shape `(minibatch, 3)`.
**err_2d**: Batch of mean 2D re-projection errors of shape
`(minibatch,)`. Specifically, if `yhat` is the re-projection for
the `i`-th batch element, it returns `sum_j norm(yhat_j - y_j)`
where `j` iterates over points and `norm` denotes the L2 norm.
**err_3d**: Batch of mean algebraic errors of shape `(minibatch,)`.
Specifically, those are squared distances between `x_world` and
estimated points on the rays defined by `y`.
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
EPnP: An Accurate O(n) solution to the PnP problem.
International Journal of Computer Vision.
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
"""
# define control points in a world coordinate system (centered on the 3d
# points centroid); 4 x 3
# TODO: more stable when initialised with the center and eigenvectors!
c_world = _define_control_points(
x.detach(), weights, storage_opts={"dtype": x.dtype, "device": x.device}
)
# find the linear combination of the control points to represent the 3d points
alphas = _compute_alphas(x, c_world)
M = _build_M(y, alphas, weights)
# Compute kernel M
kernel, spectrum = _null_space(M, 4)
c_world_distances = _gen_pairs(c_world)
kernel_dsts = _kernel_vec_distances(kernel)
betas = (
[]
if skip_quadratic_eq
else [
fnsc(kernel_dsts, c_world_distances)
for fnsc in [
_find_null_space_coords_1,
_find_null_space_coords_2,
_find_null_space_coords_3,
]
]
)
c_cam_variants = [kernel] + [
torch.matmul(kernel, beta[:, None, :, :]) for beta in betas
]
solutions = [
_compute_norm_sign_scaling_factor(c_cam[..., 0], alphas, x, y, weights)
for c_cam in c_cam_variants
]
sol_zipped = EpnpSolution(*(torch.stack(list(col)) for col in zip(*solutions)))
best = torch.argmin(sol_zipped.err_2d, dim=0)
def gather1d(source, idx):
# reduces the dim=1 by picking the slices in a 1D tensor idx
# in other words, it is batched index_select.
return source.gather(
0,
idx.reshape(1, -1, *([1] * (len(source.shape) - 2))).expand_as(source[:1]),
)[0]
return EpnpSolution(*[gather1d(sol_col, best) for sol_col in sol_zipped])

View File

@ -0,0 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from test_perspective_n_points import TestPerspectiveNPoints
def bm_perspective_n_points() -> None:
case_grid = {
"batch_size": [1, 10, 100],
"num_pts": [100, 100000],
"skip_q": [False, True],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
test = TestPerspectiveNPoints()
benchmark(
test.case_with_gaussian_points,
"PerspectiveNPoints",
kwargs_list,
warmup_iters=1,
)

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from copy import deepcopy from copy import deepcopy

View File

@ -1,12 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest import unittest
from typing import Optional from typing import Callable, Optional, Union
import numpy as np import numpy as np
import torch import torch
TensorOrArray = Union[torch.Tensor, np.ndarray]
class TestCaseMixin(unittest.TestCase): class TestCaseMixin(unittest.TestCase):
def assertSeparate(self, tensor1, tensor2) -> None: def assertSeparate(self, tensor1, tensor2) -> None:
""" """
@ -28,10 +31,11 @@ class TestCaseMixin(unittest.TestCase):
ptrs = [i.storage().data_ptr() for i in tensor_list] ptrs = [i.storage().data_ptr() for i in tensor_list]
self.assertCountEqual(ptrs, set(ptrs)) self.assertCountEqual(ptrs, set(ptrs))
def assertClose( def assertNormsClose(
self, self,
input, input: TensorOrArray,
other, other: TensorOrArray,
norm_fn: Callable[[TensorOrArray], TensorOrArray],
*, *,
rtol: float = 1e-05, rtol: float = 1e-05,
atol: float = 1e-08, atol: float = 1e-08,
@ -39,7 +43,60 @@ class TestCaseMixin(unittest.TestCase):
msg: Optional[str] = None, msg: Optional[str] = None,
) -> None: ) -> None:
""" """
Verify that two tensors or arrays are the same shape and close. Verifies that two tensors or arrays have the same shape and are close
given absolute and relative tolerance; raises AssertionError otherwise.
A custom norm function is computed before comparison. If no such pre-
processing needed, pass `torch.abs` or, equivalently, call `assertClose`.
Args:
input, other: two tensors or two arrays.
norm_fn: The function evaluates
`all(norm_fn(input - other) <= atol + rtol * norm_fn(other))`.
norm_fn is a tensor -> tensor function; the output has:
* all entries non-negative,
* shape defined by the input shape only.
rtol, atol, equal_nan: as for torch.allclose.
msg: message in case the assertion is violated.
Note:
Optional arguments here are all keyword-only, to avoid confusion
with msg arguments on other assert functions.
"""
self.assertEqual(np.shape(input), np.shape(other))
diff = norm_fn(input - other)
other_ = norm_fn(other)
# We want to generalise allclose(input, output), which is essentially
# all(diff <= atol + rtol * other)
# but with a sophisticated handling non-finite values.
# We work that around by calling allclose() with the following arguments:
# allclose(diff + other_, other_). This computes what we want because
# all(|diff + other_ - other_| <= atol + rtol * |other_|) ==
# all(|norm_fn(input - other)| <= atol + rtol * |norm_fn(other)|) ==
# all(norm_fn(input - other) <= atol + rtol * norm_fn(other)).
backend = torch if torch.is_tensor(input) else np
close = backend.allclose(
diff + other_, other_, rtol=rtol, atol=atol, equal_nan=equal_nan
)
self.assertTrue(close, msg)
def assertClose(
self,
input: TensorOrArray,
other: TensorOrArray,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
msg: Optional[str] = None,
) -> None:
"""
Verifies that two tensors or arrays have the same shape and are close
given absolute and relative tolerance, i.e. checks
`all(|input - other| <= atol + rtol * |other|)`;
raises AssertionError otherwise.
Args: Args:
input, other: two tensors or two arrays. input, other: two tensors or two arrays.
rtol, atol, equal_nan: as for torch.allclose. rtol, atol, equal_nan: as for torch.allclose.
@ -51,10 +108,9 @@ class TestCaseMixin(unittest.TestCase):
self.assertEqual(np.shape(input), np.shape(other)) self.assertEqual(np.shape(input), np.shape(other))
if torch.is_tensor(input): backend = torch if torch.is_tensor(input) else np
close = torch.allclose( close = backend.allclose(
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
) )
else:
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
self.assertTrue(close, msg) self.assertTrue(close, msg)

View File

@ -0,0 +1,56 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_all_close(self):
device = torch.device("cuda:0")
n_points = 20
noise_std = 1e-3
msg = "tratata"
# test absolute tolerance
x = torch.rand(n_points, 3, device=device)
x_noise = x + noise_std * torch.rand(n_points, 3, device=device)
assert torch.allclose(x, x_noise, atol=10 * noise_std)
assert not torch.allclose(x, x_noise, atol=0.1 * noise_std)
self.assertClose(x, x_noise, atol=10 * noise_std)
with self.assertRaises(AssertionError) as context:
self.assertClose(x, x_noise, atol=0.1 * noise_std, msg=msg)
self.assertTrue(msg in str(context.exception))
# test numpy
def to_np(t):
return t.data.cpu().numpy()
self.assertClose(to_np(x), to_np(x_noise), atol=10 * noise_std)
with self.assertRaises(AssertionError) as context:
self.assertClose(to_np(x), to_np(x_noise), atol=0.1 * noise_std, msg=msg)
self.assertTrue(msg in str(context.exception))
# test relative tolerance
assert torch.allclose(x, x_noise, rtol=100 * noise_std)
assert not torch.allclose(x, x_noise, rtol=noise_std)
self.assertClose(x, x_noise, rtol=100 * noise_std)
with self.assertRaises(AssertionError) as context:
self.assertClose(x, x_noise, rtol=noise_std, msg=msg)
self.assertTrue(msg in str(context.exception))
# test norm aggregation
# if one of the spatial dimensions is small, norm aggregation helps
x_noise[:, 0] = x_noise[:, 0] - x[:, 0]
x[:, 0] = 0.0
assert not torch.allclose(x, x_noise, rtol=100 * noise_std)
self.assertNormsClose(
x, x_noise, rtol=100 * noise_std, norm_fn=lambda t: t.norm(dim=-1)
)

View File

@ -0,0 +1,131 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import perspective_n_points
from pytorch3d.transforms import rotation_conversions
def reproj_error(x_world, y, R, T, weight=None):
# applies the affine transform, projects, and computes the reprojection error
y_hat = torch.matmul(x_world, R) + T[:, None, :]
y_hat = y_hat / y_hat[..., 2:]
if weight is None:
weight = y.new_ones((1, 1))
return (((weight[:, :, None] * (y - y_hat[..., :2])) ** 2).sum(dim=-1) ** 0.5).mean(
dim=-1
)
class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False):
sol = perspective_n_points.efficient_pnp(
x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q
)
err_2d = reproj_error(x_world, y, sol.R, sol.T)
R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R)
R_quat = rotation_conversions.matrix_to_quaternion(R)
num_pts = x_world.shape[-2]
# quadratic part is more stable with fewer points
num_pts_thresh = 5 if skip_q else 4
if check_output and num_pts > num_pts_thresh:
assert_msg = (
f"test_perspective_n_points assertion failure for "
f"n_points={num_pts}, "
f"skip_quadratic={skip_q}, "
f"no noise."
)
self.assertClose(err_2d, sol.err_2d, msg=assert_msg)
self.assertTrue((err_2d < 1e-4).all(), msg=assert_msg)
def norm_fn(t):
return t.norm(dim=-1)
self.assertNormsClose(
T, sol.T[:, None, :], rtol=1e-2, norm_fn=norm_fn, msg=assert_msg
)
self.assertNormsClose(
R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg
)
if print_stats:
torch.set_printoptions(precision=5, sci_mode=False)
for err_2d, err_3d, R_gt, T_gt in zip(
sol.err_2d,
sol.err_3d,
torch.cat((sol.R, R), dim=-1),
torch.stack((sol.T, T[:, 0, :]), dim=-1),
):
print("2D Error: %1.4f" % err_2d.item())
print("3D Error: %1.4f" % err_3d.item())
print("R_hat | R_gt\n", R_gt)
print("T_hat | T_gt\n", T_gt)
def _testcase_from_2d(self, y, print_stats, benchmark, skip_q=False):
x_cam = torch.cat((y, torch.rand_like(y[:, :1]) * 2.0 + 3.5), dim=1)
x_cam[:, :2] *= x_cam[:, 2:] # unproject
R = rotation_conversions.random_rotations(16).to(y)
T = torch.randn_like(R[:, :1, :])
x_world = torch.matmul(x_cam - T, R.transpose(1, 2))
if print_stats:
print("Run without noise")
if benchmark: # return curried call
torch.cuda.synchronize()
def result():
self._run_and_print(x_world, y, R, T, False, skip_q)
torch.cuda.synchronize()
return result
self._run_and_print(x_world, y, R, T, print_stats, skip_q, check_output=True)
# in the noisy case, there are no guarantees, so we check it doesn't crash
if print_stats:
print("Run with noise")
x_world += torch.randn_like(x_world) * 0.1
self._run_and_print(x_world, y, R, T, print_stats, skip_q)
def case_with_gaussian_points(
self, batch_size=10, num_pts=20, print_stats=False, benchmark=True, skip_q=False
):
return self._testcase_from_2d(
torch.randn((num_pts, 2)).cuda() / 3.0,
print_stats=print_stats,
benchmark=benchmark,
skip_q=skip_q,
)
def test_perspective_n_points(self, print_stats=False):
if print_stats:
print("RUN ON A DENSE GRID")
u = torch.linspace(-1.0, 1.0, 20)
v = torch.linspace(-1.0, 1.0, 15)
for skip_q in [False, True]:
self._testcase_from_2d(
torch.cartesian_prod(u, v).cuda(), print_stats, False, skip_q
)
for num_pts in range(6, 3, -1):
for skip_q in [False, True]:
if print_stats:
print(f"RUN ON {num_pts} points; skip_quadratic: {skip_q}")
self.case_with_gaussian_points(
num_pts=num_pts,
print_stats=print_stats,
benchmark=False,
skip_q=skip_q,
)