mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Numerical stability of ePnP.
Summary: lg-zhang found the problem with the quadratic part of ePnP implementation: n262385 . It was caused by a coefficient returned from the linear equation solver being equal to exactly 0.0, which caused `sign()` to return 0, something I had not anticipated. I also made sure we avoid division by zero by clamping all relevant denominators. Reviewed By: nikhilaravi, lg-zhang Differential Revision: D21531200 fbshipit-source-id: 9eb2fa9d4f4f8f5f411d4cf1cffcc44b365b7e51
This commit is contained in:
parent
a0e14cae1e
commit
a8377f1f06
@ -104,7 +104,7 @@ def _null_space(m, kernel_dim):
|
||||
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
|
||||
|
||||
|
||||
def _reproj_error(y_hat, y, weight):
|
||||
def _reproj_error(y_hat, y, weight, eps=1e-9):
|
||||
""" Projects estimated 3D points and computes the reprojection error
|
||||
Args:
|
||||
y_hat: a batch of predicted 2D points in homogeneous coordinates
|
||||
@ -114,7 +114,7 @@ def _reproj_error(y_hat, y, weight):
|
||||
Returns:
|
||||
Optionally weighted RMSE of difference between y and y_hat.
|
||||
"""
|
||||
y_hat = y_hat / y_hat[..., 2:]
|
||||
y_hat = y_hat / torch.clamp(y_hat[..., 2:], eps)
|
||||
dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5
|
||||
return oputil.wmean(dist, weight)[:, 0, 0]
|
||||
|
||||
@ -155,6 +155,7 @@ def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e-
|
||||
R, T, s = points_alignment.corresponding_points_alignment(
|
||||
x_world, x_cam, weight, estimate_scale=True
|
||||
)
|
||||
s = s.clamp(eps)
|
||||
x_cam = x_cam / s[:, None, None]
|
||||
T = T / s[:, None]
|
||||
x_w_rotated = torch.matmul(x_world, R) + T[:, None, :]
|
||||
@ -219,7 +220,11 @@ def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx):
|
||||
return torch.matmul(torch.pinverse(lhs), rhs[:, :, None])
|
||||
|
||||
|
||||
def _find_null_space_coords_1(kernel_dsts, cw_dst):
|
||||
def _binary_sign(t):
|
||||
return (t >= 0).to(t) * 2.0 - 1.0
|
||||
|
||||
|
||||
def _find_null_space_coords_1(kernel_dsts, cw_dst, eps=1e-9):
|
||||
""" Solves case 1 from the paper [1]; solve for 4 coefficients:
|
||||
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
|
||||
^ ^ ^ ^
|
||||
@ -235,8 +240,8 @@ def _find_null_space_coords_1(kernel_dsts, cw_dst):
|
||||
"""
|
||||
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6])
|
||||
|
||||
beta = beta * beta[:, :1, :].sign()
|
||||
return beta / (beta[:, :1, :] ** 0.5)
|
||||
beta = beta * _binary_sign(beta[:, :1, :])
|
||||
return beta / torch.clamp(beta[:, :1, :] ** 0.5, eps)
|
||||
|
||||
|
||||
def _find_null_space_coords_2(kernel_dsts, cw_dst):
|
||||
@ -255,7 +260,7 @@ def _find_null_space_coords_2(kernel_dsts, cw_dst):
|
||||
"""
|
||||
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1])
|
||||
|
||||
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
|
||||
coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :])
|
||||
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
|
||||
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
|
||||
).float()
|
||||
@ -263,7 +268,7 @@ def _find_null_space_coords_2(kernel_dsts, cw_dst):
|
||||
return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1)
|
||||
|
||||
|
||||
def _find_null_space_coords_3(kernel_dsts, cw_dst):
|
||||
def _find_null_space_coords_3(kernel_dsts, cw_dst, eps=1e-9):
|
||||
""" Solves case 3 from the paper; solve for 5 coefficients:
|
||||
[B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
|
||||
^ ^ ^ ^ ^
|
||||
@ -279,11 +284,11 @@ def _find_null_space_coords_3(kernel_dsts, cw_dst):
|
||||
"""
|
||||
beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7])
|
||||
|
||||
coord_0 = (beta[:, :1, :].abs() ** 0.5) * beta[:, 1:2, :].sign()
|
||||
coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :])
|
||||
coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
|
||||
(beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
|
||||
).float()
|
||||
coord_2 = beta[:, 3:4, :] / coord_0[:, :1, :]
|
||||
coord_2 = beta[:, 3:4, :] / torch.clamp(coord_0[:, :1, :], eps)
|
||||
|
||||
return torch.cat(
|
||||
(coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1
|
||||
|
@ -51,7 +51,7 @@ class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase):
|
||||
return t.norm(dim=-1)
|
||||
|
||||
self.assertNormsClose(
|
||||
T, sol.T[:, None, :], rtol=1e-2, norm_fn=norm_fn, msg=assert_msg
|
||||
T, sol.T[:, None, :], rtol=3e-3, norm_fn=norm_fn, msg=assert_msg
|
||||
)
|
||||
self.assertNormsClose(
|
||||
R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg
|
||||
|
Loading…
x
Reference in New Issue
Block a user