mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Efficient PnP weighting bug fix
Summary: There is a bug in efficient PnP that incorrectly weights points. This fixes it. The test does not pass for the previous version with the bug. Reviewed By: shapovalov Differential Revision: D22449357 fbshipit-source-id: f5a22081e91d25681a6a783cce2f5c6be429ca6a
This commit is contained in:
parent
2f3cd98725
commit
daf9eac801
@ -66,6 +66,10 @@ def _build_M(y, alphas, weight):
|
||||
def prepad(t, v):
|
||||
return F.pad(t, (1, 0), value=v)
|
||||
|
||||
if weight is not None:
|
||||
# weight the alphas in order to get a correctly weighted version of M
|
||||
alphas = alphas * weight[:, :, None]
|
||||
|
||||
# outer left-multiply by alphas
|
||||
def lm_alphas(t):
|
||||
return torch.matmul(alphas[..., None], t).reshape(bs, n, 12)
|
||||
@ -82,9 +86,6 @@ def _build_M(y, alphas, weight):
|
||||
dim=-1,
|
||||
).reshape(bs, -1, 12)
|
||||
|
||||
if weight is not None:
|
||||
M = M * weight.repeat(1, 2)[:, :, None]
|
||||
|
||||
return M
|
||||
|
||||
|
||||
|
@ -24,6 +24,21 @@ class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
|
||||
@classmethod
|
||||
def _generate_epnp_test_from_2d(cls, y):
|
||||
"""
|
||||
Instantiate random x_world, x_cam, R, T given a set of input
|
||||
2D projections y.
|
||||
"""
|
||||
batch_size = y.shape[0]
|
||||
x_cam = torch.cat((y, torch.rand_like(y[:, :, :1]) * 2.0 + 3.5), dim=2)
|
||||
x_cam[:, :, :2] *= x_cam[:, :, 2:] # unproject
|
||||
R = rotation_conversions.random_rotations(batch_size).to(y)
|
||||
T = torch.randn_like(R[:, :1, :])
|
||||
T[:, :, 2] = (T[:, :, 2] + 3.0).clamp(2.0)
|
||||
x_world = torch.matmul(x_cam - T, R.transpose(1, 2))
|
||||
return x_cam, x_world, R, T
|
||||
|
||||
def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False):
|
||||
sol = perspective_n_points.efficient_pnp(
|
||||
x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q
|
||||
@ -45,16 +60,16 @@ class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertClose(err_2d, sol.err_2d, msg=assert_msg)
|
||||
self.assertTrue((err_2d < 1e-4).all(), msg=assert_msg)
|
||||
self.assertTrue((err_2d < 5e-4).all(), msg=assert_msg)
|
||||
|
||||
def norm_fn(t):
|
||||
return t.norm(dim=-1)
|
||||
|
||||
self.assertNormsClose(
|
||||
T, sol.T[:, None, :], rtol=3e-3, norm_fn=norm_fn, msg=assert_msg
|
||||
T, sol.T[:, None, :], rtol=4e-3, norm_fn=norm_fn, msg=assert_msg
|
||||
)
|
||||
self.assertNormsClose(
|
||||
R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg
|
||||
R_quat, R_est_quat, rtol=3e-3, norm_fn=norm_fn, msg=assert_msg
|
||||
)
|
||||
|
||||
if print_stats:
|
||||
@ -71,12 +86,9 @@ class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
|
||||
print("T_hat | T_gt\n", T_gt)
|
||||
|
||||
def _testcase_from_2d(self, y, print_stats, benchmark, skip_q=False):
|
||||
x_cam = torch.cat((y, torch.rand_like(y[:, :1]) * 2.0 + 3.5), dim=1)
|
||||
x_cam[:, :2] *= x_cam[:, 2:] # unproject
|
||||
|
||||
R = rotation_conversions.random_rotations(16).to(y)
|
||||
T = torch.randn_like(R[:, :1, :])
|
||||
x_world = torch.matmul(x_cam - T, R.transpose(1, 2))
|
||||
x_cam, x_world, R, T = TestPerspectiveNPoints._generate_epnp_test_from_2d(
|
||||
y[None].repeat(16, 1, 1)
|
||||
)
|
||||
|
||||
if print_stats:
|
||||
print("Run without noise")
|
||||
@ -129,3 +141,45 @@ class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
|
||||
benchmark=False,
|
||||
skip_q=skip_q,
|
||||
)
|
||||
|
||||
def test_weighted_perspective_n_points(self, batch_size=16, num_pts=200):
|
||||
# instantiate random x_world and y
|
||||
y = torch.randn((batch_size, num_pts, 2)).cuda() / 3.0
|
||||
x_cam, x_world, R, T = TestPerspectiveNPoints._generate_epnp_test_from_2d(y)
|
||||
|
||||
# randomly drop 50% of the rows
|
||||
weights = (torch.rand_like(x_world[:, :, 0]) > 0.5).float()
|
||||
|
||||
# make sure we retain at least 6 points for each case
|
||||
weights[:, :6] = 1.0
|
||||
|
||||
# fill ignored y with trash to ensure that we get different
|
||||
# solution in case the weighting is wrong
|
||||
y = y + (1 - weights[:, :, None]) * 100.0
|
||||
|
||||
def norm_fn(t):
|
||||
return t.norm(dim=-1)
|
||||
|
||||
for skip_quadratic_eq in (True, False):
|
||||
# get the solution for the 0/1 weighted case
|
||||
sol = perspective_n_points.efficient_pnp(
|
||||
x_world, y, skip_quadratic_eq=skip_quadratic_eq, weights=weights
|
||||
)
|
||||
sol_R_quat = rotation_conversions.matrix_to_quaternion(sol.R)
|
||||
sol_T = sol.T
|
||||
|
||||
# check that running only on points with non-zero weights ends in the
|
||||
# same place as running the 0/1 weighted version
|
||||
for i in range(batch_size):
|
||||
ok = weights[i] > 0
|
||||
x_world_ok = x_world[i, ok][None]
|
||||
y_ok = y[i, ok][None]
|
||||
sol_ok = perspective_n_points.efficient_pnp(
|
||||
x_world_ok, y_ok, skip_quadratic_eq=False
|
||||
)
|
||||
R_est_quat_ok = rotation_conversions.matrix_to_quaternion(sol_ok.R)
|
||||
|
||||
self.assertNormsClose(sol_T[i], sol_ok.T[0], rtol=3e-3, norm_fn=norm_fn)
|
||||
self.assertNormsClose(
|
||||
sol_R_quat[i], R_est_quat_ok[0], rtol=3e-4, norm_fn=norm_fn
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user