mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Umeyama estimates a rigid motion between two sets of corresponding points. Benchmark output for `bm_points_alignment` ``` Arguments key: [<allow_reflection>_<batch_size>_<dim>_<estimate_scale>_<n_points>_<use_pointclouds>] Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- CorrespodingPointsAlignment_True_1_3_True_100_False 7382 9833 68 CorrespodingPointsAlignment_True_1_3_True_10000_False 8183 10500 62 CorrespodingPointsAlignment_True_1_3_False_100_False 7301 9263 69 CorrespodingPointsAlignment_True_1_3_False_10000_False 7945 9746 64 CorrespodingPointsAlignment_True_1_20_True_100_False 13706 41623 37 CorrespodingPointsAlignment_True_1_20_True_10000_False 11044 33766 46 CorrespodingPointsAlignment_True_1_20_False_100_False 9908 28791 51 CorrespodingPointsAlignment_True_1_20_False_10000_False 9523 18680 53 CorrespodingPointsAlignment_True_10_3_True_100_False 29585 32026 17 CorrespodingPointsAlignment_True_10_3_True_10000_False 29626 36324 18 CorrespodingPointsAlignment_True_10_3_False_100_False 26013 29253 20 CorrespodingPointsAlignment_True_10_3_False_10000_False 25000 33820 20 CorrespodingPointsAlignment_True_10_20_True_100_False 40955 41592 13 CorrespodingPointsAlignment_True_10_20_True_10000_False 42087 42393 12 CorrespodingPointsAlignment_True_10_20_False_100_False 39863 40381 13 CorrespodingPointsAlignment_True_10_20_False_10000_False 40813 41699 13 CorrespodingPointsAlignment_True_100_3_True_100_False 183146 194745 3 CorrespodingPointsAlignment_True_100_3_True_10000_False 213789 231466 3 CorrespodingPointsAlignment_True_100_3_False_100_False 177805 180796 3 CorrespodingPointsAlignment_True_100_3_False_10000_False 184963 185695 3 CorrespodingPointsAlignment_True_100_20_True_100_False 347181 347325 2 CorrespodingPointsAlignment_True_100_20_True_10000_False 363259 363613 2 CorrespodingPointsAlignment_True_100_20_False_100_False 351769 352496 2 CorrespodingPointsAlignment_True_100_20_False_10000_False 375629 379818 2 CorrespodingPointsAlignment_False_1_3_True_100_False 11155 13770 45 CorrespodingPointsAlignment_False_1_3_True_10000_False 10743 13938 47 CorrespodingPointsAlignment_False_1_3_False_100_False 9578 11511 53 CorrespodingPointsAlignment_False_1_3_False_10000_False 9549 11984 53 CorrespodingPointsAlignment_False_1_20_True_100_False 13809 14183 37 CorrespodingPointsAlignment_False_1_20_True_10000_False 14084 15082 36 CorrespodingPointsAlignment_False_1_20_False_100_False 12765 14177 40 CorrespodingPointsAlignment_False_1_20_False_10000_False 12811 13096 40 CorrespodingPointsAlignment_False_10_3_True_100_False 28823 39384 18 CorrespodingPointsAlignment_False_10_3_True_10000_False 27135 27525 19 CorrespodingPointsAlignment_False_10_3_False_100_False 26236 28980 20 CorrespodingPointsAlignment_False_10_3_False_10000_False 42324 45123 12 CorrespodingPointsAlignment_False_10_20_True_100_False 723902 723902 1 CorrespodingPointsAlignment_False_10_20_True_10000_False 220007 252886 3 CorrespodingPointsAlignment_False_10_20_False_100_False 55593 71636 9 CorrespodingPointsAlignment_False_10_20_False_10000_False 44419 71861 12 CorrespodingPointsAlignment_False_100_3_True_100_False 184768 185199 3 CorrespodingPointsAlignment_False_100_3_True_10000_False 198657 213868 3 CorrespodingPointsAlignment_False_100_3_False_100_False 224598 309645 3 CorrespodingPointsAlignment_False_100_3_False_10000_False 197863 202002 3 CorrespodingPointsAlignment_False_100_20_True_100_False 293484 309459 2 CorrespodingPointsAlignment_False_100_20_True_10000_False 327253 366644 2 CorrespodingPointsAlignment_False_100_20_False_100_False 420793 422194 2 CorrespodingPointsAlignment_False_100_20_False_10000_False 462634 485542 2 CorrespodingPointsAlignment_True_1_3_True_100_True 7664 9909 66 CorrespodingPointsAlignment_True_1_3_True_10000_True 7190 8366 70 CorrespodingPointsAlignment_True_1_3_False_100_True 6549 8316 77 CorrespodingPointsAlignment_True_1_3_False_10000_True 6534 7710 77 CorrespodingPointsAlignment_True_10_3_True_100_True 29052 32940 18 CorrespodingPointsAlignment_True_10_3_True_10000_True 30526 33453 17 CorrespodingPointsAlignment_True_10_3_False_100_True 28708 32993 18 CorrespodingPointsAlignment_True_10_3_False_10000_True 30630 35973 17 CorrespodingPointsAlignment_True_100_3_True_100_True 264909 320820 3 CorrespodingPointsAlignment_True_100_3_True_10000_True 310902 322604 2 CorrespodingPointsAlignment_True_100_3_False_100_True 246832 250634 3 CorrespodingPointsAlignment_True_100_3_False_10000_True 276006 289061 2 CorrespodingPointsAlignment_False_1_3_True_100_True 11421 13757 44 CorrespodingPointsAlignment_False_1_3_True_10000_True 11199 12532 45 CorrespodingPointsAlignment_False_1_3_False_100_True 11474 15841 44 CorrespodingPointsAlignment_False_1_3_False_10000_True 10384 13188 49 CorrespodingPointsAlignment_False_10_3_True_100_True 36599 47340 14 CorrespodingPointsAlignment_False_10_3_True_10000_True 40702 50754 13 CorrespodingPointsAlignment_False_10_3_False_100_True 41277 52149 13 CorrespodingPointsAlignment_False_10_3_False_10000_True 34286 37091 15 CorrespodingPointsAlignment_False_100_3_True_100_True 254991 258578 2 CorrespodingPointsAlignment_False_100_3_True_10000_True 257999 261285 2 CorrespodingPointsAlignment_False_100_3_False_100_True 247511 248693 3 CorrespodingPointsAlignment_False_100_3_False_10000_True 251807 263865 3 ``` Reviewed By: gkioxari Differential Revision: D19808389 fbshipit-source-id: 83305a58627d2fc5dcaf3c3015132d8148f28c29
359 lines
12 KiB
Python
359 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
|
|
import numpy as np
|
|
import unittest
|
|
import torch
|
|
|
|
from pytorch3d.ops import points_alignment
|
|
from pytorch3d.structures.pointclouds import Pointclouds
|
|
from pytorch3d.transforms import rotation_conversions
|
|
|
|
|
|
def _apply_pcl_transformation(X, R, T, s=None):
|
|
"""
|
|
Apply a batch of similarity/rigid transformations, parametrized with
|
|
rotation `R`, translation `T` and scale `s`, to an input batch of
|
|
point clouds `X`.
|
|
"""
|
|
if isinstance(X, Pointclouds):
|
|
num_points = X.num_points_per_cloud()
|
|
X_t = X.points_padded()
|
|
else:
|
|
X_t = X
|
|
|
|
if s is not None:
|
|
X_t = s[:, None, None] * X_t
|
|
|
|
X_t = torch.bmm(X_t, R) + T[:, None, :]
|
|
|
|
if isinstance(X, Pointclouds):
|
|
X_list = [x[:n_p] for x, n_p in zip(X_t, num_points)]
|
|
X_t = Pointclouds(X_list)
|
|
|
|
return X_t
|
|
|
|
|
|
class TestCorrespondingPointsAlignment(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
torch.manual_seed(42)
|
|
np.random.seed(42)
|
|
|
|
@staticmethod
|
|
def random_rotation(batch_size, dim, device=None):
|
|
"""
|
|
Generates a batch of random `dim`-dimensional rotation matrices.
|
|
"""
|
|
if dim == 3:
|
|
R = rotation_conversions.random_rotations(batch_size, device=device)
|
|
else:
|
|
# generate random rotation matrices with orthogonalization of
|
|
# random normal square matrices, followed by a transformation
|
|
# that ensures determinant(R)==1
|
|
H = torch.randn(
|
|
batch_size, dim, dim, dtype=torch.float32, device=device
|
|
)
|
|
U, _, V = torch.svd(H)
|
|
E = torch.eye(dim, dtype=torch.float32, device=device)[None].repeat(
|
|
batch_size, 1, 1
|
|
)
|
|
E[:, -1, -1] = torch.det(torch.bmm(U, V.transpose(2, 1)))
|
|
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
|
|
assert torch.allclose(
|
|
torch.det(R), R.new_ones(batch_size), atol=1e-4
|
|
)
|
|
|
|
return R
|
|
|
|
@staticmethod
|
|
def init_point_cloud(
|
|
batch_size=10,
|
|
n_points=1000,
|
|
dim=3,
|
|
device=None,
|
|
use_pointclouds=False,
|
|
random_pcl_size=True,
|
|
):
|
|
"""
|
|
Generate a batch of normally distributed point clouds.
|
|
"""
|
|
if use_pointclouds:
|
|
assert dim == 3, "Pointclouds support only 3-dim points."
|
|
# generate a `batch_size` point clouds with number of points
|
|
# between 4 and `n_points`
|
|
if random_pcl_size:
|
|
n_points_per_batch = torch.randint(
|
|
low=4,
|
|
high=n_points,
|
|
size=(batch_size,),
|
|
device=device,
|
|
dtype=torch.int64,
|
|
)
|
|
X_list = [
|
|
torch.randn(
|
|
int(n_pt), dim, device=device, dtype=torch.float32
|
|
)
|
|
for n_pt in n_points_per_batch
|
|
]
|
|
X = Pointclouds(X_list)
|
|
else:
|
|
X = torch.randn(
|
|
batch_size,
|
|
n_points,
|
|
dim,
|
|
device=device,
|
|
dtype=torch.float32,
|
|
)
|
|
X = Pointclouds(list(X))
|
|
else:
|
|
X = torch.randn(
|
|
batch_size, n_points, dim, device=device, dtype=torch.float32
|
|
)
|
|
return X
|
|
|
|
@staticmethod
|
|
def generate_pcl_transformation(
|
|
batch_size=10, scale=False, reflect=False, dim=3, device=None
|
|
):
|
|
"""
|
|
Generate a batch of random rigid/similarity transformations.
|
|
"""
|
|
R = TestCorrespondingPointsAlignment.random_rotation(
|
|
batch_size, dim, device=device
|
|
)
|
|
T = torch.randn(batch_size, dim, dtype=torch.float32, device=device)
|
|
if scale:
|
|
s = torch.rand(batch_size, dtype=torch.float32, device=device) + 0.1
|
|
else:
|
|
s = torch.ones(batch_size, dtype=torch.float32, device=device)
|
|
|
|
return R, T, s
|
|
|
|
@staticmethod
|
|
def generate_random_reflection(batch_size=10, dim=3, device=None):
|
|
"""
|
|
Generate a batch of reflection matrices of shape (batch_size, dim, dim),
|
|
where M_i is an identity matrix with one random entry on the
|
|
diagonal equal to -1.
|
|
"""
|
|
# randomly select one of the dimensions to reflect for each
|
|
# element in the batch
|
|
dim_to_reflect = torch.randint(
|
|
low=0,
|
|
high=dim,
|
|
size=(batch_size,),
|
|
device=device,
|
|
dtype=torch.int64,
|
|
)
|
|
|
|
# convert dim_to_reflect to a batch of reflection matrices M
|
|
M = torch.diag_embed(
|
|
(
|
|
dim_to_reflect[:, None]
|
|
!= torch.arange(dim, device=device, dtype=torch.float32)
|
|
).float()
|
|
* 2
|
|
- 1,
|
|
dim1=1,
|
|
dim2=2,
|
|
)
|
|
|
|
return M
|
|
|
|
@staticmethod
|
|
def corresponding_points_alignment(
|
|
batch_size=10,
|
|
n_points=100,
|
|
dim=3,
|
|
use_pointclouds=False,
|
|
estimate_scale=False,
|
|
allow_reflection=False,
|
|
reflect=False,
|
|
):
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
# initialize a ground truth point cloud
|
|
X = TestCorrespondingPointsAlignment.init_point_cloud(
|
|
batch_size=batch_size,
|
|
n_points=n_points,
|
|
dim=dim,
|
|
device=device,
|
|
use_pointclouds=use_pointclouds,
|
|
random_pcl_size=True,
|
|
)
|
|
|
|
# generate the true transformation
|
|
R, T, s = TestCorrespondingPointsAlignment.generate_pcl_transformation(
|
|
batch_size=batch_size,
|
|
scale=estimate_scale,
|
|
reflect=reflect,
|
|
dim=dim,
|
|
device=device,
|
|
)
|
|
|
|
# apply the generated transformation to the generated
|
|
# point cloud X
|
|
X_t = _apply_pcl_transformation(X, R, T, s=s)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
def run_corresponding_points_alignment():
|
|
points_alignment.corresponding_points_alignment(
|
|
X,
|
|
X_t,
|
|
allow_reflection=allow_reflection,
|
|
estimate_scale=estimate_scale,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
return run_corresponding_points_alignment
|
|
|
|
def test_corresponding_points_alignment(self, batch_size=10):
|
|
"""
|
|
Tests whether we can estimate a rigid/similarity motion between
|
|
a randomly initialized point cloud and its randomly transformed version.
|
|
|
|
The tests are done for all possible combinations
|
|
of the following boolean flags:
|
|
- estimate_scale ... Estimate also a scaling component of
|
|
the transformation.
|
|
- reflect ... The ground truth orthonormal part of the generated
|
|
transformation is a reflection (det==-1).
|
|
- allow_reflection ... If True, the orthonormal matrix of the
|
|
estimated transformation is allowed to be
|
|
a reflection (det==-1).
|
|
- use_pointclouds ... If True, passes the Pointclouds objects
|
|
to corresponding_points_alignment.
|
|
"""
|
|
|
|
# run this for several different point cloud sizes
|
|
for n_points in (100, 3, 2, 1, 0):
|
|
# run this for several different dimensionalities
|
|
for dim in torch.arange(2, 10):
|
|
# switches whether we should use the Pointclouds inputs
|
|
use_point_clouds_cases = (
|
|
(True, False) if dim == 3 and n_points > 3 else (False,)
|
|
)
|
|
for use_pointclouds in use_point_clouds_cases:
|
|
for estimate_scale in (False, True):
|
|
for reflect in (False, True):
|
|
for allow_reflection in (False, True):
|
|
self._test_single_corresponding_points_alignment(
|
|
batch_size=10,
|
|
n_points=n_points,
|
|
dim=int(dim),
|
|
use_pointclouds=use_pointclouds,
|
|
estimate_scale=estimate_scale,
|
|
reflect=reflect,
|
|
allow_reflection=allow_reflection,
|
|
)
|
|
|
|
def _test_single_corresponding_points_alignment(
|
|
self,
|
|
batch_size=10,
|
|
n_points=100,
|
|
dim=3,
|
|
use_pointclouds=False,
|
|
estimate_scale=False,
|
|
reflect=False,
|
|
allow_reflection=False,
|
|
):
|
|
"""
|
|
Executes a single test for `corresponding_points_alignment` for a
|
|
specific setting of the inputs / outputs.
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
# initialize the a ground truth point cloud
|
|
X = TestCorrespondingPointsAlignment.init_point_cloud(
|
|
batch_size=batch_size,
|
|
n_points=n_points,
|
|
dim=dim,
|
|
device=device,
|
|
use_pointclouds=use_pointclouds,
|
|
random_pcl_size=True,
|
|
)
|
|
|
|
# generate the true transformation
|
|
R, T, s = TestCorrespondingPointsAlignment.generate_pcl_transformation(
|
|
batch_size=batch_size,
|
|
scale=estimate_scale,
|
|
reflect=reflect,
|
|
dim=dim,
|
|
device=device,
|
|
)
|
|
|
|
if reflect:
|
|
# generate random reflection M and apply to the rotations
|
|
M = TestCorrespondingPointsAlignment.generate_random_reflection(
|
|
batch_size=batch_size, dim=dim, device=device
|
|
)
|
|
R = torch.bmm(M, R)
|
|
|
|
# apply the generated transformation to the generated
|
|
# point cloud X
|
|
X_t = _apply_pcl_transformation(X, R, T, s=s)
|
|
|
|
# run the CorrespondingPointsAlignment algorithm
|
|
R_est, T_est, s_est = points_alignment.corresponding_points_alignment(
|
|
X,
|
|
X_t,
|
|
allow_reflection=allow_reflection,
|
|
estimate_scale=estimate_scale,
|
|
)
|
|
|
|
assert_error_message = (
|
|
f"Corresponding_points_alignment assertion failure for "
|
|
f"n_points={n_points}, "
|
|
f"dim={dim}, "
|
|
f"use_pointclouds={use_pointclouds}, "
|
|
f"estimate_scale={estimate_scale}, "
|
|
f"reflect={reflect}, "
|
|
f"allow_reflection={allow_reflection}."
|
|
)
|
|
|
|
if reflect and not allow_reflection:
|
|
# check that all rotations have det=1
|
|
self._assert_all_close(
|
|
torch.det(R_est),
|
|
R_est.new_ones(batch_size),
|
|
assert_error_message,
|
|
)
|
|
|
|
else:
|
|
# check that the estimated tranformation is the same
|
|
# as the ground truth
|
|
if n_points >= (dim + 1):
|
|
# the checks on transforms apply only when
|
|
# the problem setup is unambiguous
|
|
self._assert_all_close(R_est, R, assert_error_message)
|
|
self._assert_all_close(T_est, T, assert_error_message)
|
|
self._assert_all_close(s_est, s, assert_error_message)
|
|
|
|
# check that the orthonormal part of the
|
|
# transformation has a correct determinant (+1/-1)
|
|
desired_det = R_est.new_ones(batch_size)
|
|
if reflect:
|
|
desired_det *= -1.0
|
|
self._assert_all_close(
|
|
torch.det(R_est), desired_det, assert_error_message
|
|
)
|
|
|
|
# check that the transformed point cloud
|
|
# X matches X_t
|
|
X_t_est = _apply_pcl_transformation(X, R_est, T_est, s=s_est)
|
|
self._assert_all_close(
|
|
X_t, X_t_est, assert_error_message, atol=1e-5
|
|
)
|
|
|
|
def _assert_all_close(self, a_, b_, err_message, atol=1e-6):
|
|
if isinstance(a_, Pointclouds):
|
|
a_ = a_.points_packed()
|
|
if isinstance(b_, Pointclouds):
|
|
b_ = b_.points_packed()
|
|
self.assertTrue(torch.allclose(a_, b_, atol=atol), err_message)
|