mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Umeyama
Summary: Umeyama estimates a rigid motion between two sets of corresponding points. Benchmark output for `bm_points_alignment` ``` Arguments key: [<allow_reflection>_<batch_size>_<dim>_<estimate_scale>_<n_points>_<use_pointclouds>] Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- CorrespodingPointsAlignment_True_1_3_True_100_False 7382 9833 68 CorrespodingPointsAlignment_True_1_3_True_10000_False 8183 10500 62 CorrespodingPointsAlignment_True_1_3_False_100_False 7301 9263 69 CorrespodingPointsAlignment_True_1_3_False_10000_False 7945 9746 64 CorrespodingPointsAlignment_True_1_20_True_100_False 13706 41623 37 CorrespodingPointsAlignment_True_1_20_True_10000_False 11044 33766 46 CorrespodingPointsAlignment_True_1_20_False_100_False 9908 28791 51 CorrespodingPointsAlignment_True_1_20_False_10000_False 9523 18680 53 CorrespodingPointsAlignment_True_10_3_True_100_False 29585 32026 17 CorrespodingPointsAlignment_True_10_3_True_10000_False 29626 36324 18 CorrespodingPointsAlignment_True_10_3_False_100_False 26013 29253 20 CorrespodingPointsAlignment_True_10_3_False_10000_False 25000 33820 20 CorrespodingPointsAlignment_True_10_20_True_100_False 40955 41592 13 CorrespodingPointsAlignment_True_10_20_True_10000_False 42087 42393 12 CorrespodingPointsAlignment_True_10_20_False_100_False 39863 40381 13 CorrespodingPointsAlignment_True_10_20_False_10000_False 40813 41699 13 CorrespodingPointsAlignment_True_100_3_True_100_False 183146 194745 3 CorrespodingPointsAlignment_True_100_3_True_10000_False 213789 231466 3 CorrespodingPointsAlignment_True_100_3_False_100_False 177805 180796 3 CorrespodingPointsAlignment_True_100_3_False_10000_False 184963 185695 3 CorrespodingPointsAlignment_True_100_20_True_100_False 347181 347325 2 CorrespodingPointsAlignment_True_100_20_True_10000_False 363259 363613 2 CorrespodingPointsAlignment_True_100_20_False_100_False 351769 352496 2 CorrespodingPointsAlignment_True_100_20_False_10000_False 375629 379818 2 CorrespodingPointsAlignment_False_1_3_True_100_False 11155 13770 45 CorrespodingPointsAlignment_False_1_3_True_10000_False 10743 13938 47 CorrespodingPointsAlignment_False_1_3_False_100_False 9578 11511 53 CorrespodingPointsAlignment_False_1_3_False_10000_False 9549 11984 53 CorrespodingPointsAlignment_False_1_20_True_100_False 13809 14183 37 CorrespodingPointsAlignment_False_1_20_True_10000_False 14084 15082 36 CorrespodingPointsAlignment_False_1_20_False_100_False 12765 14177 40 CorrespodingPointsAlignment_False_1_20_False_10000_False 12811 13096 40 CorrespodingPointsAlignment_False_10_3_True_100_False 28823 39384 18 CorrespodingPointsAlignment_False_10_3_True_10000_False 27135 27525 19 CorrespodingPointsAlignment_False_10_3_False_100_False 26236 28980 20 CorrespodingPointsAlignment_False_10_3_False_10000_False 42324 45123 12 CorrespodingPointsAlignment_False_10_20_True_100_False 723902 723902 1 CorrespodingPointsAlignment_False_10_20_True_10000_False 220007 252886 3 CorrespodingPointsAlignment_False_10_20_False_100_False 55593 71636 9 CorrespodingPointsAlignment_False_10_20_False_10000_False 44419 71861 12 CorrespodingPointsAlignment_False_100_3_True_100_False 184768 185199 3 CorrespodingPointsAlignment_False_100_3_True_10000_False 198657 213868 3 CorrespodingPointsAlignment_False_100_3_False_100_False 224598 309645 3 CorrespodingPointsAlignment_False_100_3_False_10000_False 197863 202002 3 CorrespodingPointsAlignment_False_100_20_True_100_False 293484 309459 2 CorrespodingPointsAlignment_False_100_20_True_10000_False 327253 366644 2 CorrespodingPointsAlignment_False_100_20_False_100_False 420793 422194 2 CorrespodingPointsAlignment_False_100_20_False_10000_False 462634 485542 2 CorrespodingPointsAlignment_True_1_3_True_100_True 7664 9909 66 CorrespodingPointsAlignment_True_1_3_True_10000_True 7190 8366 70 CorrespodingPointsAlignment_True_1_3_False_100_True 6549 8316 77 CorrespodingPointsAlignment_True_1_3_False_10000_True 6534 7710 77 CorrespodingPointsAlignment_True_10_3_True_100_True 29052 32940 18 CorrespodingPointsAlignment_True_10_3_True_10000_True 30526 33453 17 CorrespodingPointsAlignment_True_10_3_False_100_True 28708 32993 18 CorrespodingPointsAlignment_True_10_3_False_10000_True 30630 35973 17 CorrespodingPointsAlignment_True_100_3_True_100_True 264909 320820 3 CorrespodingPointsAlignment_True_100_3_True_10000_True 310902 322604 2 CorrespodingPointsAlignment_True_100_3_False_100_True 246832 250634 3 CorrespodingPointsAlignment_True_100_3_False_10000_True 276006 289061 2 CorrespodingPointsAlignment_False_1_3_True_100_True 11421 13757 44 CorrespodingPointsAlignment_False_1_3_True_10000_True 11199 12532 45 CorrespodingPointsAlignment_False_1_3_False_100_True 11474 15841 44 CorrespodingPointsAlignment_False_1_3_False_10000_True 10384 13188 49 CorrespodingPointsAlignment_False_10_3_True_100_True 36599 47340 14 CorrespodingPointsAlignment_False_10_3_True_10000_True 40702 50754 13 CorrespodingPointsAlignment_False_10_3_False_100_True 41277 52149 13 CorrespodingPointsAlignment_False_10_3_False_10000_True 34286 37091 15 CorrespodingPointsAlignment_False_100_3_True_100_True 254991 258578 2 CorrespodingPointsAlignment_False_100_3_True_10000_True 257999 261285 2 CorrespodingPointsAlignment_False_100_3_False_100_True 247511 248693 3 CorrespodingPointsAlignment_False_100_3_False_10000_True 251807 263865 3 ``` Reviewed By: gkioxari Differential Revision: D19808389 fbshipit-source-id: 83305a58627d2fc5dcaf3c3015132d8148f28c29
This commit is contained in:
parent
745aaf3915
commit
e5b1d6d3a3
@ -6,6 +6,7 @@ from .graph_conv import GraphConv
|
||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||
from .nearest_neighbor_points import nn_points_idx
|
||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||
from .points_alignment import corresponding_points_alignment
|
||||
from .sample_points_from_meshes import sample_points_from_meshes
|
||||
from .subdivide_meshes import SubdivideMeshes
|
||||
from .vert_align import vert_align
|
||||
|
151
pytorch3d/ops/points_alignment.py
Normal file
151
pytorch3d/ops/points_alignment.py
Normal file
@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
|
||||
def corresponding_points_alignment(
|
||||
X: Union[torch.Tensor, Pointclouds],
|
||||
Y: Union[torch.Tensor, Pointclouds],
|
||||
estimate_scale: bool = False,
|
||||
allow_reflection: bool = False,
|
||||
eps: float = 1e-8,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Finds a similarity transformation (rotation `R`, translation `T`
|
||||
and optionally scale `s`) between two given sets of corresponding
|
||||
`d`-dimensional points `X` and `Y` such that:
|
||||
|
||||
`s[i] X[i] R[i] + T[i] = Y[i]`,
|
||||
|
||||
for all batch indexes `i` in the least squares sense.
|
||||
|
||||
The algorithm is also known as Umeyama [1].
|
||||
|
||||
Args:
|
||||
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
estimate_scale: If `True`, also estimates a scaling component `s`
|
||||
of the transformation. Otherwise assumes an identity
|
||||
scale and returns a tensor of ones.
|
||||
allow_reflection: If `True`, allows the algorithm to return `R`
|
||||
which is orthonormal but has determinant==-1.
|
||||
eps: A scalar for clamping to avoid dividing by zero. Active for the
|
||||
code that estimates the output scale `s`.
|
||||
|
||||
Returns:
|
||||
3-element tuple containing
|
||||
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
|
||||
- **T**: Batch of translations of shape `(minibatch, d)`.
|
||||
- **s**: batch of scaling factors of shape `(minibatch, )`.
|
||||
|
||||
References:
|
||||
[1] Shinji Umeyama: Least-Suqares Estimation of
|
||||
Transformation Parameters Between Two Point Patterns
|
||||
"""
|
||||
|
||||
# make sure we convert input Pointclouds structures to tensors
|
||||
Xt, num_points = _convert_point_cloud_to_tensor(X)
|
||||
Yt, num_points_Y = _convert_point_cloud_to_tensor(Y)
|
||||
|
||||
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
|
||||
raise ValueError(
|
||||
"Point sets X and Y have to have the same \
|
||||
number of batches, points and dimensions."
|
||||
)
|
||||
|
||||
b, n, dim = Xt.shape
|
||||
|
||||
# compute the centroids of the point sets
|
||||
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
|
||||
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)
|
||||
|
||||
# mean-center the point sets
|
||||
Xc = Xt - Xmu[:, None]
|
||||
Yc = Yt - Ymu[:, None]
|
||||
|
||||
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
|
||||
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
|
||||
mask = (
|
||||
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
|
||||
< num_points[:, None]
|
||||
).type_as(Xc)
|
||||
Xc *= mask[:, :, None]
|
||||
Yc *= mask[:, :, None]
|
||||
|
||||
if (num_points < (dim + 1)).any():
|
||||
warnings.warn(
|
||||
"The size of one of the point clouds is <= dim+1. "
|
||||
+ "corresponding_points_alignment can't return a unique solution."
|
||||
)
|
||||
|
||||
# compute the covariance XYcov between the point sets Xc, Yc
|
||||
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
|
||||
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)
|
||||
|
||||
# decompose the covariance matrix XYcov
|
||||
U, S, V = torch.svd(XYcov)
|
||||
|
||||
# identity matrix used for fixing reflections
|
||||
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(
|
||||
b, 1, 1
|
||||
)
|
||||
|
||||
if not allow_reflection:
|
||||
# reflection test:
|
||||
# checks whether the estimated rotation has det==1,
|
||||
# if not, finds the nearest rotation s.t. det==1 by
|
||||
# flipping the sign of the last singular vector U
|
||||
R_test = torch.bmm(U, V.transpose(2, 1))
|
||||
E[:, -1, -1] = torch.det(R_test)
|
||||
|
||||
# find the rotation matrix by composing U and V again
|
||||
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
|
||||
|
||||
if estimate_scale:
|
||||
# estimate the scaling component of the transformation
|
||||
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
|
||||
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)
|
||||
|
||||
# the scaling component
|
||||
s = trace_ES / torch.clamp(Xcov, eps)
|
||||
|
||||
# translation component
|
||||
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]
|
||||
|
||||
else:
|
||||
# translation component
|
||||
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]
|
||||
|
||||
# unit scaling since we do not estimate scale
|
||||
s = T.new_ones(b)
|
||||
|
||||
return R, T, s
|
||||
|
||||
|
||||
def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]):
|
||||
"""
|
||||
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
|
||||
padded representation and returns it together with the number of points
|
||||
per batch. Otherwise, returns the input itself with the number of points
|
||||
set to the size of the second dimension of `pcl`.
|
||||
"""
|
||||
if isinstance(pcl, Pointclouds):
|
||||
X = pcl.points_padded()
|
||||
num_points = pcl.num_points_per_cloud()
|
||||
elif torch.is_tensor(pcl):
|
||||
X = pcl
|
||||
num_points = X.shape[1] * torch.ones(
|
||||
X.shape[0], device=X.device, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The inputs X, Y should be either Pointclouds objects or tensors."
|
||||
)
|
||||
return X, num_points
|
40
tests/bm_points_alignment.py
Normal file
40
tests/bm_points_alignment.py
Normal file
@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from test_points_alignment import TestCorrespondingPointsAlignment
|
||||
|
||||
|
||||
def bm_corresponding_points_alignment() -> None:
|
||||
|
||||
case_grid = {
|
||||
"allow_reflection": [True, False],
|
||||
"batch_size": [1, 10, 100],
|
||||
"dim": [3, 20],
|
||||
"estimate_scale": [True, False],
|
||||
"n_points": [100, 10000],
|
||||
"use_pointclouds": [False],
|
||||
}
|
||||
|
||||
test_args = sorted(case_grid.keys())
|
||||
test_cases = product(*[case_grid[k] for k in test_args])
|
||||
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
|
||||
|
||||
# add the use_pointclouds=True test cases whenever we have dim==3
|
||||
kwargs_to_add = []
|
||||
for entry in kwargs_list:
|
||||
if entry["dim"] == 3:
|
||||
entry_add = deepcopy(entry)
|
||||
entry_add["use_pointclouds"] = True
|
||||
kwargs_to_add.append(entry_add)
|
||||
kwargs_list.extend(kwargs_to_add)
|
||||
|
||||
benchmark(
|
||||
TestCorrespondingPointsAlignment.corresponding_points_alignment,
|
||||
"CorrespodingPointsAlignment",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
358
tests/test_points_alignment.py
Normal file
358
tests/test_points_alignment.py
Normal file
@ -0,0 +1,358 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.ops import points_alignment
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.transforms import rotation_conversions
|
||||
|
||||
|
||||
def _apply_pcl_transformation(X, R, T, s=None):
|
||||
"""
|
||||
Apply a batch of similarity/rigid transformations, parametrized with
|
||||
rotation `R`, translation `T` and scale `s`, to an input batch of
|
||||
point clouds `X`.
|
||||
"""
|
||||
if isinstance(X, Pointclouds):
|
||||
num_points = X.num_points_per_cloud()
|
||||
X_t = X.points_padded()
|
||||
else:
|
||||
X_t = X
|
||||
|
||||
if s is not None:
|
||||
X_t = s[:, None, None] * X_t
|
||||
|
||||
X_t = torch.bmm(X_t, R) + T[:, None, :]
|
||||
|
||||
if isinstance(X, Pointclouds):
|
||||
X_list = [x[:n_p] for x, n_p in zip(X_t, num_points)]
|
||||
X_t = Pointclouds(X_list)
|
||||
|
||||
return X_t
|
||||
|
||||
|
||||
class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
@staticmethod
|
||||
def random_rotation(batch_size, dim, device=None):
|
||||
"""
|
||||
Generates a batch of random `dim`-dimensional rotation matrices.
|
||||
"""
|
||||
if dim == 3:
|
||||
R = rotation_conversions.random_rotations(batch_size, device=device)
|
||||
else:
|
||||
# generate random rotation matrices with orthogonalization of
|
||||
# random normal square matrices, followed by a transformation
|
||||
# that ensures determinant(R)==1
|
||||
H = torch.randn(
|
||||
batch_size, dim, dim, dtype=torch.float32, device=device
|
||||
)
|
||||
U, _, V = torch.svd(H)
|
||||
E = torch.eye(dim, dtype=torch.float32, device=device)[None].repeat(
|
||||
batch_size, 1, 1
|
||||
)
|
||||
E[:, -1, -1] = torch.det(torch.bmm(U, V.transpose(2, 1)))
|
||||
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
|
||||
assert torch.allclose(
|
||||
torch.det(R), R.new_ones(batch_size), atol=1e-4
|
||||
)
|
||||
|
||||
return R
|
||||
|
||||
@staticmethod
|
||||
def init_point_cloud(
|
||||
batch_size=10,
|
||||
n_points=1000,
|
||||
dim=3,
|
||||
device=None,
|
||||
use_pointclouds=False,
|
||||
random_pcl_size=True,
|
||||
):
|
||||
"""
|
||||
Generate a batch of normally distributed point clouds.
|
||||
"""
|
||||
if use_pointclouds:
|
||||
assert dim == 3, "Pointclouds support only 3-dim points."
|
||||
# generate a `batch_size` point clouds with number of points
|
||||
# between 4 and `n_points`
|
||||
if random_pcl_size:
|
||||
n_points_per_batch = torch.randint(
|
||||
low=4,
|
||||
high=n_points,
|
||||
size=(batch_size,),
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
X_list = [
|
||||
torch.randn(
|
||||
int(n_pt), dim, device=device, dtype=torch.float32
|
||||
)
|
||||
for n_pt in n_points_per_batch
|
||||
]
|
||||
X = Pointclouds(X_list)
|
||||
else:
|
||||
X = torch.randn(
|
||||
batch_size,
|
||||
n_points,
|
||||
dim,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
X = Pointclouds(list(X))
|
||||
else:
|
||||
X = torch.randn(
|
||||
batch_size, n_points, dim, device=device, dtype=torch.float32
|
||||
)
|
||||
return X
|
||||
|
||||
@staticmethod
|
||||
def generate_pcl_transformation(
|
||||
batch_size=10, scale=False, reflect=False, dim=3, device=None
|
||||
):
|
||||
"""
|
||||
Generate a batch of random rigid/similarity transformations.
|
||||
"""
|
||||
R = TestCorrespondingPointsAlignment.random_rotation(
|
||||
batch_size, dim, device=device
|
||||
)
|
||||
T = torch.randn(batch_size, dim, dtype=torch.float32, device=device)
|
||||
if scale:
|
||||
s = torch.rand(batch_size, dtype=torch.float32, device=device) + 0.1
|
||||
else:
|
||||
s = torch.ones(batch_size, dtype=torch.float32, device=device)
|
||||
|
||||
return R, T, s
|
||||
|
||||
@staticmethod
|
||||
def generate_random_reflection(batch_size=10, dim=3, device=None):
|
||||
"""
|
||||
Generate a batch of reflection matrices of shape (batch_size, dim, dim),
|
||||
where M_i is an identity matrix with one random entry on the
|
||||
diagonal equal to -1.
|
||||
"""
|
||||
# randomly select one of the dimensions to reflect for each
|
||||
# element in the batch
|
||||
dim_to_reflect = torch.randint(
|
||||
low=0,
|
||||
high=dim,
|
||||
size=(batch_size,),
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# convert dim_to_reflect to a batch of reflection matrices M
|
||||
M = torch.diag_embed(
|
||||
(
|
||||
dim_to_reflect[:, None]
|
||||
!= torch.arange(dim, device=device, dtype=torch.float32)
|
||||
).float()
|
||||
* 2
|
||||
- 1,
|
||||
dim1=1,
|
||||
dim2=2,
|
||||
)
|
||||
|
||||
return M
|
||||
|
||||
@staticmethod
|
||||
def corresponding_points_alignment(
|
||||
batch_size=10,
|
||||
n_points=100,
|
||||
dim=3,
|
||||
use_pointclouds=False,
|
||||
estimate_scale=False,
|
||||
allow_reflection=False,
|
||||
reflect=False,
|
||||
):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# initialize a ground truth point cloud
|
||||
X = TestCorrespondingPointsAlignment.init_point_cloud(
|
||||
batch_size=batch_size,
|
||||
n_points=n_points,
|
||||
dim=dim,
|
||||
device=device,
|
||||
use_pointclouds=use_pointclouds,
|
||||
random_pcl_size=True,
|
||||
)
|
||||
|
||||
# generate the true transformation
|
||||
R, T, s = TestCorrespondingPointsAlignment.generate_pcl_transformation(
|
||||
batch_size=batch_size,
|
||||
scale=estimate_scale,
|
||||
reflect=reflect,
|
||||
dim=dim,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# apply the generated transformation to the generated
|
||||
# point cloud X
|
||||
X_t = _apply_pcl_transformation(X, R, T, s=s)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_corresponding_points_alignment():
|
||||
points_alignment.corresponding_points_alignment(
|
||||
X,
|
||||
X_t,
|
||||
allow_reflection=allow_reflection,
|
||||
estimate_scale=estimate_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return run_corresponding_points_alignment
|
||||
|
||||
def test_corresponding_points_alignment(self, batch_size=10):
|
||||
"""
|
||||
Tests whether we can estimate a rigid/similarity motion between
|
||||
a randomly initialized point cloud and its randomly transformed version.
|
||||
|
||||
The tests are done for all possible combinations
|
||||
of the following boolean flags:
|
||||
- estimate_scale ... Estimate also a scaling component of
|
||||
the transformation.
|
||||
- reflect ... The ground truth orthonormal part of the generated
|
||||
transformation is a reflection (det==-1).
|
||||
- allow_reflection ... If True, the orthonormal matrix of the
|
||||
estimated transformation is allowed to be
|
||||
a reflection (det==-1).
|
||||
- use_pointclouds ... If True, passes the Pointclouds objects
|
||||
to corresponding_points_alignment.
|
||||
"""
|
||||
|
||||
# run this for several different point cloud sizes
|
||||
for n_points in (100, 3, 2, 1, 0):
|
||||
# run this for several different dimensionalities
|
||||
for dim in torch.arange(2, 10):
|
||||
# switches whether we should use the Pointclouds inputs
|
||||
use_point_clouds_cases = (
|
||||
(True, False) if dim == 3 and n_points > 3 else (False,)
|
||||
)
|
||||
for use_pointclouds in use_point_clouds_cases:
|
||||
for estimate_scale in (False, True):
|
||||
for reflect in (False, True):
|
||||
for allow_reflection in (False, True):
|
||||
self._test_single_corresponding_points_alignment(
|
||||
batch_size=10,
|
||||
n_points=n_points,
|
||||
dim=int(dim),
|
||||
use_pointclouds=use_pointclouds,
|
||||
estimate_scale=estimate_scale,
|
||||
reflect=reflect,
|
||||
allow_reflection=allow_reflection,
|
||||
)
|
||||
|
||||
def _test_single_corresponding_points_alignment(
|
||||
self,
|
||||
batch_size=10,
|
||||
n_points=100,
|
||||
dim=3,
|
||||
use_pointclouds=False,
|
||||
estimate_scale=False,
|
||||
reflect=False,
|
||||
allow_reflection=False,
|
||||
):
|
||||
"""
|
||||
Executes a single test for `corresponding_points_alignment` for a
|
||||
specific setting of the inputs / outputs.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# initialize the a ground truth point cloud
|
||||
X = TestCorrespondingPointsAlignment.init_point_cloud(
|
||||
batch_size=batch_size,
|
||||
n_points=n_points,
|
||||
dim=dim,
|
||||
device=device,
|
||||
use_pointclouds=use_pointclouds,
|
||||
random_pcl_size=True,
|
||||
)
|
||||
|
||||
# generate the true transformation
|
||||
R, T, s = TestCorrespondingPointsAlignment.generate_pcl_transformation(
|
||||
batch_size=batch_size,
|
||||
scale=estimate_scale,
|
||||
reflect=reflect,
|
||||
dim=dim,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if reflect:
|
||||
# generate random reflection M and apply to the rotations
|
||||
M = TestCorrespondingPointsAlignment.generate_random_reflection(
|
||||
batch_size=batch_size, dim=dim, device=device
|
||||
)
|
||||
R = torch.bmm(M, R)
|
||||
|
||||
# apply the generated transformation to the generated
|
||||
# point cloud X
|
||||
X_t = _apply_pcl_transformation(X, R, T, s=s)
|
||||
|
||||
# run the CorrespondingPointsAlignment algorithm
|
||||
R_est, T_est, s_est = points_alignment.corresponding_points_alignment(
|
||||
X,
|
||||
X_t,
|
||||
allow_reflection=allow_reflection,
|
||||
estimate_scale=estimate_scale,
|
||||
)
|
||||
|
||||
assert_error_message = (
|
||||
f"Corresponding_points_alignment assertion failure for "
|
||||
f"n_points={n_points}, "
|
||||
f"dim={dim}, "
|
||||
f"use_pointclouds={use_pointclouds}, "
|
||||
f"estimate_scale={estimate_scale}, "
|
||||
f"reflect={reflect}, "
|
||||
f"allow_reflection={allow_reflection}."
|
||||
)
|
||||
|
||||
if reflect and not allow_reflection:
|
||||
# check that all rotations have det=1
|
||||
self._assert_all_close(
|
||||
torch.det(R_est),
|
||||
R_est.new_ones(batch_size),
|
||||
assert_error_message,
|
||||
)
|
||||
|
||||
else:
|
||||
# check that the estimated tranformation is the same
|
||||
# as the ground truth
|
||||
if n_points >= (dim + 1):
|
||||
# the checks on transforms apply only when
|
||||
# the problem setup is unambiguous
|
||||
self._assert_all_close(R_est, R, assert_error_message)
|
||||
self._assert_all_close(T_est, T, assert_error_message)
|
||||
self._assert_all_close(s_est, s, assert_error_message)
|
||||
|
||||
# check that the orthonormal part of the
|
||||
# transformation has a correct determinant (+1/-1)
|
||||
desired_det = R_est.new_ones(batch_size)
|
||||
if reflect:
|
||||
desired_det *= -1.0
|
||||
self._assert_all_close(
|
||||
torch.det(R_est), desired_det, assert_error_message
|
||||
)
|
||||
|
||||
# check that the transformed point cloud
|
||||
# X matches X_t
|
||||
X_t_est = _apply_pcl_transformation(X, R_est, T_est, s=s_est)
|
||||
self._assert_all_close(
|
||||
X_t, X_t_est, assert_error_message, atol=1e-5
|
||||
)
|
||||
|
||||
def _assert_all_close(self, a_, b_, err_message, atol=1e-6):
|
||||
if isinstance(a_, Pointclouds):
|
||||
a_ = a_.points_packed()
|
||||
if isinstance(b_, Pointclouds):
|
||||
b_ = b_.points_packed()
|
||||
self.assertTrue(torch.allclose(a_, b_, atol=atol), err_message)
|
Loading…
x
Reference in New Issue
Block a user