Fix returning a proper rotation in levelling; supporting batches and default centroid

Summary:
`get_rotation_to_best_fit_xy` is useful to expose externally, however there was a bug (which we probably did not care about for our use case): it could return a rotation matrix with det(R) == −1.
The diff fixes that, and also makes centroid optional (it can be computed from points).

Reviewed By: bottler

Differential Revision: D39926791

fbshipit-source-id: 5120c7892815b829f3ddcc23e93d4a5ec0ca0013
This commit is contained in:
Roman Shapovalov
2022-09-29 11:56:14 -07:00
committed by Facebook GitHub Bot
parent de98c9cc2f
commit 74bbd6fd76
2 changed files with 45 additions and 10 deletions

View File

@@ -12,8 +12,9 @@ from pytorch3d.implicitron.tools.circle_fitting import (
_signed_area,
fit_circle_in_2d,
fit_circle_in_3d,
get_rotation_to_best_fit_xy,
)
from pytorch3d.transforms import random_rotation
from pytorch3d.transforms import random_rotation, random_rotations
from tests.common_testing import TestCaseMixin
@@ -28,6 +29,32 @@ class TestCircleFitting(TestCaseMixin, unittest.TestCase):
"""
self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs)
def test_plane_levelling(self):
device = torch.device("cuda:0")
B = 16
N = 1024
random = torch.randn((B, N, 3), device=device)
# first, check that we always return a vaild rotation
rot = get_rotation_to_best_fit_xy(random)
self.assertClose(rot.det(), torch.ones_like(rot[:, 0, 0]))
self.assertClose(rot.norm(dim=-1), torch.ones_like(rot[:, 0]))
# then, check the result is what we expect
z_squeeze = 0.1
random[..., -1] *= z_squeeze
rot_gt = random_rotations(B, device=device)
rotated = random @ rot_gt.transpose(-1, -2)
rot_hat = get_rotation_to_best_fit_xy(rotated)
self.assertClose(rot.det(), torch.ones_like(rot[:, 0, 0]))
self.assertClose(rot.norm(dim=-1), torch.ones_like(rot[:, 0]))
# covariance matrix of the levelled points is by design diag(1, 1, z_squeeze²)
self.assertClose(
(rotated @ rot_hat)[..., -1].std(dim=-1),
torch.ones_like(rot_hat[:, 0, 0]) * z_squeeze,
rtol=0.1,
)
def test_simple_3d(self):
device = torch.device("cuda:0")
for _ in range(7):