Numerical stability of ePnP.

Summary: lg-zhang found the problem with the quadratic part of ePnP implementation: n262385 . It was caused by a coefficient returned from the linear equation solver being equal to exactly 0.0, which caused `sign()` to return 0, something I had not anticipated. I also made sure we avoid division by zero by clamping all relevant denominators.

Reviewed By: nikhilaravi, lg-zhang

Differential Revision: D21531200

fbshipit-source-id: 9eb2fa9d4f4f8f5f411d4cf1cffcc44b365b7e51
This commit is contained in:
Roman Shapovalov 2020-05-15 01:35:00 -07:00 committed by Facebook GitHub Bot
parent a0e14cae1e
commit a8377f1f06
2 changed files with 15 additions and 10 deletions

View File

@ -104,7 +104,7 @@ def _null_space(m, kernel_dim):
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
def _reproj_error(y_hat, y, weight):
def _reproj_error(y_hat, y, weight, eps=1e-9):
""" Projects estimated 3D points and computes the reprojection error
Args:
y_hat: a batch of predicted 2D points in homogeneous coordinates
@ -114,7 +114,7 @@ def _reproj_error(y_hat, y, weight):
Returns:
Optionally weighted RMSE of difference between y and y_hat.
"""
y_hat = y_hat / y_hat[..., 2:]
y_hat = y_hat / torch.clamp(y_hat[..., 2:], eps)
dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5
return oputil.wmean(dist, weight)[:, 0, 0]
@ -155,6 +155,7 @@ def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e-
R, T, s = points_alignment.corresponding_points_alignment(
x_world, x_cam, weight, estimate_scale=True
)
s = s.clamp(eps)
x_cam = x_cam / s[:, None, None]
T = T / s[:, None]
x_w_rotated = torch.matmul(x_world, R) + T[:, None, :]
@ -219,7 +220,11 @@ def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx):
return torch.matmul(torch.pinverse(lhs), rhs[:, :, None])
def _find_null_space_coords_1(kernel_dsts, cw_dst):
def _binary_sign(t):
return (t >= 0).to(t) * 2.0 - 1.0
def _find_null_space_coords_1(kernel_dsts, cw_dst, eps=1e-9):
""" Solves case 1 from the paper [1]; solve for 4 coefficients:
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
^ ^ ^ ^
@ -235,8 +240,8 @@ def _find_null_space_coords_1(kernel_dsts, cw_dst):
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6])
beta = beta * beta[:, :1, :].sign()
return beta / (beta[:, :1, :] ** 0.5)
beta = beta * _binary_sign(beta[:, :1, :])
return beta / torch.clamp(beta[:, :1, :] ** 0.5, eps)
def _find_null_space_coords_2(kernel_dsts, cw_dst):
@ -255,7 +260,7 @@ def _find_null_space_coords_2(kernel_dsts, cw_dst):
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1])
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :])
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
).float()
@ -263,7 +268,7 @@ def _find_null_space_coords_2(kernel_dsts, cw_dst):
return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1)
def _find_null_space_coords_3(kernel_dsts, cw_dst):
def _find_null_space_coords_3(kernel_dsts, cw_dst, eps=1e-9):
""" Solves case 3 from the paper; solve for 5 coefficients:
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
^ ^ ^ ^ ^
@ -279,11 +284,11 @@ def _find_null_space_coords_3(kernel_dsts, cw_dst):
"""
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7])
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :])
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
).float()
coord_2 = beta[:, 3:4, :] / coord_0[:, :1, :]
coord_2 = beta[:, 3:4, :] / torch.clamp(coord_0[:, :1, :], eps)
return torch.cat(
(coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1

View File

@ -51,7 +51,7 @@ class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
return t.norm(dim=-1)
self.assertNormsClose(
T, sol.T[:, None, :], rtol=1e-2, norm_fn=norm_fn, msg=assert_msg
T, sol.T[:, None, :], rtol=3e-3, norm_fn=norm_fn, msg=assert_msg
)
self.assertNormsClose(
R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg