Efficient PnP.

Summary:
Efficient PnP algorithm to fit 2D to 3D correspondences under perspective assumption.

Benchmarked both variants of nullspace and pick one; SVD takes 7 times longer in the 100K points case.

Reviewed By: davnov134, gkioxari

Differential Revision: D20095754

fbshipit-source-id: 2b4519729630e6373820880272f674829eaed073
This commit is contained in:
Roman Shapovalov 2020-04-17 07:42:16 -07:00 committed by Facebook GitHub Bot
parent 7788a38050
commit 04d8bf6a43
6 changed files with 680 additions and 12 deletions

View File

@ -0,0 +1,401 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This file contains Efficient PnP algorithm for Perspective-n-Points problem.
It finds a camera position (defined by rotation `R` and translation `T`) that
minimises re-projection error between the given 3D points `x` and
the corresponding uncalibrated 2D points `y`.
"""
import warnings
from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from pytorch3d.ops import points_alignment, utils as oputil
class EpnpSolution(NamedTuple):
x_cam: torch.Tensor
R: torch.Tensor
T: torch.Tensor
err_2d: torch.Tensor
err_3d: torch.Tensor
def _define_control_points(x, weight, storage_opts=None):
"""
Returns control points that define barycentric coordinates
Args:
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
weight: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
storage_opts: dict of keyword arguments to the tensor constructor.
"""
storage_opts = storage_opts or {}
x_mean = oputil.wmean(x, weight)
x_std = oputil.wmean((x - x_mean) ** 2, weight) ** 0.5
c_world = F.pad(torch.eye(3, **storage_opts), (0, 0, 0, 1), value=0.0).expand_as(
x[:, :4, :]
)
return c_world * x_std + x_mean
def _compute_alphas(x, c_world):
"""
Computes barycentric coordinates of x in the frame c_world.
Args:
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
c_world: control points in world coordinates.
"""
x = F.pad(x, (0, 1), value=1.0)
c = F.pad(c_world, (0, 1), value=1.0)
return torch.matmul(x, torch.inverse(c)) # B x N x 4
def _build_M(y, alphas, weight):
""" Returns the matrix defining the reprojection equations.
Args:
y: projected points in camera coordinates of size B x N x 2
alphas: barycentric coordinates of size B x N x 4
weight: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
"""
bs, n, _ = y.size()
# prepend t with the column of v's
def prepad(t, v):
return F.pad(t, (1, 0), value=v)
# outer left-multiply by alphas
def lm_alphas(t):
return torch.matmul(alphas[..., None], t).reshape(bs, n, 12)
M = torch.cat(
(
lm_alphas(
prepad(prepad(-y[:, :, 0, None, None], 0.0), 1.0)
), # u constraints
lm_alphas(
prepad(prepad(-y[:, :, 1, None, None], 1.0), 0.0)
), # v constraints
),
dim=-1,
).reshape(bs, -1, 12)
if weight is not None:
M = M * weight.repeat(1, 2)[:, :, None]
return M
def _null_space(m, kernel_dim):
""" Finds the null space (kernel) basis of the matrix
Args:
m: the batch of input matrices, B x N x 12
kernel_dim: number of dimensions to approximate the kernel
Returns:
* a batch of null space basis vectors
of size B x 4 x 3 x kernel_dim
* a batch of spectral values where near-0s correspond to actual
kernel vectors, of size B x kernel_dim
"""
mTm = torch.bmm(m.transpose(1, 2), m)
s, v = torch.symeig(mTm, eigenvectors=True)
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
def _reproj_error(y_hat, y, weight):
""" Projects estimated 3D points and computes the reprojection error
Args:
y_hat: a batch of predicted 2D points in homogeneous coordinates
y: a batch of ground-truth 2D points
weight: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
Returns:
Optionally weighted RMSE of difference between y and y_hat.
"""
y_hat = y_hat / y_hat[..., 2:]
dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5
return oputil.wmean(dist, weight)[:, 0, 0]
def _algebraic_error(x_w_rotated, x_cam, weight):
""" Computes the residual of Umeyama in 3D.
Args:
x_w_rotated: The given 3D points rotated with the predicted camera.
x_cam: the lifted 2D points y
weight: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
Returns:
Optionally weighted MSE of difference between x_w_rotated and x_cam.
"""
dist = ((x_w_rotated - x_cam) ** 2).sum(dim=-1, keepdim=True)
return oputil.wmean(dist, weight)[:, 0, 0]
def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e-9):
""" Given a solution, adjusts the scale and flip
Args:
c_cam: control points in camera coordinates
alphas: barycentric coordinates of the points
x_world: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`.
weights: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
eps: epsilon to threshold negative `z` values
"""
# position of reference points in camera coordinates
x_cam = torch.matmul(alphas, c_cam)
x_cam = x_cam * (1.0 - 2.0 * (oputil.wmean(x_cam[..., 2:], weight) < 0).float())
if torch.any(x_cam[..., 2:] < -eps):
neg_rate = oputil.wmean((x_cam[..., 2:] < 0).float(), weight, dim=(0, 1)).item()
warnings.warn("\nEPnP: %2.2f%% points have z<0." % (neg_rate * 100.0))
R, T, s = points_alignment.corresponding_points_alignment(
x_world, x_cam, weight, estimate_scale=True
)
x_cam = x_cam / s[:, None, None]
T = T / s[:, None]
x_w_rotated = torch.matmul(x_world, R) + T[:, None, :]
err_2d = _reproj_error(x_w_rotated, y, weight)
err_3d = _algebraic_error(x_w_rotated, x_cam, weight)
return EpnpSolution(x_cam, R, T, err_2d, err_3d)
def _gen_pairs(input, dim=-2, reducer=lambda l, r: ((l - r) ** 2).sum(dim=-1)):
""" Generates all pairs of different rows and then applies the reducer
Args:
input: a tensor
dim: a dimension to generate pairs across
reducer: a function of generated pair of rows to apply (beyond just concat)
Returns:
for default args, for A x B x C input, will output A x (B choose 2)
"""
n = input.size()[dim]
range = torch.arange(n)
idx = torch.combinations(range).to(input).long()
left = input.index_select(dim, idx[:, 0])
right = input.index_select(dim, idx[:, 1])
return reducer(left, right)
def _kernel_vec_distances(v):
""" Computes the coefficients for linearisation of the quadratic system
to match all pairwise distances between 4 control points (dim=1).
The last dimension corresponds to the coefficients for quadratic terms
Bij = Bi * Bj, where Bi and Bj correspond to kernel vectors.
Arg:
v: tensor of B x 4 x 3 x D, where D is dim(kernel), usually 4
Returns:
a tensor of B x 6 x [(D choose 2) + D];
for D=4, the last dim means [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34].
"""
dv = _gen_pairs(v, dim=-3, reducer=lambda l, r: l - r) # B x 6 x 3 x D
# we should take dot-product of all (i,j), i < j, with coeff 2
rows_2ij = 2.0 * _gen_pairs(dv, dim=-1, reducer=lambda l, r: (l * r).sum(dim=-2))
# this should produce B x 6 x (D choose 2) tensor
# we should take dot-product of all (i,i)
rows_ii = (dv ** 2).sum(dim=-2)
# this should produce B x 6 x D tensor
return torch.cat((rows_ii, rows_2ij), dim=-1)
def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx):
""" Solves an over-determined linear system for selected LHS columns.
A batched version of `torch.lstsq`.
Args:
rhs: right-hand side vectors
lhs: left-hand side matrices
lhs_col_idx: a slice of columns in lhs
Returns:
a least-squares solution for lhs * X = rhs
"""
lhs = lhs.index_select(-1, torch.tensor(lhs_col_idx, device=lhs.device).long())
return torch.matmul(torch.pinverse(lhs), rhs[:, :, None])
def _find_null_space_coords_1(kernel_dsts, cw_dst):
""" Solves case 1 from the paper [1]; solve for 4 coefficients:
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
^ ^ ^ ^
Args:
kernel_dsts: distances between kernel vectors
cw_dst: distances between control points
Returns:
coefficients to weight kernel vectors
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
EPnP: An Accurate O(n) solution to the PnP problem.
International Journal of Computer Vision.
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6])
beta = beta * beta[:, :1, :].sign()
return beta / (beta[:, :1, :] ** 0.5)
def _find_null_space_coords_2(kernel_dsts, cw_dst):
""" Solves case 2 from the paper; solve for 3 coefficients:
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
^ ^ ^
Args:
kernel_dsts: distances between kernel vectors
cw_dst: distances between control points
Returns:
coefficients to weight kernel vectors
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
EPnP: An Accurate O(n) solution to the PnP problem.
International Journal of Computer Vision.
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1])
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
).float()
return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1)
def _find_null_space_coords_3(kernel_dsts, cw_dst):
""" Solves case 3 from the paper; solve for 5 coefficients:
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
^ ^ ^ ^ ^
Args:
kernel_dsts: distances between kernel vectors
cw_dst: distances between control points
Returns:
coefficients to weight kernel vectors
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
EPnP: An Accurate O(n) solution to the PnP problem.
International Journal of Computer Vision.
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7])
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
).float()
coord_2 = beta[:, 3:4, :] / coord_0[:, :1, :]
return torch.cat(
(coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1
)
def efficient_pnp(
x: torch.Tensor,
y: torch.Tensor,
weights: Optional[torch.Tensor] = None,
skip_quadratic_eq: bool = False,
) -> EpnpSolution:
"""
Implements Efficient PnP algorithm [1] for Perspective-n-Points problem:
finds a camera position (defined by rotation `R` and translation `T`) that
minimizes re-projection error between the given 3D points `x` and
the corresponding uncalibrated 2D points `y`, i.e. solves
`y[i] = Proj(x[i] R[i] + T[i])`
in the least-squares sense, where `i` are indices within the batch, and
`Proj` is the perspective projection operator: `Proj([x y z]) = [x/z y/z]`.
In the noise-less case, 4 points are enough to find the solution as long
as they are not co-planar.
Args:
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`.
weights: Batch of non-negative weights of
shape `(minibatch, num_point)`. `None` means equal weights.
skip_quadratic_eq: If True, assumes the solution space for the
linear system is one-dimensional, i.e. takes the scaled eigenvector
that corresponds to the smallest eigenvalue as a solution.
If False, finds the candidate coordinates in the potentially
4D null space by approximately solving the systems of quadratic
equations. The best candidate is chosen by examining the 2D
re-projection error. While this option finds a better solution,
especially when the number of points is small or perspective
distortions are low (the points are far away), it may be more
difficult to back-propagate through.
Returns:
`EpnpSolution` namedtuple containing elements:
**x_cam**: Batch of transformed points `x` that is used to find
the camera parameters, of shape `(minibatch, num_points, 3)`.
In the general (noisy) case, they are not exactly equal to
`x[i] R[i] + T[i]` but are some affine transform of `x[i]`s.
**R**: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
**T**: Batch of translation vectors of shape `(minibatch, 3)`.
**err_2d**: Batch of mean 2D re-projection errors of shape
`(minibatch,)`. Specifically, if `yhat` is the re-projection for
the `i`-th batch element, it returns `sum_j norm(yhat_j - y_j)`
where `j` iterates over points and `norm` denotes the L2 norm.
**err_3d**: Batch of mean algebraic errors of shape `(minibatch,)`.
Specifically, those are squared distances between `x_world` and
estimated points on the rays defined by `y`.
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
EPnP: An Accurate O(n) solution to the PnP problem.
International Journal of Computer Vision.
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
"""
# define control points in a world coordinate system (centered on the 3d
# points centroid); 4 x 3
# TODO: more stable when initialised with the center and eigenvectors!
c_world = _define_control_points(
x.detach(), weights, storage_opts={"dtype": x.dtype, "device": x.device}
)
# find the linear combination of the control points to represent the 3d points
alphas = _compute_alphas(x, c_world)
M = _build_M(y, alphas, weights)
# Compute kernel M
kernel, spectrum = _null_space(M, 4)
c_world_distances = _gen_pairs(c_world)
kernel_dsts = _kernel_vec_distances(kernel)
betas = (
[]
if skip_quadratic_eq
else [
fnsc(kernel_dsts, c_world_distances)
for fnsc in [
_find_null_space_coords_1,
_find_null_space_coords_2,
_find_null_space_coords_3,
]
]
)
c_cam_variants = [kernel] + [
torch.matmul(kernel, beta[:, None, :, :]) for beta in betas
]
solutions = [
_compute_norm_sign_scaling_factor(c_cam[..., 0], alphas, x, y, weights)
for c_cam in c_cam_variants
]
sol_zipped = EpnpSolution(*(torch.stack(list(col)) for col in zip(*solutions)))
best = torch.argmin(sol_zipped.err_2d, dim=0)
def gather1d(source, idx):
# reduces the dim=1 by picking the slices in a 1D tensor idx
# in other words, it is batched index_select.
return source.gather(
0,
idx.reshape(1, -1, *([1] * (len(source.shape) - 2))).expand_as(source[:1]),
)[0]
return EpnpSolution(*[gather1d(sol_col, best) for sol_col in sol_zipped])

View File

@ -0,0 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from test_perspective_n_points import TestPerspectiveNPoints
def bm_perspective_n_points() -> None:
case_grid = {
"batch_size": [1, 10, 100],
"num_pts": [100, 100000],
"skip_q": [False, True],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
test = TestPerspectiveNPoints()
benchmark(
test.case_with_gaussian_points,
"PerspectiveNPoints",
kwargs_list,
warmup_iters=1,
)

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from copy import deepcopy

View File

@ -1,12 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from typing import Optional
from typing import Callable, Optional, Union
import numpy as np
import torch
TensorOrArray = Union[torch.Tensor, np.ndarray]
class TestCaseMixin(unittest.TestCase):
def assertSeparate(self, tensor1, tensor2) -> None:
"""
@ -28,10 +31,11 @@ class TestCaseMixin(unittest.TestCase):
ptrs = [i.storage().data_ptr() for i in tensor_list]
self.assertCountEqual(ptrs, set(ptrs))
def assertClose(
def assertNormsClose(
self,
input,
other,
input: TensorOrArray,
other: TensorOrArray,
norm_fn: Callable[[TensorOrArray], TensorOrArray],
*,
rtol: float = 1e-05,
atol: float = 1e-08,
@ -39,7 +43,60 @@ class TestCaseMixin(unittest.TestCase):
msg: Optional[str] = None,
) -> None:
"""
Verify that two tensors or arrays are the same shape and close.
Verifies that two tensors or arrays have the same shape and are close
given absolute and relative tolerance; raises AssertionError otherwise.
A custom norm function is computed before comparison. If no such pre-
processing needed, pass `torch.abs` or, equivalently, call `assertClose`.
Args:
input, other: two tensors or two arrays.
norm_fn: The function evaluates
`all(norm_fn(input - other) <= atol + rtol * norm_fn(other))`.
norm_fn is a tensor -> tensor function; the output has:
* all entries non-negative,
* shape defined by the input shape only.
rtol, atol, equal_nan: as for torch.allclose.
msg: message in case the assertion is violated.
Note:
Optional arguments here are all keyword-only, to avoid confusion
with msg arguments on other assert functions.
"""
self.assertEqual(np.shape(input), np.shape(other))
diff = norm_fn(input - other)
other_ = norm_fn(other)
# We want to generalise allclose(input, output), which is essentially
# all(diff <= atol + rtol * other)
# but with a sophisticated handling non-finite values.
# We work that around by calling allclose() with the following arguments:
# allclose(diff + other_, other_). This computes what we want because
# all(|diff + other_ - other_| <= atol + rtol * |other_|) ==
# all(|norm_fn(input - other)| <= atol + rtol * |norm_fn(other)|) ==
# all(norm_fn(input - other) <= atol + rtol * norm_fn(other)).
backend = torch if torch.is_tensor(input) else np
close = backend.allclose(
diff + other_, other_, rtol=rtol, atol=atol, equal_nan=equal_nan
)
self.assertTrue(close, msg)
def assertClose(
self,
input: TensorOrArray,
other: TensorOrArray,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
msg: Optional[str] = None,
) -> None:
"""
Verifies that two tensors or arrays have the same shape and are close
given absolute and relative tolerance, i.e. checks
`all(|input - other| <= atol + rtol * |other|)`;
raises AssertionError otherwise.
Args:
input, other: two tensors or two arrays.
rtol, atol, equal_nan: as for torch.allclose.
@ -51,10 +108,9 @@ class TestCaseMixin(unittest.TestCase):
self.assertEqual(np.shape(input), np.shape(other))
if torch.is_tensor(input):
close = torch.allclose(
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
)
else:
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
backend = torch if torch.is_tensor(input) else np
close = backend.allclose(
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
)
self.assertTrue(close, msg)

View File

@ -0,0 +1,56 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_all_close(self):
device = torch.device("cuda:0")
n_points = 20
noise_std = 1e-3
msg = "tratata"
# test absolute tolerance
x = torch.rand(n_points, 3, device=device)
x_noise = x + noise_std * torch.rand(n_points, 3, device=device)
assert torch.allclose(x, x_noise, atol=10 * noise_std)
assert not torch.allclose(x, x_noise, atol=0.1 * noise_std)
self.assertClose(x, x_noise, atol=10 * noise_std)
with self.assertRaises(AssertionError) as context:
self.assertClose(x, x_noise, atol=0.1 * noise_std, msg=msg)
self.assertTrue(msg in str(context.exception))
# test numpy
def to_np(t):
return t.data.cpu().numpy()
self.assertClose(to_np(x), to_np(x_noise), atol=10 * noise_std)
with self.assertRaises(AssertionError) as context:
self.assertClose(to_np(x), to_np(x_noise), atol=0.1 * noise_std, msg=msg)
self.assertTrue(msg in str(context.exception))
# test relative tolerance
assert torch.allclose(x, x_noise, rtol=100 * noise_std)
assert not torch.allclose(x, x_noise, rtol=noise_std)
self.assertClose(x, x_noise, rtol=100 * noise_std)
with self.assertRaises(AssertionError) as context:
self.assertClose(x, x_noise, rtol=noise_std, msg=msg)
self.assertTrue(msg in str(context.exception))
# test norm aggregation
# if one of the spatial dimensions is small, norm aggregation helps
x_noise[:, 0] = x_noise[:, 0] - x[:, 0]
x[:, 0] = 0.0
assert not torch.allclose(x, x_noise, rtol=100 * noise_std)
self.assertNormsClose(
x, x_noise, rtol=100 * noise_std, norm_fn=lambda t: t.norm(dim=-1)
)

View File

@ -0,0 +1,131 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import perspective_n_points
from pytorch3d.transforms import rotation_conversions
def reproj_error(x_world, y, R, T, weight=None):
# applies the affine transform, projects, and computes the reprojection error
y_hat = torch.matmul(x_world, R) + T[:, None, :]
y_hat = y_hat / y_hat[..., 2:]
if weight is None:
weight = y.new_ones((1, 1))
return (((weight[:, :, None] * (y - y_hat[..., :2])) ** 2).sum(dim=-1) ** 0.5).mean(
dim=-1
)
class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False):
sol = perspective_n_points.efficient_pnp(
x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q
)
err_2d = reproj_error(x_world, y, sol.R, sol.T)
R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R)
R_quat = rotation_conversions.matrix_to_quaternion(R)
num_pts = x_world.shape[-2]
# quadratic part is more stable with fewer points
num_pts_thresh = 5 if skip_q else 4
if check_output and num_pts > num_pts_thresh:
assert_msg = (
f"test_perspective_n_points assertion failure for "
f"n_points={num_pts}, "
f"skip_quadratic={skip_q}, "
f"no noise."
)
self.assertClose(err_2d, sol.err_2d, msg=assert_msg)
self.assertTrue((err_2d < 1e-4).all(), msg=assert_msg)
def norm_fn(t):
return t.norm(dim=-1)
self.assertNormsClose(
T, sol.T[:, None, :], rtol=1e-2, norm_fn=norm_fn, msg=assert_msg
)
self.assertNormsClose(
R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg
)
if print_stats:
torch.set_printoptions(precision=5, sci_mode=False)
for err_2d, err_3d, R_gt, T_gt in zip(
sol.err_2d,
sol.err_3d,
torch.cat((sol.R, R), dim=-1),
torch.stack((sol.T, T[:, 0, :]), dim=-1),
):
print("2D Error: %1.4f" % err_2d.item())
print("3D Error: %1.4f" % err_3d.item())
print("R_hat | R_gt\n", R_gt)
print("T_hat | T_gt\n", T_gt)
def _testcase_from_2d(self, y, print_stats, benchmark, skip_q=False):
x_cam = torch.cat((y, torch.rand_like(y[:, :1]) * 2.0 + 3.5), dim=1)
x_cam[:, :2] *= x_cam[:, 2:] # unproject
R = rotation_conversions.random_rotations(16).to(y)
T = torch.randn_like(R[:, :1, :])
x_world = torch.matmul(x_cam - T, R.transpose(1, 2))
if print_stats:
print("Run without noise")
if benchmark: # return curried call
torch.cuda.synchronize()
def result():
self._run_and_print(x_world, y, R, T, False, skip_q)
torch.cuda.synchronize()
return result
self._run_and_print(x_world, y, R, T, print_stats, skip_q, check_output=True)
# in the noisy case, there are no guarantees, so we check it doesn't crash
if print_stats:
print("Run with noise")
x_world += torch.randn_like(x_world) * 0.1
self._run_and_print(x_world, y, R, T, print_stats, skip_q)
def case_with_gaussian_points(
self, batch_size=10, num_pts=20, print_stats=False, benchmark=True, skip_q=False
):
return self._testcase_from_2d(
torch.randn((num_pts, 2)).cuda() / 3.0,
print_stats=print_stats,
benchmark=benchmark,
skip_q=skip_q,
)
def test_perspective_n_points(self, print_stats=False):
if print_stats:
print("RUN ON A DENSE GRID")
u = torch.linspace(-1.0, 1.0, 20)
v = torch.linspace(-1.0, 1.0, 15)
for skip_q in [False, True]:
self._testcase_from_2d(
torch.cartesian_prod(u, v).cuda(), print_stats, False, skip_q
)
for num_pts in range(6, 3, -1):
for skip_q in [False, True]:
if print_stats:
print(f"RUN ON {num_pts} points; skip_quadratic: {skip_q}")
self.case_with_gaussian_points(
num_pts=num_pts,
print_stats=print_stats,
benchmark=False,
skip_q=skip_q,
)