mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Efficient PnP.
Summary: Efficient PnP algorithm to fit 2D to 3D correspondences under perspective assumption. Benchmarked both variants of nullspace and pick one; SVD takes 7 times longer in the 100K points case. Reviewed By: davnov134, gkioxari Differential Revision: D20095754 fbshipit-source-id: 2b4519729630e6373820880272f674829eaed073
This commit is contained in:
parent
7788a38050
commit
04d8bf6a43
401
pytorch3d/ops/perspective_n_points.py
Normal file
401
pytorch3d/ops/perspective_n_points.py
Normal file
@ -0,0 +1,401 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
"""
|
||||
This file contains Efficient PnP algorithm for Perspective-n-Points problem.
|
||||
It finds a camera position (defined by rotation `R` and translation `T`) that
|
||||
minimises re-projection error between the given 3D points `x` and
|
||||
the corresponding uncalibrated 2D points `y`.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.ops import points_alignment, utils as oputil
|
||||
|
||||
|
||||
class EpnpSolution(NamedTuple):
|
||||
x_cam: torch.Tensor
|
||||
R: torch.Tensor
|
||||
T: torch.Tensor
|
||||
err_2d: torch.Tensor
|
||||
err_3d: torch.Tensor
|
||||
|
||||
|
||||
def _define_control_points(x, weight, storage_opts=None):
|
||||
"""
|
||||
Returns control points that define barycentric coordinates
|
||||
Args:
|
||||
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
|
||||
weight: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)`. `None` means equal weights.
|
||||
storage_opts: dict of keyword arguments to the tensor constructor.
|
||||
"""
|
||||
storage_opts = storage_opts or {}
|
||||
x_mean = oputil.wmean(x, weight)
|
||||
x_std = oputil.wmean((x - x_mean) ** 2, weight) ** 0.5
|
||||
c_world = F.pad(torch.eye(3, **storage_opts), (0, 0, 0, 1), value=0.0).expand_as(
|
||||
x[:, :4, :]
|
||||
)
|
||||
return c_world * x_std + x_mean
|
||||
|
||||
|
||||
def _compute_alphas(x, c_world):
|
||||
"""
|
||||
Computes barycentric coordinates of x in the frame c_world.
|
||||
Args:
|
||||
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
|
||||
c_world: control points in world coordinates.
|
||||
"""
|
||||
x = F.pad(x, (0, 1), value=1.0)
|
||||
c = F.pad(c_world, (0, 1), value=1.0)
|
||||
return torch.matmul(x, torch.inverse(c)) # B x N x 4
|
||||
|
||||
|
||||
def _build_M(y, alphas, weight):
|
||||
""" Returns the matrix defining the reprojection equations.
|
||||
Args:
|
||||
y: projected points in camera coordinates of size B x N x 2
|
||||
alphas: barycentric coordinates of size B x N x 4
|
||||
weight: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)`. `None` means equal weights.
|
||||
"""
|
||||
bs, n, _ = y.size()
|
||||
|
||||
# prepend t with the column of v's
|
||||
def prepad(t, v):
|
||||
return F.pad(t, (1, 0), value=v)
|
||||
|
||||
# outer left-multiply by alphas
|
||||
def lm_alphas(t):
|
||||
return torch.matmul(alphas[..., None], t).reshape(bs, n, 12)
|
||||
|
||||
M = torch.cat(
|
||||
(
|
||||
lm_alphas(
|
||||
prepad(prepad(-y[:, :, 0, None, None], 0.0), 1.0)
|
||||
), # u constraints
|
||||
lm_alphas(
|
||||
prepad(prepad(-y[:, :, 1, None, None], 1.0), 0.0)
|
||||
), # v constraints
|
||||
),
|
||||
dim=-1,
|
||||
).reshape(bs, -1, 12)
|
||||
|
||||
if weight is not None:
|
||||
M = M * weight.repeat(1, 2)[:, :, None]
|
||||
|
||||
return M
|
||||
|
||||
|
||||
def _null_space(m, kernel_dim):
|
||||
""" Finds the null space (kernel) basis of the matrix
|
||||
Args:
|
||||
m: the batch of input matrices, B x N x 12
|
||||
kernel_dim: number of dimensions to approximate the kernel
|
||||
Returns:
|
||||
* a batch of null space basis vectors
|
||||
of size B x 4 x 3 x kernel_dim
|
||||
* a batch of spectral values where near-0s correspond to actual
|
||||
kernel vectors, of size B x kernel_dim
|
||||
"""
|
||||
mTm = torch.bmm(m.transpose(1, 2), m)
|
||||
s, v = torch.symeig(mTm, eigenvectors=True)
|
||||
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
|
||||
|
||||
|
||||
def _reproj_error(y_hat, y, weight):
|
||||
""" Projects estimated 3D points and computes the reprojection error
|
||||
Args:
|
||||
y_hat: a batch of predicted 2D points in homogeneous coordinates
|
||||
y: a batch of ground-truth 2D points
|
||||
weight: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)`. `None` means equal weights.
|
||||
Returns:
|
||||
Optionally weighted RMSE of difference between y and y_hat.
|
||||
"""
|
||||
y_hat = y_hat / y_hat[..., 2:]
|
||||
dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5
|
||||
return oputil.wmean(dist, weight)[:, 0, 0]
|
||||
|
||||
|
||||
def _algebraic_error(x_w_rotated, x_cam, weight):
|
||||
""" Computes the residual of Umeyama in 3D.
|
||||
Args:
|
||||
x_w_rotated: The given 3D points rotated with the predicted camera.
|
||||
x_cam: the lifted 2D points y
|
||||
weight: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)`. `None` means equal weights.
|
||||
Returns:
|
||||
Optionally weighted MSE of difference between x_w_rotated and x_cam.
|
||||
"""
|
||||
dist = ((x_w_rotated - x_cam) ** 2).sum(dim=-1, keepdim=True)
|
||||
return oputil.wmean(dist, weight)[:, 0, 0]
|
||||
|
||||
|
||||
def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e-9):
|
||||
""" Given a solution, adjusts the scale and flip
|
||||
Args:
|
||||
c_cam: control points in camera coordinates
|
||||
alphas: barycentric coordinates of the points
|
||||
x_world: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
|
||||
y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`.
|
||||
weights: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)`. `None` means equal weights.
|
||||
eps: epsilon to threshold negative `z` values
|
||||
"""
|
||||
# position of reference points in camera coordinates
|
||||
x_cam = torch.matmul(alphas, c_cam)
|
||||
|
||||
x_cam = x_cam * (1.0 - 2.0 * (oputil.wmean(x_cam[..., 2:], weight) < 0).float())
|
||||
if torch.any(x_cam[..., 2:] < -eps):
|
||||
neg_rate = oputil.wmean((x_cam[..., 2:] < 0).float(), weight, dim=(0, 1)).item()
|
||||
warnings.warn("\nEPnP: %2.2f%% points have z<0." % (neg_rate * 100.0))
|
||||
|
||||
R, T, s = points_alignment.corresponding_points_alignment(
|
||||
x_world, x_cam, weight, estimate_scale=True
|
||||
)
|
||||
x_cam = x_cam / s[:, None, None]
|
||||
T = T / s[:, None]
|
||||
x_w_rotated = torch.matmul(x_world, R) + T[:, None, :]
|
||||
err_2d = _reproj_error(x_w_rotated, y, weight)
|
||||
err_3d = _algebraic_error(x_w_rotated, x_cam, weight)
|
||||
|
||||
return EpnpSolution(x_cam, R, T, err_2d, err_3d)
|
||||
|
||||
|
||||
def _gen_pairs(input, dim=-2, reducer=lambda l, r: ((l - r) ** 2).sum(dim=-1)):
|
||||
""" Generates all pairs of different rows and then applies the reducer
|
||||
Args:
|
||||
input: a tensor
|
||||
dim: a dimension to generate pairs across
|
||||
reducer: a function of generated pair of rows to apply (beyond just concat)
|
||||
Returns:
|
||||
for default args, for A x B x C input, will output A x (B choose 2)
|
||||
"""
|
||||
n = input.size()[dim]
|
||||
range = torch.arange(n)
|
||||
idx = torch.combinations(range).to(input).long()
|
||||
left = input.index_select(dim, idx[:, 0])
|
||||
right = input.index_select(dim, idx[:, 1])
|
||||
return reducer(left, right)
|
||||
|
||||
|
||||
def _kernel_vec_distances(v):
|
||||
""" Computes the coefficients for linearisation of the quadratic system
|
||||
to match all pairwise distances between 4 control points (dim=1).
|
||||
The last dimension corresponds to the coefficients for quadratic terms
|
||||
Bij = Bi * Bj, where Bi and Bj correspond to kernel vectors.
|
||||
Arg:
|
||||
v: tensor of B x 4 x 3 x D, where D is dim(kernel), usually 4
|
||||
Returns:
|
||||
a tensor of B x 6 x [(D choose 2) + D];
|
||||
for D=4, the last dim means [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34].
|
||||
"""
|
||||
dv = _gen_pairs(v, dim=-3, reducer=lambda l, r: l - r) # B x 6 x 3 x D
|
||||
|
||||
# we should take dot-product of all (i,j), i < j, with coeff 2
|
||||
rows_2ij = 2.0 * _gen_pairs(dv, dim=-1, reducer=lambda l, r: (l * r).sum(dim=-2))
|
||||
# this should produce B x 6 x (D choose 2) tensor
|
||||
|
||||
# we should take dot-product of all (i,i)
|
||||
rows_ii = (dv ** 2).sum(dim=-2)
|
||||
# this should produce B x 6 x D tensor
|
||||
|
||||
return torch.cat((rows_ii, rows_2ij), dim=-1)
|
||||
|
||||
|
||||
def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx):
|
||||
""" Solves an over-determined linear system for selected LHS columns.
|
||||
A batched version of `torch.lstsq`.
|
||||
Args:
|
||||
rhs: right-hand side vectors
|
||||
lhs: left-hand side matrices
|
||||
lhs_col_idx: a slice of columns in lhs
|
||||
Returns:
|
||||
a least-squares solution for lhs * X = rhs
|
||||
"""
|
||||
lhs = lhs.index_select(-1, torch.tensor(lhs_col_idx, device=lhs.device).long())
|
||||
return torch.matmul(torch.pinverse(lhs), rhs[:, :, None])
|
||||
|
||||
|
||||
def _find_null_space_coords_1(kernel_dsts, cw_dst):
|
||||
""" Solves case 1 from the paper [1]; solve for 4 coefficients:
|
||||
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
|
||||
^ ^ ^ ^
|
||||
Args:
|
||||
kernel_dsts: distances between kernel vectors
|
||||
cw_dst: distances between control points
|
||||
Returns:
|
||||
coefficients to weight kernel vectors
|
||||
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
|
||||
EPnP: An Accurate O(n) solution to the PnP problem.
|
||||
International Journal of Computer Vision.
|
||||
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
|
||||
"""
|
||||
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6])
|
||||
|
||||
beta = beta * beta[:, :1, :].sign()
|
||||
return beta / (beta[:, :1, :] ** 0.5)
|
||||
|
||||
|
||||
def _find_null_space_coords_2(kernel_dsts, cw_dst):
|
||||
""" Solves case 2 from the paper; solve for 3 coefficients:
|
||||
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
|
||||
^ ^ ^
|
||||
Args:
|
||||
kernel_dsts: distances between kernel vectors
|
||||
cw_dst: distances between control points
|
||||
Returns:
|
||||
coefficients to weight kernel vectors
|
||||
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
|
||||
EPnP: An Accurate O(n) solution to the PnP problem.
|
||||
International Journal of Computer Vision.
|
||||
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
|
||||
"""
|
||||
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1])
|
||||
|
||||
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
|
||||
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
|
||||
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
|
||||
).float()
|
||||
|
||||
return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1)
|
||||
|
||||
|
||||
def _find_null_space_coords_3(kernel_dsts, cw_dst):
|
||||
""" Solves case 3 from the paper; solve for 5 coefficients:
|
||||
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
|
||||
^ ^ ^ ^ ^
|
||||
Args:
|
||||
kernel_dsts: distances between kernel vectors
|
||||
cw_dst: distances between control points
|
||||
Returns:
|
||||
coefficients to weight kernel vectors
|
||||
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
|
||||
EPnP: An Accurate O(n) solution to the PnP problem.
|
||||
International Journal of Computer Vision.
|
||||
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
|
||||
"""
|
||||
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7])
|
||||
|
||||
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
|
||||
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
|
||||
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
|
||||
).float()
|
||||
coord_2 = beta[:, 3:4, :] / coord_0[:, :1, :]
|
||||
|
||||
return torch.cat(
|
||||
(coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1
|
||||
)
|
||||
|
||||
|
||||
def efficient_pnp(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weights: Optional[torch.Tensor] = None,
|
||||
skip_quadratic_eq: bool = False,
|
||||
) -> EpnpSolution:
|
||||
"""
|
||||
Implements Efficient PnP algorithm [1] for Perspective-n-Points problem:
|
||||
finds a camera position (defined by rotation `R` and translation `T`) that
|
||||
minimizes re-projection error between the given 3D points `x` and
|
||||
the corresponding uncalibrated 2D points `y`, i.e. solves
|
||||
|
||||
`y[i] = Proj(x[i] R[i] + T[i])`
|
||||
|
||||
in the least-squares sense, where `i` are indices within the batch, and
|
||||
`Proj` is the perspective projection operator: `Proj([x y z]) = [x/z y/z]`.
|
||||
In the noise-less case, 4 points are enough to find the solution as long
|
||||
as they are not co-planar.
|
||||
|
||||
Args:
|
||||
x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
|
||||
y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`.
|
||||
weights: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)`. `None` means equal weights.
|
||||
skip_quadratic_eq: If True, assumes the solution space for the
|
||||
linear system is one-dimensional, i.e. takes the scaled eigenvector
|
||||
that corresponds to the smallest eigenvalue as a solution.
|
||||
If False, finds the candidate coordinates in the potentially
|
||||
4D null space by approximately solving the systems of quadratic
|
||||
equations. The best candidate is chosen by examining the 2D
|
||||
re-projection error. While this option finds a better solution,
|
||||
especially when the number of points is small or perspective
|
||||
distortions are low (the points are far away), it may be more
|
||||
difficult to back-propagate through.
|
||||
|
||||
Returns:
|
||||
`EpnpSolution` namedtuple containing elements:
|
||||
**x_cam**: Batch of transformed points `x` that is used to find
|
||||
the camera parameters, of shape `(minibatch, num_points, 3)`.
|
||||
In the general (noisy) case, they are not exactly equal to
|
||||
`x[i] R[i] + T[i]` but are some affine transform of `x[i]`s.
|
||||
**R**: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
|
||||
**T**: Batch of translation vectors of shape `(minibatch, 3)`.
|
||||
**err_2d**: Batch of mean 2D re-projection errors of shape
|
||||
`(minibatch,)`. Specifically, if `yhat` is the re-projection for
|
||||
the `i`-th batch element, it returns `sum_j norm(yhat_j - y_j)`
|
||||
where `j` iterates over points and `norm` denotes the L2 norm.
|
||||
**err_3d**: Batch of mean algebraic errors of shape `(minibatch,)`.
|
||||
Specifically, those are squared distances between `x_world` and
|
||||
estimated points on the rays defined by `y`.
|
||||
|
||||
[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
|
||||
EPnP: An Accurate O(n) solution to the PnP problem.
|
||||
International Journal of Computer Vision.
|
||||
https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
|
||||
"""
|
||||
# define control points in a world coordinate system (centered on the 3d
|
||||
# points centroid); 4 x 3
|
||||
# TODO: more stable when initialised with the center and eigenvectors!
|
||||
c_world = _define_control_points(
|
||||
x.detach(), weights, storage_opts={"dtype": x.dtype, "device": x.device}
|
||||
)
|
||||
|
||||
# find the linear combination of the control points to represent the 3d points
|
||||
alphas = _compute_alphas(x, c_world)
|
||||
|
||||
M = _build_M(y, alphas, weights)
|
||||
|
||||
# Compute kernel M
|
||||
kernel, spectrum = _null_space(M, 4)
|
||||
|
||||
c_world_distances = _gen_pairs(c_world)
|
||||
kernel_dsts = _kernel_vec_distances(kernel)
|
||||
|
||||
betas = (
|
||||
[]
|
||||
if skip_quadratic_eq
|
||||
else [
|
||||
fnsc(kernel_dsts, c_world_distances)
|
||||
for fnsc in [
|
||||
_find_null_space_coords_1,
|
||||
_find_null_space_coords_2,
|
||||
_find_null_space_coords_3,
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
c_cam_variants = [kernel] + [
|
||||
torch.matmul(kernel, beta[:, None, :, :]) for beta in betas
|
||||
]
|
||||
|
||||
solutions = [
|
||||
_compute_norm_sign_scaling_factor(c_cam[..., 0], alphas, x, y, weights)
|
||||
for c_cam in c_cam_variants
|
||||
]
|
||||
|
||||
sol_zipped = EpnpSolution(*(torch.stack(list(col)) for col in zip(*solutions)))
|
||||
best = torch.argmin(sol_zipped.err_2d, dim=0)
|
||||
|
||||
def gather1d(source, idx):
|
||||
# reduces the dim=1 by picking the slices in a 1D tensor idx
|
||||
# in other words, it is batched index_select.
|
||||
return source.gather(
|
||||
0,
|
||||
idx.reshape(1, -1, *([1] * (len(source.shape) - 2))).expand_as(source[:1]),
|
||||
)[0]
|
||||
|
||||
return EpnpSolution(*[gather1d(sol_col, best) for sol_col in sol_zipped])
|
25
tests/bm_perspective_n_points.py
Normal file
25
tests/bm_perspective_n_points.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import itertools
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_perspective_n_points import TestPerspectiveNPoints
|
||||
|
||||
|
||||
def bm_perspective_n_points() -> None:
|
||||
case_grid = {
|
||||
"batch_size": [1, 10, 100],
|
||||
"num_pts": [100, 100000],
|
||||
"skip_q": [False, True],
|
||||
}
|
||||
|
||||
test_cases = itertools.product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||
|
||||
test = TestPerspectiveNPoints()
|
||||
benchmark(
|
||||
test.case_with_gaussian_points,
|
||||
"PerspectiveNPoints",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from copy import deepcopy
|
||||
|
@ -1,12 +1,15 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
TensorOrArray = Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class TestCaseMixin(unittest.TestCase):
|
||||
def assertSeparate(self, tensor1, tensor2) -> None:
|
||||
"""
|
||||
@ -28,10 +31,11 @@ class TestCaseMixin(unittest.TestCase):
|
||||
ptrs = [i.storage().data_ptr() for i in tensor_list]
|
||||
self.assertCountEqual(ptrs, set(ptrs))
|
||||
|
||||
def assertClose(
|
||||
def assertNormsClose(
|
||||
self,
|
||||
input,
|
||||
other,
|
||||
input: TensorOrArray,
|
||||
other: TensorOrArray,
|
||||
norm_fn: Callable[[TensorOrArray], TensorOrArray],
|
||||
*,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
@ -39,7 +43,60 @@ class TestCaseMixin(unittest.TestCase):
|
||||
msg: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that two tensors or arrays are the same shape and close.
|
||||
Verifies that two tensors or arrays have the same shape and are close
|
||||
given absolute and relative tolerance; raises AssertionError otherwise.
|
||||
A custom norm function is computed before comparison. If no such pre-
|
||||
processing needed, pass `torch.abs` or, equivalently, call `assertClose`.
|
||||
Args:
|
||||
input, other: two tensors or two arrays.
|
||||
norm_fn: The function evaluates
|
||||
`all(norm_fn(input - other) <= atol + rtol * norm_fn(other))`.
|
||||
norm_fn is a tensor -> tensor function; the output has:
|
||||
* all entries non-negative,
|
||||
* shape defined by the input shape only.
|
||||
rtol, atol, equal_nan: as for torch.allclose.
|
||||
msg: message in case the assertion is violated.
|
||||
Note:
|
||||
Optional arguments here are all keyword-only, to avoid confusion
|
||||
with msg arguments on other assert functions.
|
||||
"""
|
||||
|
||||
self.assertEqual(np.shape(input), np.shape(other))
|
||||
|
||||
diff = norm_fn(input - other)
|
||||
other_ = norm_fn(other)
|
||||
|
||||
# We want to generalise allclose(input, output), which is essentially
|
||||
# all(diff <= atol + rtol * other)
|
||||
# but with a sophisticated handling non-finite values.
|
||||
# We work that around by calling allclose() with the following arguments:
|
||||
# allclose(diff + other_, other_). This computes what we want because
|
||||
# all(|diff + other_ - other_| <= atol + rtol * |other_|) ==
|
||||
# all(|norm_fn(input - other)| <= atol + rtol * |norm_fn(other)|) ==
|
||||
# all(norm_fn(input - other) <= atol + rtol * norm_fn(other)).
|
||||
|
||||
backend = torch if torch.is_tensor(input) else np
|
||||
close = backend.allclose(
|
||||
diff + other_, other_, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
|
||||
self.assertTrue(close, msg)
|
||||
|
||||
def assertClose(
|
||||
self,
|
||||
input: TensorOrArray,
|
||||
other: TensorOrArray,
|
||||
*,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
msg: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Verifies that two tensors or arrays have the same shape and are close
|
||||
given absolute and relative tolerance, i.e. checks
|
||||
`all(|input - other| <= atol + rtol * |other|)`;
|
||||
raises AssertionError otherwise.
|
||||
Args:
|
||||
input, other: two tensors or two arrays.
|
||||
rtol, atol, equal_nan: as for torch.allclose.
|
||||
@ -51,10 +108,9 @@ class TestCaseMixin(unittest.TestCase):
|
||||
|
||||
self.assertEqual(np.shape(input), np.shape(other))
|
||||
|
||||
if torch.is_tensor(input):
|
||||
close = torch.allclose(
|
||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
else:
|
||||
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
backend = torch if torch.is_tensor(input) else np
|
||||
close = backend.allclose(
|
||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
|
||||
self.assertTrue(close, msg)
|
||||
|
56
tests/test_common_testing.py
Normal file
56
tests/test_common_testing.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
def test_all_close(self):
|
||||
device = torch.device("cuda:0")
|
||||
n_points = 20
|
||||
noise_std = 1e-3
|
||||
msg = "tratata"
|
||||
|
||||
# test absolute tolerance
|
||||
x = torch.rand(n_points, 3, device=device)
|
||||
x_noise = x + noise_std * torch.rand(n_points, 3, device=device)
|
||||
assert torch.allclose(x, x_noise, atol=10 * noise_std)
|
||||
assert not torch.allclose(x, x_noise, atol=0.1 * noise_std)
|
||||
self.assertClose(x, x_noise, atol=10 * noise_std)
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
self.assertClose(x, x_noise, atol=0.1 * noise_std, msg=msg)
|
||||
self.assertTrue(msg in str(context.exception))
|
||||
|
||||
# test numpy
|
||||
def to_np(t):
|
||||
return t.data.cpu().numpy()
|
||||
|
||||
self.assertClose(to_np(x), to_np(x_noise), atol=10 * noise_std)
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
self.assertClose(to_np(x), to_np(x_noise), atol=0.1 * noise_std, msg=msg)
|
||||
self.assertTrue(msg in str(context.exception))
|
||||
|
||||
# test relative tolerance
|
||||
assert torch.allclose(x, x_noise, rtol=100 * noise_std)
|
||||
assert not torch.allclose(x, x_noise, rtol=noise_std)
|
||||
self.assertClose(x, x_noise, rtol=100 * noise_std)
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
self.assertClose(x, x_noise, rtol=noise_std, msg=msg)
|
||||
self.assertTrue(msg in str(context.exception))
|
||||
|
||||
# test norm aggregation
|
||||
# if one of the spatial dimensions is small, norm aggregation helps
|
||||
x_noise[:, 0] = x_noise[:, 0] - x[:, 0]
|
||||
x[:, 0] = 0.0
|
||||
assert not torch.allclose(x, x_noise, rtol=100 * noise_std)
|
||||
self.assertNormsClose(
|
||||
x, x_noise, rtol=100 * noise_std, norm_fn=lambda t: t.norm(dim=-1)
|
||||
)
|
131
tests/test_perspective_n_points.py
Normal file
131
tests/test_perspective_n_points.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import perspective_n_points
|
||||
from pytorch3d.transforms import rotation_conversions
|
||||
|
||||
|
||||
def reproj_error(x_world, y, R, T, weight=None):
|
||||
# applies the affine transform, projects, and computes the reprojection error
|
||||
y_hat = torch.matmul(x_world, R) + T[:, None, :]
|
||||
y_hat = y_hat / y_hat[..., 2:]
|
||||
if weight is None:
|
||||
weight = y.new_ones((1, 1))
|
||||
return (((weight[:, :, None] * (y - y_hat[..., :2])) ** 2).sum(dim=-1) ** 0.5).mean(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
|
||||
class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
|
||||
def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False):
|
||||
sol = perspective_n_points.efficient_pnp(
|
||||
x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q
|
||||
)
|
||||
|
||||
err_2d = reproj_error(x_world, y, sol.R, sol.T)
|
||||
R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R)
|
||||
R_quat = rotation_conversions.matrix_to_quaternion(R)
|
||||
|
||||
num_pts = x_world.shape[-2]
|
||||
# quadratic part is more stable with fewer points
|
||||
num_pts_thresh = 5 if skip_q else 4
|
||||
if check_output and num_pts > num_pts_thresh:
|
||||
assert_msg = (
|
||||
f"test_perspective_n_points assertion failure for "
|
||||
f"n_points={num_pts}, "
|
||||
f"skip_quadratic={skip_q}, "
|
||||
f"no noise."
|
||||
)
|
||||
|
||||
self.assertClose(err_2d, sol.err_2d, msg=assert_msg)
|
||||
self.assertTrue((err_2d < 1e-4).all(), msg=assert_msg)
|
||||
|
||||
def norm_fn(t):
|
||||
return t.norm(dim=-1)
|
||||
|
||||
self.assertNormsClose(
|
||||
T, sol.T[:, None, :], rtol=1e-2, norm_fn=norm_fn, msg=assert_msg
|
||||
)
|
||||
self.assertNormsClose(
|
||||
R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg
|
||||
)
|
||||
|
||||
if print_stats:
|
||||
torch.set_printoptions(precision=5, sci_mode=False)
|
||||
for err_2d, err_3d, R_gt, T_gt in zip(
|
||||
sol.err_2d,
|
||||
sol.err_3d,
|
||||
torch.cat((sol.R, R), dim=-1),
|
||||
torch.stack((sol.T, T[:, 0, :]), dim=-1),
|
||||
):
|
||||
print("2D Error: %1.4f" % err_2d.item())
|
||||
print("3D Error: %1.4f" % err_3d.item())
|
||||
print("R_hat | R_gt\n", R_gt)
|
||||
print("T_hat | T_gt\n", T_gt)
|
||||
|
||||
def _testcase_from_2d(self, y, print_stats, benchmark, skip_q=False):
|
||||
x_cam = torch.cat((y, torch.rand_like(y[:, :1]) * 2.0 + 3.5), dim=1)
|
||||
x_cam[:, :2] *= x_cam[:, 2:] # unproject
|
||||
|
||||
R = rotation_conversions.random_rotations(16).to(y)
|
||||
T = torch.randn_like(R[:, :1, :])
|
||||
x_world = torch.matmul(x_cam - T, R.transpose(1, 2))
|
||||
|
||||
if print_stats:
|
||||
print("Run without noise")
|
||||
|
||||
if benchmark: # return curried call
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def result():
|
||||
self._run_and_print(x_world, y, R, T, False, skip_q)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return result
|
||||
|
||||
self._run_and_print(x_world, y, R, T, print_stats, skip_q, check_output=True)
|
||||
|
||||
# in the noisy case, there are no guarantees, so we check it doesn't crash
|
||||
if print_stats:
|
||||
print("Run with noise")
|
||||
x_world += torch.randn_like(x_world) * 0.1
|
||||
self._run_and_print(x_world, y, R, T, print_stats, skip_q)
|
||||
|
||||
def case_with_gaussian_points(
|
||||
self, batch_size=10, num_pts=20, print_stats=False, benchmark=True, skip_q=False
|
||||
):
|
||||
return self._testcase_from_2d(
|
||||
torch.randn((num_pts, 2)).cuda() / 3.0,
|
||||
print_stats=print_stats,
|
||||
benchmark=benchmark,
|
||||
skip_q=skip_q,
|
||||
)
|
||||
|
||||
def test_perspective_n_points(self, print_stats=False):
|
||||
if print_stats:
|
||||
print("RUN ON A DENSE GRID")
|
||||
u = torch.linspace(-1.0, 1.0, 20)
|
||||
v = torch.linspace(-1.0, 1.0, 15)
|
||||
for skip_q in [False, True]:
|
||||
self._testcase_from_2d(
|
||||
torch.cartesian_prod(u, v).cuda(), print_stats, False, skip_q
|
||||
)
|
||||
|
||||
for num_pts in range(6, 3, -1):
|
||||
for skip_q in [False, True]:
|
||||
if print_stats:
|
||||
print(f"RUN ON {num_pts} points; skip_quadratic: {skip_q}")
|
||||
|
||||
self.case_with_gaussian_points(
|
||||
num_pts=num_pts,
|
||||
print_stats=print_stats,
|
||||
benchmark=False,
|
||||
skip_q=skip_q,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user