mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 03:40:34 +08:00
Weighted Umeyama.
Summary: 1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own. 2. Refactored to use the same code for the Pointclouds mask and passed weights. 3. Added test cases with random weights. 4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: https://github.com/pytorch/pytorch/issues/31421 ). Reviewed By: gkioxari Differential Revision: D20070293 fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e5b1d6d3a3
commit
e37085d999
@@ -6,6 +6,8 @@ import numpy as np
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
from pytorch3d.ops import points_alignment
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.transforms import rotation_conversions
|
||||
@@ -35,7 +37,7 @@ def _apply_pcl_transformation(X, R, T, s=None):
|
||||
return X_t
|
||||
|
||||
|
||||
class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
@@ -171,6 +173,7 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
estimate_scale=False,
|
||||
allow_reflection=False,
|
||||
reflect=False,
|
||||
random_weights=False,
|
||||
):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
@@ -198,12 +201,27 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
# point cloud X
|
||||
X_t = _apply_pcl_transformation(X, R, T, s=s)
|
||||
|
||||
weights = None
|
||||
if random_weights:
|
||||
template = X.points_padded() if use_pointclouds else X
|
||||
weights = torch.rand_like(template[:, :, 0])
|
||||
weights = weights / weights.sum(dim=1, keepdim=True)
|
||||
# zero out some weights as zero weights are a common use case
|
||||
# this guarantees there are no zero weight
|
||||
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
||||
if use_pointclouds: # convert to List[Tensor]
|
||||
weights = [
|
||||
w[:npts]
|
||||
for w, npts in zip(weights, X.num_points_per_cloud())
|
||||
]
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_corresponding_points_alignment():
|
||||
points_alignment.corresponding_points_alignment(
|
||||
X,
|
||||
X_t,
|
||||
weights,
|
||||
allow_reflection=allow_reflection,
|
||||
estimate_scale=estimate_scale,
|
||||
)
|
||||
@@ -230,26 +248,28 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
"""
|
||||
|
||||
# run this for several different point cloud sizes
|
||||
for n_points in (100, 3, 2, 1, 0):
|
||||
for n_points in (100, 3, 2, 1):
|
||||
# run this for several different dimensionalities
|
||||
for dim in torch.arange(2, 10):
|
||||
for dim in range(2, 10):
|
||||
# switches whether we should use the Pointclouds inputs
|
||||
use_point_clouds_cases = (
|
||||
(True, False) if dim == 3 and n_points > 3 else (False,)
|
||||
)
|
||||
for use_pointclouds in use_point_clouds_cases:
|
||||
for estimate_scale in (False, True):
|
||||
for reflect in (False, True):
|
||||
for allow_reflection in (False, True):
|
||||
self._test_single_corresponding_points_alignment(
|
||||
batch_size=10,
|
||||
n_points=n_points,
|
||||
dim=int(dim),
|
||||
use_pointclouds=use_pointclouds,
|
||||
estimate_scale=estimate_scale,
|
||||
reflect=reflect,
|
||||
allow_reflection=allow_reflection,
|
||||
)
|
||||
for random_weights in (False, True,):
|
||||
for use_pointclouds in use_point_clouds_cases:
|
||||
for estimate_scale in (False, True):
|
||||
for reflect in (False, True):
|
||||
for allow_reflection in (False, True):
|
||||
self._test_single_corresponding_points_alignment(
|
||||
batch_size=10,
|
||||
n_points=n_points,
|
||||
dim=dim,
|
||||
use_pointclouds=use_pointclouds,
|
||||
estimate_scale=estimate_scale,
|
||||
reflect=reflect,
|
||||
allow_reflection=allow_reflection,
|
||||
random_weights=random_weights,
|
||||
)
|
||||
|
||||
def _test_single_corresponding_points_alignment(
|
||||
self,
|
||||
@@ -260,6 +280,7 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
estimate_scale=False,
|
||||
reflect=False,
|
||||
allow_reflection=False,
|
||||
random_weights=False,
|
||||
):
|
||||
"""
|
||||
Executes a single test for `corresponding_points_alignment` for a
|
||||
@@ -294,6 +315,20 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
)
|
||||
R = torch.bmm(M, R)
|
||||
|
||||
weights = None
|
||||
if random_weights:
|
||||
template = X.points_padded() if use_pointclouds else X
|
||||
weights = torch.rand_like(template[:, :, 0])
|
||||
weights = weights / weights.sum(dim=1, keepdim=True)
|
||||
# zero out some weights as zero weights are a common use case
|
||||
# this guarantees there are no zero weight
|
||||
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
||||
if use_pointclouds: # convert to List[Tensor]
|
||||
weights = [
|
||||
w[:npts]
|
||||
for w, npts in zip(weights, X.num_points_per_cloud())
|
||||
]
|
||||
|
||||
# apply the generated transformation to the generated
|
||||
# point cloud X
|
||||
X_t = _apply_pcl_transformation(X, R, T, s=s)
|
||||
@@ -302,6 +337,7 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
R_est, T_est, s_est = points_alignment.corresponding_points_alignment(
|
||||
X,
|
||||
X_t,
|
||||
weights,
|
||||
allow_reflection=allow_reflection,
|
||||
estimate_scale=estimate_scale,
|
||||
)
|
||||
@@ -313,9 +349,40 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
f"use_pointclouds={use_pointclouds}, "
|
||||
f"estimate_scale={estimate_scale}, "
|
||||
f"reflect={reflect}, "
|
||||
f"allow_reflection={allow_reflection}."
|
||||
f"allow_reflection={allow_reflection},"
|
||||
f"random_weights={random_weights}."
|
||||
)
|
||||
|
||||
# if we test the weighted case, check that weights help with noise
|
||||
if random_weights and not use_pointclouds and n_points >= (dim + 10):
|
||||
# add noise to 20% points with smallest weight
|
||||
X_noisy = X_t.clone()
|
||||
_, mink_idx = torch.topk(-weights, int(n_points * 0.2), dim=1)
|
||||
mink_idx = mink_idx[:, :, None].expand(-1, -1, X_t.shape[-1])
|
||||
X_noisy.scatter_add_(
|
||||
1, mink_idx, 0.3 * torch.randn_like(mink_idx, dtype=X_t.dtype)
|
||||
)
|
||||
|
||||
def align_and_get_mse(weights_):
|
||||
R_n, T_n, s_n = points_alignment.corresponding_points_alignment(
|
||||
X_noisy,
|
||||
X_t,
|
||||
weights_,
|
||||
allow_reflection=allow_reflection,
|
||||
estimate_scale=estimate_scale,
|
||||
)
|
||||
|
||||
X_t_est = _apply_pcl_transformation(X_noisy, R_n, T_n, s=s_n)
|
||||
|
||||
return (
|
||||
((X_t_est - X_t) * weights[..., None]) ** 2
|
||||
).sum(dim=(1, 2)) / weights.sum(dim=-1)
|
||||
|
||||
# check that using weights leads to lower weighted_MSE(X_noisy, X_t)
|
||||
self.assertTrue(
|
||||
torch.all(align_and_get_mse(weights) <= align_and_get_mse(None))
|
||||
)
|
||||
|
||||
if reflect and not allow_reflection:
|
||||
# check that all rotations have det=1
|
||||
self._assert_all_close(
|
||||
@@ -325,34 +392,44 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
)
|
||||
|
||||
else:
|
||||
# mask out inputs with too few non-degenerate points for assertions
|
||||
w = (
|
||||
torch.ones_like(R_est[:, 0, 0])
|
||||
if weights is None or n_points >= dim + 10
|
||||
else (weights > 0.0).all(dim=1).to(R_est)
|
||||
)
|
||||
# check that the estimated tranformation is the same
|
||||
# as the ground truth
|
||||
if n_points >= (dim + 1):
|
||||
# the checks on transforms apply only when
|
||||
# the problem setup is unambiguous
|
||||
self._assert_all_close(R_est, R, assert_error_message)
|
||||
self._assert_all_close(T_est, T, assert_error_message)
|
||||
self._assert_all_close(s_est, s, assert_error_message)
|
||||
msg = assert_error_message
|
||||
self._assert_all_close(R_est, R, msg, w[:, None, None], atol=1e-5)
|
||||
self._assert_all_close(T_est, T, msg, w[:, None])
|
||||
self._assert_all_close(s_est, s, msg, w)
|
||||
|
||||
# check that the orthonormal part of the
|
||||
# transformation has a correct determinant (+1/-1)
|
||||
desired_det = R_est.new_ones(batch_size)
|
||||
if reflect:
|
||||
desired_det *= -1.0
|
||||
self._assert_all_close(
|
||||
torch.det(R_est), desired_det, assert_error_message
|
||||
)
|
||||
self._assert_all_close(torch.det(R_est), desired_det, msg, w)
|
||||
|
||||
# check that the transformed point cloud
|
||||
# X matches X_t
|
||||
X_t_est = _apply_pcl_transformation(X, R_est, T_est, s=s_est)
|
||||
self._assert_all_close(
|
||||
X_t, X_t_est, assert_error_message, atol=1e-5
|
||||
X_t, X_t_est, assert_error_message, w[:, None, None], atol=1e-5
|
||||
)
|
||||
|
||||
def _assert_all_close(self, a_, b_, err_message, atol=1e-6):
|
||||
def _assert_all_close(self, a_, b_, err_message, weights=None, atol=1e-6):
|
||||
if isinstance(a_, Pointclouds):
|
||||
a_ = a_.points_packed()
|
||||
if isinstance(b_, Pointclouds):
|
||||
b_ = b_.points_packed()
|
||||
self.assertTrue(torch.allclose(a_, b_, atol=atol), err_message)
|
||||
if weights is None:
|
||||
self.assertClose(a_, b_, atol=atol, msg=err_message)
|
||||
else:
|
||||
self.assertClose(
|
||||
a_ * weights, b_ * weights, atol=atol, msg=err_message
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user