ICP - point-to-point version

Summary:
The iterative closest point algorithm - point-to-point version.

Output of `bm_iterative_closest_point`:
Argument key: `batch_size dim n_points_X n_points_Y use_pointclouds`

```
Benchmark                                         Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
IterativeClosestPoint_1_3_100_100_False              107569          111323              5
IterativeClosestPoint_1_3_100_1000_False             118972          122306              5
IterativeClosestPoint_1_3_1000_100_False             108576          110978              5
IterativeClosestPoint_1_3_1000_1000_False            331836          333515              2
IterativeClosestPoint_1_20_100_100_False             134387          137842              4
IterativeClosestPoint_1_20_100_1000_False            149218          153405              4
IterativeClosestPoint_1_20_1000_100_False            414248          416595              2
IterativeClosestPoint_1_20_1000_1000_False           374318          374662              2
IterativeClosestPoint_10_3_100_100_False             539852          539852              1
IterativeClosestPoint_10_3_100_1000_False            752784          752784              1
IterativeClosestPoint_10_3_1000_100_False           1070700         1070700              1
IterativeClosestPoint_10_3_1000_1000_False          1164020         1164020              1
IterativeClosestPoint_10_20_100_100_False            374548          377337              2
IterativeClosestPoint_10_20_100_1000_False           472764          476685              2
IterativeClosestPoint_10_20_1000_100_False          1457175         1457175              1
IterativeClosestPoint_10_20_1000_1000_False         2195820         2195820              1
IterativeClosestPoint_1_3_100_100_True               110084          115824              5
IterativeClosestPoint_1_3_100_1000_True              142728          147696              4
IterativeClosestPoint_1_3_1000_100_True              212966          213966              3
IterativeClosestPoint_1_3_1000_1000_True             369130          375114              2
IterativeClosestPoint_10_3_100_100_True              354615          355179              2
IterativeClosestPoint_10_3_100_1000_True             451815          452704              2
IterativeClosestPoint_10_3_1000_100_True             511833          511833              1
IterativeClosestPoint_10_3_1000_1000_True            798453          798453              1
--------------------------------------------------------------------------------
```

Reviewed By: shapovalov, gkioxari

Differential Revision: D19909952

fbshipit-source-id: f77fadc88fb7c53999909d594114b182ee2a3def
This commit is contained in:
David Novotny 2020-04-16 13:59:34 -07:00 committed by Facebook GitHub Bot
parent b5eb33b36c
commit 8abbe22ffb
6 changed files with 603 additions and 45 deletions

View File

@ -6,9 +6,10 @@ from .graph_conv import GraphConv
from .knn import knn_gather, knn_points
from .mesh_face_areas_normals import mesh_face_areas_normals
from .packed_to_padded import packed_to_padded, padded_to_packed
from .points_alignment import corresponding_points_alignment
from .points_alignment import corresponding_points_alignment, iterative_closest_point
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean
from .vert_align import vert_align

View File

@ -1,22 +1,231 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from typing import List, Tuple, Union
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Union
import torch
from pytorch3d.ops import utils as oputil
from pytorch3d.ops import knn_points
from pytorch3d.structures import utils as strutil
from pytorch3d.structures.pointclouds import Pointclouds
from . import utils as oputil
if TYPE_CHECKING:
from pytorch3d.structures.pointclouds import Pointclouds
# named tuples for inputs/outputs
class SimilarityTransform(NamedTuple):
R: torch.Tensor
T: torch.Tensor
s: torch.Tensor
class ICPSolution(NamedTuple):
converged: bool
rmse: Union[torch.Tensor, None]
Xt: torch.Tensor
RTs: SimilarityTransform
t_history: List[SimilarityTransform]
def iterative_closest_point(
X: Union[torch.Tensor, "Pointclouds"],
Y: Union[torch.Tensor, "Pointclouds"],
init_transform: Optional[SimilarityTransform] = None,
max_iterations: int = 100,
relative_rmse_thr: float = 1e-6,
estimate_scale: bool = False,
allow_reflection: bool = False,
verbose: bool = False,
) -> ICPSolution:
"""
Executes the iterative closest point (ICP) algorithm [1, 2] in order to find
a similarity transformation (rotation `R`, translation `T`, and
optionally scale `s`) between two given differently-sized sets of
`d`-dimensional points `X` and `Y`, such that:
`s[i] X[i] R[i] + T[i] = Y[NN[i]]`,
for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands
for the indices of nearest neighbors from `Y` to each point in `X`.
Note, however, that the solution is only a local optimum.
Args:
**X**: Batch of `d`-dimensional points
of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object.
**Y**: Batch of `d`-dimensional points
of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object.
**init_transform**: A named-tuple `SimilarityTransform` of tensors
`R`, `T, `s`, where `R` is a batch of orthonormal matrices of
shape `(minibatch, d, d)`, `T` is a batch of translations
of shape `(minibatch, d)` and `s` is a batch of scaling factors
of shape `(minibatch,)`.
**max_iterations**: The maximum number of ICP iterations.
**relative_rmse_thr**: A threshold on the relative root mean squared error
used to terminate the algorithm.
**estimate_scale**: If `True`, also estimates a scaling component `s`
of the transformation. Otherwise assumes the identity
scale and returns a tensor of ones.
**allow_reflection**: If `True`, allows the algorithm to return `R`
which is orthonormal but has determinant==-1.
**verbose**: If `True`, prints status messages during each ICP iteration.
Returns:
A named tuple `ICPSolution` with the following fields:
**converged**: A boolean flag denoting whether the algorithm converged
successfully (=`True`) or not (=`False`).
**rmse**: Attained root mean squared error after termination of ICP.
**Xt**: The point cloud `X` transformed with the final transformation
(`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an
instance of `Pointclouds`, otherwise returns `torch.Tensor`.
**RTs**: A named tuple `SimilarityTransform` containing
a batch of similarity transforms with fields:
**R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
**T**: Batch of translations of shape `(minibatch, d)`.
**s**: batch of scaling factors of shape `(minibatch, )`.
**t_history**: A list of named tuples `SimilarityTransform`
the transformation parameters after each ICP iteration.
References:
[1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992.
[2] https://en.wikipedia.org/wiki/Iterative_closest_point
"""
# make sure we convert input Pointclouds structures to
# padded tensors of shape (N, P, 3)
Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X)
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
b, size_X, dim = Xt.shape
if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]):
raise ValueError(
"Point sets X and Y have to have the same "
+ "number of batches and data dimensions."
)
if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and (
num_points_Y != num_points_X
).any():
# we have a heterogeneous input (e.g. because X/Y is
# an instance of Pointclouds)
mask_X = (
torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None]
< num_points_X[:, None]
).type_as(Xt)
else:
mask_X = Xt.new_ones(b, size_X)
# clone the initial point cloud
Xt_init = Xt.clone()
if init_transform is not None:
# parse the initial transform from the input and apply to Xt
try:
R, T, s = init_transform
assert (
R.shape == torch.Size((b, dim, dim))
and T.shape == torch.Size((b, dim))
and s.shape == torch.Size((b,))
)
except Exception:
raise ValueError(
"The initial transformation init_transform has to be "
"a named tuple SimilarityTransform with elements (R, T, s). "
"R are dim x dim orthonormal matrices of shape "
"(minibatch, dim, dim), T is a batch of dim-dimensional "
"translations of shape (minibatch, dim) and s is a batch "
"of scalars of shape (minibatch,)."
)
# apply the init transform to the input point cloud
Xt = _apply_similarity_transform(Xt, R, T, s)
else:
# initialize the transformation with identity
R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype)
T = Xt.new_zeros((b, dim))
s = Xt.new_ones(b)
prev_rmse = None
rmse = None
iteration = -1
converged = False
# initialize the transformation history
t_history = []
# the main loop over ICP iterations
for iteration in range(max_iterations):
Xt_nn_points = knn_points(
Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True
)[2][:, :, 0, :]
# get the alignment of the nearest neighbors from Yt with Xt_init
R, T, s = corresponding_points_alignment(
Xt_init,
Xt_nn_points,
weights=mask_X,
estimate_scale=estimate_scale,
allow_reflection=allow_reflection,
)
# apply the estimated similarity transform to Xt_init
Xt = _apply_similarity_transform(Xt_init, R, T, s)
# add the current transformation to the history
t_history.append(SimilarityTransform(R, T, s))
# compute the root mean squared error
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
# compute the relative rmse
if prev_rmse is None:
relative_rmse = rmse.new_ones(b)
else:
relative_rmse = (prev_rmse - rmse) / prev_rmse
if verbose:
rmse_msg = (
f"ICP iteration {iteration}: mean/max rmse = "
+ f"{rmse.mean():1.2e}/{rmse.max():1.2e} "
+ f"; mean relative rmse = {relative_rmse.mean():1.2e}"
)
print(rmse_msg)
# check for convergence
if (relative_rmse <= relative_rmse_thr).all():
converged = True
break
# update the previous rmse
prev_rmse = rmse
if verbose:
if converged:
print(f"ICP has converged in {iteration + 1} iterations.")
else:
print(f"ICP has not converged in {max_iterations} iterations.")
if oputil.is_pointclouds(X):
Xt = X.update_padded(Xt) # type: ignore
return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
# threshold for checking that point crosscorelation
# is full rank in corresponding_points_alignment
AMBIGUOUS_ROT_SINGULAR_THR = 1e-15
def corresponding_points_alignment(
X: Union[torch.Tensor, Pointclouds],
Y: Union[torch.Tensor, Pointclouds],
X: Union[torch.Tensor, "Pointclouds"],
Y: Union[torch.Tensor, "Pointclouds"],
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
estimate_scale: bool = False,
allow_reflection: bool = False,
eps: float = 1e-8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
eps: float = 1e-9,
) -> SimilarityTransform:
"""
Finds a similarity transformation (rotation `R`, translation `T`
and optionally scale `s`) between two given sets of corresponding
@ -29,25 +238,25 @@ def corresponding_points_alignment(
The algorithm is also known as Umeyama [1].
Args:
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
**X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
**Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
weights: Batch of non-negative weights of
**weights**: Batch of non-negative weights of
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
tensors that may have different shapes; in that case, the length of
i-th tensor should be equal to the number of points in X_i and Y_i.
Passing `None` means uniform weights.
estimate_scale: If `True`, also estimates a scaling component `s`
**estimate_scale**: If `True`, also estimates a scaling component `s`
of the transformation. Otherwise assumes an identity
scale and returns a tensor of ones.
allow_reflection: If `True`, allows the algorithm to return `R`
**allow_reflection**: If `True`, allows the algorithm to return `R`
which is orthonormal but has determinant==-1.
eps: A scalar for clamping to avoid dividing by zero. Active for the
**eps**: A scalar for clamping to avoid dividing by zero. Active for the
code that estimates the output scale `s`.
Returns:
3-element tuple containing
3-element named tuple `SimilarityTransform` containing
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
- **T**: Batch of translations of shape `(minibatch, d)`.
- **s**: batch of scaling factors of shape `(minibatch, )`.
@ -58,8 +267,8 @@ def corresponding_points_alignment(
"""
# make sure we convert input Pointclouds structures to tensors
Xt, num_points = _convert_point_cloud_to_tensor(X)
Yt, num_points_Y = _convert_point_cloud_to_tensor(Y)
Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
raise ValueError(
@ -90,8 +299,8 @@ def corresponding_points_alignment(
weights = mask if weights is None else mask * weights.type_as(Xt)
# compute the centroids of the point sets
Xmu = oputil.wmean(Xt, weights, eps=eps)
Ymu = oputil.wmean(Yt, weights, eps=eps)
Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
Ymu = oputil.wmean(Yt, weight=weights, eps=eps)
# mean-center the point sets
Xc = Xt - Xmu
@ -107,7 +316,7 @@ def corresponding_points_alignment(
if (num_points < (dim + 1)).any():
warnings.warn(
"The size of one of the point clouds is <= dim+1. "
+ "corresponding_points_alignment can't return a unique solution."
+ "corresponding_points_alignment cannot return a unique rotation."
)
# compute the covariance XYcov between the point sets Xc, Yc
@ -117,6 +326,16 @@ def corresponding_points_alignment(
# decompose the covariance matrix XYcov
U, S, V = torch.svd(XYcov)
# catch ambiguous rotation by checking the magnitude of singular values
if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not (
num_points < (dim + 1)
).any():
warnings.warn(
"Excessively low rank of "
+ "cross-correlation between aligned point clouds. "
+ "corresponding_points_alignment cannot return a unique rotation."
)
# identity matrix used for fixing reflections
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)
@ -148,26 +367,18 @@ def corresponding_points_alignment(
# unit scaling since we do not estimate scale
s = T.new_ones(b)
return R, T, s
return SimilarityTransform(R, T, s)
def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]):
def _apply_similarity_transform(
X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor
) -> torch.Tensor:
"""
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
padded representation and returns it together with the number of points
per batch. Otherwise, returns the input itself with the number of points
set to the size of the second dimension of `pcl`.
Applies a similarity transformation parametrized with a batch of orthonormal
matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T`
of shape `(minibatch, d)` and a batch of scaling factors `s`
of shape `(minibatch,)` to a given `d`-dimensional cloud `X`
of shape `(minibatch, num_points, d)`
"""
if isinstance(pcl, Pointclouds):
X = pcl.points_padded()
num_points = pcl.num_points_per_cloud()
elif torch.is_tensor(pcl):
X = pcl
num_points = X.shape[1] * torch.ones(
X.shape[0], device=X.device, dtype=torch.int64
)
else:
raise ValueError(
"The inputs X, Y should be either Pointclouds objects or tensors."
)
return X, num_points
X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :]
return X

View File

@ -1,9 +1,13 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
if TYPE_CHECKING:
from pytorch3d.structures import Pointclouds
def wmean(
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
@ -41,3 +45,55 @@ def wmean(
return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp(
eps
)
def eyes(
dim: int,
N: int,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Generates a batch of `N` identity matrices of shape `(N, dim, dim)`.
Args:
**dim**: The dimensionality of the identity matrices.
**N**: The number of identity matrices.
**device**: The device to be used for allocating the matrices.
**dtype**: The datatype of the matrices.
Returns:
**identities**: A batch of identity matrices of shape `(N, dim, dim)`.
"""
identities = torch.eye(dim, device=device, dtype=dtype)
return identities[None].repeat(N, 1, 1)
def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
"""
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
padded representation and returns it together with the number of points
per batch. Otherwise, returns the input itself with the number of points
set to the size of the second dimension of `pcl`.
"""
if is_pointclouds(pcl):
X = pcl.points_padded() # type: ignore
num_points = pcl.num_points_per_cloud() # type: ignore
elif torch.is_tensor(pcl):
X = pcl
num_points = X.shape[1] * torch.ones( # type: ignore
X.shape[0], device=X.device, dtype=torch.int64
)
else:
raise ValueError(
"The inputs X, Y should be either Pointclouds objects or tensors."
)
return X, num_points
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
""" Checks whether the input `pcl` is an instance `Pointclouds` of
by checking the existence of `points_padded` and `num_points_per_cloud`
functions.
"""
return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud")

View File

@ -5,7 +5,38 @@ from copy import deepcopy
from itertools import product
from fvcore.common.benchmark import benchmark
from test_points_alignment import TestCorrespondingPointsAlignment
from test_points_alignment import TestCorrespondingPointsAlignment, TestICP
def bm_iterative_closest_point() -> None:
case_grid = {
"batch_size": [1, 10],
"dim": [3, 20],
"n_points_X": [100, 1000],
"n_points_Y": [100, 1000],
"use_pointclouds": [False],
}
test_args = sorted(case_grid.keys())
test_cases = product(*case_grid.values())
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
# add the use_pointclouds=True test cases whenever we have dim==3
kwargs_to_add = []
for entry in kwargs_list:
if entry["dim"] == 3:
entry_add = deepcopy(entry)
entry_add["use_pointclouds"] = True
kwargs_to_add.append(entry_add)
kwargs_list.extend(kwargs_to_add)
benchmark(
TestICP.iterative_closest_point,
"IterativeClosestPoint",
kwargs_list,
warmup_iters=1,
)
def bm_corresponding_points_alignment() -> None:
@ -21,7 +52,7 @@ def bm_corresponding_points_alignment() -> None:
}
test_args = sorted(case_grid.keys())
test_cases = product(*[case_grid[k] for k in test_args])
test_cases = product(*case_grid.values())
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
# add the use_pointclouds=True test cases whenever we have dim==3

BIN
tests/icp_data.pth Normal file

Binary file not shown.

View File

@ -1,8 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from pathlib import Path
import numpy as np
import torch
@ -36,6 +35,256 @@ def _apply_pcl_transformation(X, R, T, s=None):
return X_t
class TestICP(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
trimesh_results_path = Path(__file__).resolve().parent / "icp_data.pth"
self.trimesh_results = torch.load(trimesh_results_path)
@staticmethod
def iterative_closest_point(
batch_size=10,
n_points_X=100,
n_points_Y=100,
dim=3,
use_pointclouds=False,
estimate_scale=False,
):
device = torch.device("cuda:0")
# initialize a ground truth point cloud
X, Y = [
TestCorrespondingPointsAlignment.init_point_cloud(
batch_size=batch_size,
n_points=n_points,
dim=dim,
device=device,
use_pointclouds=use_pointclouds,
random_pcl_size=True,
fix_seed=i,
)
for i, n_points in enumerate((n_points_X, n_points_Y))
]
torch.cuda.synchronize()
def run_iterative_closest_point():
points_alignment.iterative_closest_point(
X,
Y,
estimate_scale=estimate_scale,
allow_reflection=False,
verbose=False,
max_iterations=100,
relative_rmse_thr=1e-4,
)
torch.cuda.synchronize()
return run_iterative_closest_point
def test_init_transformation(self, batch_size=10):
"""
First runs a full ICP on a random problem. Then takes a given point
in the history of ICP iteration transformations, initializes
a second run of ICP with this transformation and checks whether
both runs ended with the same solution.
"""
device = torch.device("cuda:0")
for dim in (2, 3, 11):
for n_points_X in (30, 100):
for n_points_Y in (30, 100):
# initialize ground truth point clouds
X, Y = [
TestCorrespondingPointsAlignment.init_point_cloud(
batch_size=batch_size,
n_points=n_points,
dim=dim,
device=device,
use_pointclouds=False,
random_pcl_size=True,
)
for n_points in (n_points_X, n_points_Y)
]
# run full icp
converged, _, Xt, (
R,
T,
s,
), t_hist = points_alignment.iterative_closest_point(
X,
Y,
estimate_scale=False,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
# start from the solution after the third
# iteration of the previous ICP
t_init = t_hist[min(2, len(t_hist) - 1)]
# rerun the ICP
converged_init, _, Xt_init, (
R_init,
T_init,
s_init,
), t_hist_init = points_alignment.iterative_closest_point(
X,
Y,
init_transform=t_init,
estimate_scale=False,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
# compare transformations and obtained clouds
# check that both sets of transforms are the same
atol = 3e-5
self.assertClose(R_init, R, atol=atol)
self.assertClose(T_init, T, atol=atol)
self.assertClose(s_init, s, atol=atol)
self.assertClose(Xt_init, Xt, atol=atol)
def test_heterogenous_inputs(self, batch_size=10):
"""
Tests whether we get the same result when running ICP on
a set of randomly-sized Pointclouds and on their padded versions.
"""
device = torch.device("cuda:0")
for estimate_scale in (True, False):
for max_n_points in (10, 30, 100):
# initialize ground truth point clouds
X_pcl, Y_pcl = [
TestCorrespondingPointsAlignment.init_point_cloud(
batch_size=batch_size,
n_points=max_n_points,
dim=3,
device=device,
use_pointclouds=True,
random_pcl_size=True,
)
for _ in range(2)
]
# get the padded versions and their num of points
X_padded = X_pcl.points_padded()
Y_padded = Y_pcl.points_padded()
n_points_X = X_pcl.num_points_per_cloud()
n_points_Y = Y_pcl.num_points_per_cloud()
# run icp with Pointlouds inputs
_, _, Xt_pcl, (
R_pcl,
T_pcl,
s_pcl,
), _ = points_alignment.iterative_closest_point(
X_pcl,
Y_pcl,
estimate_scale=estimate_scale,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
Xt_pcl = Xt_pcl.points_padded()
# run icp with tensor inputs on each element
# of the batch separately
icp_results = [
points_alignment.iterative_closest_point(
X_[None, :n_X, :],
Y_[None, :n_Y, :],
estimate_scale=estimate_scale,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
for X_, Y_, n_X, n_Y in zip(
X_padded, Y_padded, n_points_X, n_points_Y
)
]
# parse out the transformation results
R, T, s = [
torch.cat([x.RTs[i] for x in icp_results], dim=0) for i in range(3)
]
# check that both sets of transforms are the same
atol = 1e-5
self.assertClose(R_pcl, R, atol=atol)
self.assertClose(T_pcl, T, atol=atol)
self.assertClose(s_pcl, s, atol=atol)
# compare the transformed point clouds
for pcli in range(batch_size):
nX = n_points_X[pcli]
Xt_ = icp_results[pcli].Xt[0, :nX]
Xt_pcl_ = Xt_pcl[pcli][:nX]
self.assertClose(Xt_pcl_, Xt_, atol=atol)
def test_compare_with_trimesh(self):
"""
Compares the outputs of `iterative_closest_point` with the results
of `trimesh.registration.icp` from the `trimesh` python package:
https://github.com/mikedh/trimesh
We have run `trimesh.registration.icp` on several random problems
with different point cloud sizes. The results of trimesh, together with
the randomly generated input clouds are loaded in the constructor of
this class and this test compares the loaded results to our runs.
"""
for n_points_X in (10, 20, 50, 100):
for n_points_Y in (10, 20, 50, 100):
self._compare_with_trimesh(n_points_X=n_points_X, n_points_Y=n_points_Y)
def _compare_with_trimesh(
self, n_points_X=100, n_points_Y=100, estimate_scale=False
):
"""
Executes a single test for `iterative_closest_point` for a
specific setting of the inputs / outputs. Compares the result with
the result of the trimesh package on the same input data.
"""
device = torch.device("cuda:0")
# load the trimesh results and the initial point clouds for icp
key = (int(n_points_X), int(n_points_Y), int(estimate_scale))
X, Y, R_trimesh, T_trimesh, s_trimesh = [
x.to(device) for x in self.trimesh_results[key]
]
# run the icp algorithm
converged, _, _, (
R_ours,
T_ours,
s_ours,
), _ = points_alignment.iterative_closest_point(
X,
Y,
estimate_scale=estimate_scale,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
# check that we have the same transformation
# and that the icp converged
atol = 1e-5
self.assertClose(R_ours, R_trimesh, atol=atol)
self.assertClose(T_ours, T_trimesh, atol=atol)
self.assertClose(s_ours, s_trimesh, atol=atol)
self.assertTrue(converged)
class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
@ -72,10 +321,17 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
device=None,
use_pointclouds=False,
random_pcl_size=True,
fix_seed=None,
):
"""
Generate a batch of normally distributed point clouds.
"""
if fix_seed is not None:
# make sure we always generate the same pointcloud
seed = torch.random.get_rng_state()
torch.manual_seed(fix_seed)
if use_pointclouds:
assert dim == 3, "Pointclouds support only 3-dim points."
# generate a `batch_size` point clouds with number of points
@ -102,6 +358,10 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
X = torch.randn(
batch_size, n_points, dim, device=device, dtype=torch.float32
)
if fix_seed:
torch.random.set_rng_state(seed)
return X
@staticmethod
@ -230,7 +490,6 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
- use_pointclouds ... If True, passes the Pointclouds objects
to corresponding_points_alignment.
"""
# run this for several different point cloud sizes
for n_points in (100, 3, 2, 1):
# run this for several different dimensionalities