mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: lg-zhang found the problem with the quadratic part of ePnP implementation: n262385 . It was caused by a coefficient returned from the linear equation solver being equal to exactly 0.0, which caused `sign()` to return 0, something I had not anticipated. I also made sure we avoid division by zero by clamping all relevant denominators. Reviewed By: nikhilaravi, lg-zhang Differential Revision: D21531200 fbshipit-source-id: 9eb2fa9d4f4f8f5f411d4cf1cffcc44b365b7e51
132 lines
4.6 KiB
Python
132 lines
4.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
from common_testing import TestCaseMixin
|
|
from pytorch3d.ops import perspective_n_points
|
|
from pytorch3d.transforms import rotation_conversions
|
|
|
|
|
|
def reproj_error(x_world, y, R, T, weight=None):
|
|
# applies the affine transform, projects, and computes the reprojection error
|
|
y_hat = torch.matmul(x_world, R) + T[:, None, :]
|
|
y_hat = y_hat / y_hat[..., 2:]
|
|
if weight is None:
|
|
weight = y.new_ones((1, 1))
|
|
return (((weight[:, :, None] * (y - y_hat[..., :2])) ** 2).sum(dim=-1) ** 0.5).mean(
|
|
dim=-1
|
|
)
|
|
|
|
|
|
class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
torch.manual_seed(42)
|
|
|
|
def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False):
|
|
sol = perspective_n_points.efficient_pnp(
|
|
x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q
|
|
)
|
|
|
|
err_2d = reproj_error(x_world, y, sol.R, sol.T)
|
|
R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R)
|
|
R_quat = rotation_conversions.matrix_to_quaternion(R)
|
|
|
|
num_pts = x_world.shape[-2]
|
|
# quadratic part is more stable with fewer points
|
|
num_pts_thresh = 5 if skip_q else 4
|
|
if check_output and num_pts > num_pts_thresh:
|
|
assert_msg = (
|
|
f"test_perspective_n_points assertion failure for "
|
|
f"n_points={num_pts}, "
|
|
f"skip_quadratic={skip_q}, "
|
|
f"no noise."
|
|
)
|
|
|
|
self.assertClose(err_2d, sol.err_2d, msg=assert_msg)
|
|
self.assertTrue((err_2d < 1e-4).all(), msg=assert_msg)
|
|
|
|
def norm_fn(t):
|
|
return t.norm(dim=-1)
|
|
|
|
self.assertNormsClose(
|
|
T, sol.T[:, None, :], rtol=3e-3, norm_fn=norm_fn, msg=assert_msg
|
|
)
|
|
self.assertNormsClose(
|
|
R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg
|
|
)
|
|
|
|
if print_stats:
|
|
torch.set_printoptions(precision=5, sci_mode=False)
|
|
for err_2d, err_3d, R_gt, T_gt in zip(
|
|
sol.err_2d,
|
|
sol.err_3d,
|
|
torch.cat((sol.R, R), dim=-1),
|
|
torch.stack((sol.T, T[:, 0, :]), dim=-1),
|
|
):
|
|
print("2D Error: %1.4f" % err_2d.item())
|
|
print("3D Error: %1.4f" % err_3d.item())
|
|
print("R_hat | R_gt\n", R_gt)
|
|
print("T_hat | T_gt\n", T_gt)
|
|
|
|
def _testcase_from_2d(self, y, print_stats, benchmark, skip_q=False):
|
|
x_cam = torch.cat((y, torch.rand_like(y[:, :1]) * 2.0 + 3.5), dim=1)
|
|
x_cam[:, :2] *= x_cam[:, 2:] # unproject
|
|
|
|
R = rotation_conversions.random_rotations(16).to(y)
|
|
T = torch.randn_like(R[:, :1, :])
|
|
x_world = torch.matmul(x_cam - T, R.transpose(1, 2))
|
|
|
|
if print_stats:
|
|
print("Run without noise")
|
|
|
|
if benchmark: # return curried call
|
|
torch.cuda.synchronize()
|
|
|
|
def result():
|
|
self._run_and_print(x_world, y, R, T, False, skip_q)
|
|
torch.cuda.synchronize()
|
|
|
|
return result
|
|
|
|
self._run_and_print(x_world, y, R, T, print_stats, skip_q, check_output=True)
|
|
|
|
# in the noisy case, there are no guarantees, so we check it doesn't crash
|
|
if print_stats:
|
|
print("Run with noise")
|
|
x_world += torch.randn_like(x_world) * 0.1
|
|
self._run_and_print(x_world, y, R, T, print_stats, skip_q)
|
|
|
|
def case_with_gaussian_points(
|
|
self, batch_size=10, num_pts=20, print_stats=False, benchmark=True, skip_q=False
|
|
):
|
|
return self._testcase_from_2d(
|
|
torch.randn((num_pts, 2)).cuda() / 3.0,
|
|
print_stats=print_stats,
|
|
benchmark=benchmark,
|
|
skip_q=skip_q,
|
|
)
|
|
|
|
def test_perspective_n_points(self, print_stats=False):
|
|
if print_stats:
|
|
print("RUN ON A DENSE GRID")
|
|
u = torch.linspace(-1.0, 1.0, 20)
|
|
v = torch.linspace(-1.0, 1.0, 15)
|
|
for skip_q in [False, True]:
|
|
self._testcase_from_2d(
|
|
torch.cartesian_prod(u, v).cuda(), print_stats, False, skip_q
|
|
)
|
|
|
|
for num_pts in range(6, 3, -1):
|
|
for skip_q in [False, True]:
|
|
if print_stats:
|
|
print(f"RUN ON {num_pts} points; skip_quadratic: {skip_q}")
|
|
|
|
self.case_with_gaussian_points(
|
|
num_pts=num_pts,
|
|
print_stats=print_stats,
|
|
benchmark=False,
|
|
skip_q=skip_q,
|
|
)
|