From a8377f1f069e0099bc82aafd04d2546ee65ede87 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Fri, 15 May 2020 01:35:00 -0700 Subject: [PATCH] Numerical stability of ePnP. Summary: lg-zhang found the problem with the quadratic part of ePnP implementation: n262385 . It was caused by a coefficient returned from the linear equation solver being equal to exactly 0.0, which caused `sign()` to return 0, something I had not anticipated. I also made sure we avoid division by zero by clamping all relevant denominators. Reviewed By: nikhilaravi, lg-zhang Differential Revision: D21531200 fbshipit-source-id: 9eb2fa9d4f4f8f5f411d4cf1cffcc44b365b7e51 --- pytorch3d/ops/perspective_n_points.py | 23 ++++++++++++++--------- tests/test_perspective_n_points.py | 2 +- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/pytorch3d/ops/perspective_n_points.py b/pytorch3d/ops/perspective_n_points.py index 34b99861..5452f58d 100644 --- a/pytorch3d/ops/perspective_n_points.py +++ b/pytorch3d/ops/perspective_n_points.py @@ -104,7 +104,7 @@ def _null_space(m, kernel_dim): return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim] -def _reproj_error(y_hat, y, weight): +def _reproj_error(y_hat, y, weight, eps=1e-9): """ Projects estimated 3D points and computes the reprojection error Args: y_hat: a batch of predicted 2D points in homogeneous coordinates @@ -114,7 +114,7 @@ def _reproj_error(y_hat, y, weight): Returns: Optionally weighted RMSE of difference between y and y_hat. """ - y_hat = y_hat / y_hat[..., 2:] + y_hat = y_hat / torch.clamp(y_hat[..., 2:], eps) dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5 return oputil.wmean(dist, weight)[:, 0, 0] @@ -155,6 +155,7 @@ def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e- R, T, s = points_alignment.corresponding_points_alignment( x_world, x_cam, weight, estimate_scale=True ) + s = s.clamp(eps) x_cam = x_cam / s[:, None, None] T = T / s[:, None] x_w_rotated = torch.matmul(x_world, R) + T[:, None, :] @@ -219,7 +220,11 @@ def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx): return torch.matmul(torch.pinverse(lhs), rhs[:, :, None]) -def _find_null_space_coords_1(kernel_dsts, cw_dst): +def _binary_sign(t): + return (t >= 0).to(t) * 2.0 - 1.0 + + +def _find_null_space_coords_1(kernel_dsts, cw_dst, eps=1e-9): """ Solves case 1 from the paper [1]; solve for 4 coefficients: [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34] ^ ^ ^ ^ @@ -235,8 +240,8 @@ def _find_null_space_coords_1(kernel_dsts, cw_dst): """ beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6]) - beta = beta * beta[:, :1, :].sign() - return beta / (beta[:, :1, :] ** 0.5) + beta = beta * _binary_sign(beta[:, :1, :]) + return beta / torch.clamp(beta[:, :1, :] ** 0.5, eps) def _find_null_space_coords_2(kernel_dsts, cw_dst): @@ -255,7 +260,7 @@ def _find_null_space_coords_2(kernel_dsts, cw_dst): """ beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1]) - coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign() + coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :]) coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * ( (beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0) ).float() @@ -263,7 +268,7 @@ def _find_null_space_coords_2(kernel_dsts, cw_dst): return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1) -def _find_null_space_coords_3(kernel_dsts, cw_dst): +def _find_null_space_coords_3(kernel_dsts, cw_dst, eps=1e-9): """ Solves case 3 from the paper; solve for 5 coefficients: [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34] ^ ^ ^ ^ ^ @@ -279,11 +284,11 @@ def _find_null_space_coords_3(kernel_dsts, cw_dst): """ beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7]) - coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign() + coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :]) coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * ( (beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0) ).float() - coord_2 = beta[:, 3:4, :] / coord_0[:, :1, :] + coord_2 = beta[:, 3:4, :] / torch.clamp(coord_0[:, :1, :], eps) return torch.cat( (coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1 diff --git a/tests/test_perspective_n_points.py b/tests/test_perspective_n_points.py index feaefe66..c46dbf75 100644 --- a/tests/test_perspective_n_points.py +++ b/tests/test_perspective_n_points.py @@ -51,7 +51,7 @@ class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase): return t.norm(dim=-1) self.assertNormsClose( - T, sol.T[:, None, :], rtol=1e-2, norm_fn=norm_fn, msg=assert_msg + T, sol.T[:, None, :], rtol=3e-3, norm_fn=norm_fn, msg=assert_msg ) self.assertNormsClose( R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg