mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Efficient PnP.
Summary: Efficient PnP algorithm to fit 2D to 3D correspondences under perspective assumption. Benchmarked both variants of nullspace and pick one; SVD takes 7 times longer in the 100K points case. Reviewed By: davnov134, gkioxari Differential Revision: D20095754 fbshipit-source-id: 2b4519729630e6373820880272f674829eaed073
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7788a38050
commit
04d8bf6a43
25
tests/bm_perspective_n_points.py
Normal file
25
tests/bm_perspective_n_points.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import itertools
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_perspective_n_points import TestPerspectiveNPoints
|
||||
|
||||
|
||||
def bm_perspective_n_points() -> None:
|
||||
case_grid = {
|
||||
"batch_size": [1, 10, 100],
|
||||
"num_pts": [100, 100000],
|
||||
"skip_q": [False, True],
|
||||
}
|
||||
|
||||
test_cases = itertools.product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||
|
||||
test = TestPerspectiveNPoints()
|
||||
benchmark(
|
||||
test.case_with_gaussian_points,
|
||||
"PerspectiveNPoints",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
TensorOrArray = Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class TestCaseMixin(unittest.TestCase):
|
||||
def assertSeparate(self, tensor1, tensor2) -> None:
|
||||
"""
|
||||
@@ -28,10 +31,11 @@ class TestCaseMixin(unittest.TestCase):
|
||||
ptrs = [i.storage().data_ptr() for i in tensor_list]
|
||||
self.assertCountEqual(ptrs, set(ptrs))
|
||||
|
||||
def assertClose(
|
||||
def assertNormsClose(
|
||||
self,
|
||||
input,
|
||||
other,
|
||||
input: TensorOrArray,
|
||||
other: TensorOrArray,
|
||||
norm_fn: Callable[[TensorOrArray], TensorOrArray],
|
||||
*,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
@@ -39,7 +43,60 @@ class TestCaseMixin(unittest.TestCase):
|
||||
msg: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that two tensors or arrays are the same shape and close.
|
||||
Verifies that two tensors or arrays have the same shape and are close
|
||||
given absolute and relative tolerance; raises AssertionError otherwise.
|
||||
A custom norm function is computed before comparison. If no such pre-
|
||||
processing needed, pass `torch.abs` or, equivalently, call `assertClose`.
|
||||
Args:
|
||||
input, other: two tensors or two arrays.
|
||||
norm_fn: The function evaluates
|
||||
`all(norm_fn(input - other) <= atol + rtol * norm_fn(other))`.
|
||||
norm_fn is a tensor -> tensor function; the output has:
|
||||
* all entries non-negative,
|
||||
* shape defined by the input shape only.
|
||||
rtol, atol, equal_nan: as for torch.allclose.
|
||||
msg: message in case the assertion is violated.
|
||||
Note:
|
||||
Optional arguments here are all keyword-only, to avoid confusion
|
||||
with msg arguments on other assert functions.
|
||||
"""
|
||||
|
||||
self.assertEqual(np.shape(input), np.shape(other))
|
||||
|
||||
diff = norm_fn(input - other)
|
||||
other_ = norm_fn(other)
|
||||
|
||||
# We want to generalise allclose(input, output), which is essentially
|
||||
# all(diff <= atol + rtol * other)
|
||||
# but with a sophisticated handling non-finite values.
|
||||
# We work that around by calling allclose() with the following arguments:
|
||||
# allclose(diff + other_, other_). This computes what we want because
|
||||
# all(|diff + other_ - other_| <= atol + rtol * |other_|) ==
|
||||
# all(|norm_fn(input - other)| <= atol + rtol * |norm_fn(other)|) ==
|
||||
# all(norm_fn(input - other) <= atol + rtol * norm_fn(other)).
|
||||
|
||||
backend = torch if torch.is_tensor(input) else np
|
||||
close = backend.allclose(
|
||||
diff + other_, other_, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
|
||||
self.assertTrue(close, msg)
|
||||
|
||||
def assertClose(
|
||||
self,
|
||||
input: TensorOrArray,
|
||||
other: TensorOrArray,
|
||||
*,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
msg: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Verifies that two tensors or arrays have the same shape and are close
|
||||
given absolute and relative tolerance, i.e. checks
|
||||
`all(|input - other| <= atol + rtol * |other|)`;
|
||||
raises AssertionError otherwise.
|
||||
Args:
|
||||
input, other: two tensors or two arrays.
|
||||
rtol, atol, equal_nan: as for torch.allclose.
|
||||
@@ -51,10 +108,9 @@ class TestCaseMixin(unittest.TestCase):
|
||||
|
||||
self.assertEqual(np.shape(input), np.shape(other))
|
||||
|
||||
if torch.is_tensor(input):
|
||||
close = torch.allclose(
|
||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
else:
|
||||
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
backend = torch if torch.is_tensor(input) else np
|
||||
close = backend.allclose(
|
||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
|
||||
self.assertTrue(close, msg)
|
||||
|
||||
56
tests/test_common_testing.py
Normal file
56
tests/test_common_testing.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
def test_all_close(self):
|
||||
device = torch.device("cuda:0")
|
||||
n_points = 20
|
||||
noise_std = 1e-3
|
||||
msg = "tratata"
|
||||
|
||||
# test absolute tolerance
|
||||
x = torch.rand(n_points, 3, device=device)
|
||||
x_noise = x + noise_std * torch.rand(n_points, 3, device=device)
|
||||
assert torch.allclose(x, x_noise, atol=10 * noise_std)
|
||||
assert not torch.allclose(x, x_noise, atol=0.1 * noise_std)
|
||||
self.assertClose(x, x_noise, atol=10 * noise_std)
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
self.assertClose(x, x_noise, atol=0.1 * noise_std, msg=msg)
|
||||
self.assertTrue(msg in str(context.exception))
|
||||
|
||||
# test numpy
|
||||
def to_np(t):
|
||||
return t.data.cpu().numpy()
|
||||
|
||||
self.assertClose(to_np(x), to_np(x_noise), atol=10 * noise_std)
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
self.assertClose(to_np(x), to_np(x_noise), atol=0.1 * noise_std, msg=msg)
|
||||
self.assertTrue(msg in str(context.exception))
|
||||
|
||||
# test relative tolerance
|
||||
assert torch.allclose(x, x_noise, rtol=100 * noise_std)
|
||||
assert not torch.allclose(x, x_noise, rtol=noise_std)
|
||||
self.assertClose(x, x_noise, rtol=100 * noise_std)
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
self.assertClose(x, x_noise, rtol=noise_std, msg=msg)
|
||||
self.assertTrue(msg in str(context.exception))
|
||||
|
||||
# test norm aggregation
|
||||
# if one of the spatial dimensions is small, norm aggregation helps
|
||||
x_noise[:, 0] = x_noise[:, 0] - x[:, 0]
|
||||
x[:, 0] = 0.0
|
||||
assert not torch.allclose(x, x_noise, rtol=100 * noise_std)
|
||||
self.assertNormsClose(
|
||||
x, x_noise, rtol=100 * noise_std, norm_fn=lambda t: t.norm(dim=-1)
|
||||
)
|
||||
131
tests/test_perspective_n_points.py
Normal file
131
tests/test_perspective_n_points.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import perspective_n_points
|
||||
from pytorch3d.transforms import rotation_conversions
|
||||
|
||||
|
||||
def reproj_error(x_world, y, R, T, weight=None):
|
||||
# applies the affine transform, projects, and computes the reprojection error
|
||||
y_hat = torch.matmul(x_world, R) + T[:, None, :]
|
||||
y_hat = y_hat / y_hat[..., 2:]
|
||||
if weight is None:
|
||||
weight = y.new_ones((1, 1))
|
||||
return (((weight[:, :, None] * (y - y_hat[..., :2])) ** 2).sum(dim=-1) ** 0.5).mean(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
|
||||
class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
|
||||
def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False):
|
||||
sol = perspective_n_points.efficient_pnp(
|
||||
x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q
|
||||
)
|
||||
|
||||
err_2d = reproj_error(x_world, y, sol.R, sol.T)
|
||||
R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R)
|
||||
R_quat = rotation_conversions.matrix_to_quaternion(R)
|
||||
|
||||
num_pts = x_world.shape[-2]
|
||||
# quadratic part is more stable with fewer points
|
||||
num_pts_thresh = 5 if skip_q else 4
|
||||
if check_output and num_pts > num_pts_thresh:
|
||||
assert_msg = (
|
||||
f"test_perspective_n_points assertion failure for "
|
||||
f"n_points={num_pts}, "
|
||||
f"skip_quadratic={skip_q}, "
|
||||
f"no noise."
|
||||
)
|
||||
|
||||
self.assertClose(err_2d, sol.err_2d, msg=assert_msg)
|
||||
self.assertTrue((err_2d < 1e-4).all(), msg=assert_msg)
|
||||
|
||||
def norm_fn(t):
|
||||
return t.norm(dim=-1)
|
||||
|
||||
self.assertNormsClose(
|
||||
T, sol.T[:, None, :], rtol=1e-2, norm_fn=norm_fn, msg=assert_msg
|
||||
)
|
||||
self.assertNormsClose(
|
||||
R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg
|
||||
)
|
||||
|
||||
if print_stats:
|
||||
torch.set_printoptions(precision=5, sci_mode=False)
|
||||
for err_2d, err_3d, R_gt, T_gt in zip(
|
||||
sol.err_2d,
|
||||
sol.err_3d,
|
||||
torch.cat((sol.R, R), dim=-1),
|
||||
torch.stack((sol.T, T[:, 0, :]), dim=-1),
|
||||
):
|
||||
print("2D Error: %1.4f" % err_2d.item())
|
||||
print("3D Error: %1.4f" % err_3d.item())
|
||||
print("R_hat | R_gt\n", R_gt)
|
||||
print("T_hat | T_gt\n", T_gt)
|
||||
|
||||
def _testcase_from_2d(self, y, print_stats, benchmark, skip_q=False):
|
||||
x_cam = torch.cat((y, torch.rand_like(y[:, :1]) * 2.0 + 3.5), dim=1)
|
||||
x_cam[:, :2] *= x_cam[:, 2:] # unproject
|
||||
|
||||
R = rotation_conversions.random_rotations(16).to(y)
|
||||
T = torch.randn_like(R[:, :1, :])
|
||||
x_world = torch.matmul(x_cam - T, R.transpose(1, 2))
|
||||
|
||||
if print_stats:
|
||||
print("Run without noise")
|
||||
|
||||
if benchmark: # return curried call
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def result():
|
||||
self._run_and_print(x_world, y, R, T, False, skip_q)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return result
|
||||
|
||||
self._run_and_print(x_world, y, R, T, print_stats, skip_q, check_output=True)
|
||||
|
||||
# in the noisy case, there are no guarantees, so we check it doesn't crash
|
||||
if print_stats:
|
||||
print("Run with noise")
|
||||
x_world += torch.randn_like(x_world) * 0.1
|
||||
self._run_and_print(x_world, y, R, T, print_stats, skip_q)
|
||||
|
||||
def case_with_gaussian_points(
|
||||
self, batch_size=10, num_pts=20, print_stats=False, benchmark=True, skip_q=False
|
||||
):
|
||||
return self._testcase_from_2d(
|
||||
torch.randn((num_pts, 2)).cuda() / 3.0,
|
||||
print_stats=print_stats,
|
||||
benchmark=benchmark,
|
||||
skip_q=skip_q,
|
||||
)
|
||||
|
||||
def test_perspective_n_points(self, print_stats=False):
|
||||
if print_stats:
|
||||
print("RUN ON A DENSE GRID")
|
||||
u = torch.linspace(-1.0, 1.0, 20)
|
||||
v = torch.linspace(-1.0, 1.0, 15)
|
||||
for skip_q in [False, True]:
|
||||
self._testcase_from_2d(
|
||||
torch.cartesian_prod(u, v).cuda(), print_stats, False, skip_q
|
||||
)
|
||||
|
||||
for num_pts in range(6, 3, -1):
|
||||
for skip_q in [False, True]:
|
||||
if print_stats:
|
||||
print(f"RUN ON {num_pts} points; skip_quadratic: {skip_q}")
|
||||
|
||||
self.case_with_gaussian_points(
|
||||
num_pts=num_pts,
|
||||
print_stats=print_stats,
|
||||
benchmark=False,
|
||||
skip_q=skip_q,
|
||||
)
|
||||
Reference in New Issue
Block a user