mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
ICP - point-to-point version
Summary: The iterative closest point algorithm - point-to-point version. Output of `bm_iterative_closest_point`: Argument key: `batch_size dim n_points_X n_points_Y use_pointclouds` ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- IterativeClosestPoint_1_3_100_100_False 107569 111323 5 IterativeClosestPoint_1_3_100_1000_False 118972 122306 5 IterativeClosestPoint_1_3_1000_100_False 108576 110978 5 IterativeClosestPoint_1_3_1000_1000_False 331836 333515 2 IterativeClosestPoint_1_20_100_100_False 134387 137842 4 IterativeClosestPoint_1_20_100_1000_False 149218 153405 4 IterativeClosestPoint_1_20_1000_100_False 414248 416595 2 IterativeClosestPoint_1_20_1000_1000_False 374318 374662 2 IterativeClosestPoint_10_3_100_100_False 539852 539852 1 IterativeClosestPoint_10_3_100_1000_False 752784 752784 1 IterativeClosestPoint_10_3_1000_100_False 1070700 1070700 1 IterativeClosestPoint_10_3_1000_1000_False 1164020 1164020 1 IterativeClosestPoint_10_20_100_100_False 374548 377337 2 IterativeClosestPoint_10_20_100_1000_False 472764 476685 2 IterativeClosestPoint_10_20_1000_100_False 1457175 1457175 1 IterativeClosestPoint_10_20_1000_1000_False 2195820 2195820 1 IterativeClosestPoint_1_3_100_100_True 110084 115824 5 IterativeClosestPoint_1_3_100_1000_True 142728 147696 4 IterativeClosestPoint_1_3_1000_100_True 212966 213966 3 IterativeClosestPoint_1_3_1000_1000_True 369130 375114 2 IterativeClosestPoint_10_3_100_100_True 354615 355179 2 IterativeClosestPoint_10_3_100_1000_True 451815 452704 2 IterativeClosestPoint_10_3_1000_100_True 511833 511833 1 IterativeClosestPoint_10_3_1000_1000_True 798453 798453 1 -------------------------------------------------------------------------------- ``` Reviewed By: shapovalov, gkioxari Differential Revision: D19909952 fbshipit-source-id: f77fadc88fb7c53999909d594114b182ee2a3def
This commit is contained in:
parent
b5eb33b36c
commit
8abbe22ffb
@ -6,9 +6,10 @@ from .graph_conv import GraphConv
|
||||
from .knn import knn_gather, knn_points
|
||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||
from .points_alignment import corresponding_points_alignment
|
||||
from .points_alignment import corresponding_points_alignment, iterative_closest_point
|
||||
from .sample_points_from_meshes import sample_points_from_meshes
|
||||
from .subdivide_meshes import SubdivideMeshes
|
||||
from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean
|
||||
from .vert_align import vert_align
|
||||
|
||||
|
||||
|
@ -1,22 +1,231 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d.ops import utils as oputil
|
||||
from pytorch3d.ops import knn_points
|
||||
from pytorch3d.structures import utils as strutil
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from . import utils as oputil
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
|
||||
# named tuples for inputs/outputs
|
||||
class SimilarityTransform(NamedTuple):
|
||||
R: torch.Tensor
|
||||
T: torch.Tensor
|
||||
s: torch.Tensor
|
||||
|
||||
|
||||
class ICPSolution(NamedTuple):
|
||||
converged: bool
|
||||
rmse: Union[torch.Tensor, None]
|
||||
Xt: torch.Tensor
|
||||
RTs: SimilarityTransform
|
||||
t_history: List[SimilarityTransform]
|
||||
|
||||
|
||||
def iterative_closest_point(
|
||||
X: Union[torch.Tensor, "Pointclouds"],
|
||||
Y: Union[torch.Tensor, "Pointclouds"],
|
||||
init_transform: Optional[SimilarityTransform] = None,
|
||||
max_iterations: int = 100,
|
||||
relative_rmse_thr: float = 1e-6,
|
||||
estimate_scale: bool = False,
|
||||
allow_reflection: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> ICPSolution:
|
||||
"""
|
||||
Executes the iterative closest point (ICP) algorithm [1, 2] in order to find
|
||||
a similarity transformation (rotation `R`, translation `T`, and
|
||||
optionally scale `s`) between two given differently-sized sets of
|
||||
`d`-dimensional points `X` and `Y`, such that:
|
||||
|
||||
`s[i] X[i] R[i] + T[i] = Y[NN[i]]`,
|
||||
|
||||
for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands
|
||||
for the indices of nearest neighbors from `Y` to each point in `X`.
|
||||
Note, however, that the solution is only a local optimum.
|
||||
|
||||
Args:
|
||||
**X**: Batch of `d`-dimensional points
|
||||
of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object.
|
||||
**Y**: Batch of `d`-dimensional points
|
||||
of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object.
|
||||
**init_transform**: A named-tuple `SimilarityTransform` of tensors
|
||||
`R`, `T, `s`, where `R` is a batch of orthonormal matrices of
|
||||
shape `(minibatch, d, d)`, `T` is a batch of translations
|
||||
of shape `(minibatch, d)` and `s` is a batch of scaling factors
|
||||
of shape `(minibatch,)`.
|
||||
**max_iterations**: The maximum number of ICP iterations.
|
||||
**relative_rmse_thr**: A threshold on the relative root mean squared error
|
||||
used to terminate the algorithm.
|
||||
**estimate_scale**: If `True`, also estimates a scaling component `s`
|
||||
of the transformation. Otherwise assumes the identity
|
||||
scale and returns a tensor of ones.
|
||||
**allow_reflection**: If `True`, allows the algorithm to return `R`
|
||||
which is orthonormal but has determinant==-1.
|
||||
**verbose**: If `True`, prints status messages during each ICP iteration.
|
||||
|
||||
Returns:
|
||||
A named tuple `ICPSolution` with the following fields:
|
||||
**converged**: A boolean flag denoting whether the algorithm converged
|
||||
successfully (=`True`) or not (=`False`).
|
||||
**rmse**: Attained root mean squared error after termination of ICP.
|
||||
**Xt**: The point cloud `X` transformed with the final transformation
|
||||
(`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an
|
||||
instance of `Pointclouds`, otherwise returns `torch.Tensor`.
|
||||
**RTs**: A named tuple `SimilarityTransform` containing
|
||||
a batch of similarity transforms with fields:
|
||||
**R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
|
||||
**T**: Batch of translations of shape `(minibatch, d)`.
|
||||
**s**: batch of scaling factors of shape `(minibatch, )`.
|
||||
**t_history**: A list of named tuples `SimilarityTransform`
|
||||
the transformation parameters after each ICP iteration.
|
||||
|
||||
References:
|
||||
[1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992.
|
||||
[2] https://en.wikipedia.org/wiki/Iterative_closest_point
|
||||
"""
|
||||
|
||||
# make sure we convert input Pointclouds structures to
|
||||
# padded tensors of shape (N, P, 3)
|
||||
Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X)
|
||||
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
|
||||
|
||||
b, size_X, dim = Xt.shape
|
||||
|
||||
if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]):
|
||||
raise ValueError(
|
||||
"Point sets X and Y have to have the same "
|
||||
+ "number of batches and data dimensions."
|
||||
)
|
||||
|
||||
if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and (
|
||||
num_points_Y != num_points_X
|
||||
).any():
|
||||
# we have a heterogeneous input (e.g. because X/Y is
|
||||
# an instance of Pointclouds)
|
||||
mask_X = (
|
||||
torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None]
|
||||
< num_points_X[:, None]
|
||||
).type_as(Xt)
|
||||
else:
|
||||
mask_X = Xt.new_ones(b, size_X)
|
||||
|
||||
# clone the initial point cloud
|
||||
Xt_init = Xt.clone()
|
||||
|
||||
if init_transform is not None:
|
||||
# parse the initial transform from the input and apply to Xt
|
||||
try:
|
||||
R, T, s = init_transform
|
||||
assert (
|
||||
R.shape == torch.Size((b, dim, dim))
|
||||
and T.shape == torch.Size((b, dim))
|
||||
and s.shape == torch.Size((b,))
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"The initial transformation init_transform has to be "
|
||||
"a named tuple SimilarityTransform with elements (R, T, s). "
|
||||
"R are dim x dim orthonormal matrices of shape "
|
||||
"(minibatch, dim, dim), T is a batch of dim-dimensional "
|
||||
"translations of shape (minibatch, dim) and s is a batch "
|
||||
"of scalars of shape (minibatch,)."
|
||||
)
|
||||
# apply the init transform to the input point cloud
|
||||
Xt = _apply_similarity_transform(Xt, R, T, s)
|
||||
else:
|
||||
# initialize the transformation with identity
|
||||
R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype)
|
||||
T = Xt.new_zeros((b, dim))
|
||||
s = Xt.new_ones(b)
|
||||
|
||||
prev_rmse = None
|
||||
rmse = None
|
||||
iteration = -1
|
||||
converged = False
|
||||
|
||||
# initialize the transformation history
|
||||
t_history = []
|
||||
|
||||
# the main loop over ICP iterations
|
||||
for iteration in range(max_iterations):
|
||||
Xt_nn_points = knn_points(
|
||||
Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True
|
||||
)[2][:, :, 0, :]
|
||||
|
||||
# get the alignment of the nearest neighbors from Yt with Xt_init
|
||||
R, T, s = corresponding_points_alignment(
|
||||
Xt_init,
|
||||
Xt_nn_points,
|
||||
weights=mask_X,
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=allow_reflection,
|
||||
)
|
||||
|
||||
# apply the estimated similarity transform to Xt_init
|
||||
Xt = _apply_similarity_transform(Xt_init, R, T, s)
|
||||
|
||||
# add the current transformation to the history
|
||||
t_history.append(SimilarityTransform(R, T, s))
|
||||
|
||||
# compute the root mean squared error
|
||||
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
|
||||
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
|
||||
|
||||
# compute the relative rmse
|
||||
if prev_rmse is None:
|
||||
relative_rmse = rmse.new_ones(b)
|
||||
else:
|
||||
relative_rmse = (prev_rmse - rmse) / prev_rmse
|
||||
|
||||
if verbose:
|
||||
rmse_msg = (
|
||||
f"ICP iteration {iteration}: mean/max rmse = "
|
||||
+ f"{rmse.mean():1.2e}/{rmse.max():1.2e} "
|
||||
+ f"; mean relative rmse = {relative_rmse.mean():1.2e}"
|
||||
)
|
||||
print(rmse_msg)
|
||||
|
||||
# check for convergence
|
||||
if (relative_rmse <= relative_rmse_thr).all():
|
||||
converged = True
|
||||
break
|
||||
|
||||
# update the previous rmse
|
||||
prev_rmse = rmse
|
||||
|
||||
if verbose:
|
||||
if converged:
|
||||
print(f"ICP has converged in {iteration + 1} iterations.")
|
||||
else:
|
||||
print(f"ICP has not converged in {max_iterations} iterations.")
|
||||
|
||||
if oputil.is_pointclouds(X):
|
||||
Xt = X.update_padded(Xt) # type: ignore
|
||||
|
||||
return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
|
||||
|
||||
|
||||
# threshold for checking that point crosscorelation
|
||||
# is full rank in corresponding_points_alignment
|
||||
AMBIGUOUS_ROT_SINGULAR_THR = 1e-15
|
||||
|
||||
|
||||
def corresponding_points_alignment(
|
||||
X: Union[torch.Tensor, Pointclouds],
|
||||
Y: Union[torch.Tensor, Pointclouds],
|
||||
X: Union[torch.Tensor, "Pointclouds"],
|
||||
Y: Union[torch.Tensor, "Pointclouds"],
|
||||
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
|
||||
estimate_scale: bool = False,
|
||||
allow_reflection: bool = False,
|
||||
eps: float = 1e-8,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
eps: float = 1e-9,
|
||||
) -> SimilarityTransform:
|
||||
"""
|
||||
Finds a similarity transformation (rotation `R`, translation `T`
|
||||
and optionally scale `s`) between two given sets of corresponding
|
||||
@ -29,25 +238,25 @@ def corresponding_points_alignment(
|
||||
The algorithm is also known as Umeyama [1].
|
||||
|
||||
Args:
|
||||
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
**X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
**Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
weights: Batch of non-negative weights of
|
||||
**weights**: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
|
||||
tensors that may have different shapes; in that case, the length of
|
||||
i-th tensor should be equal to the number of points in X_i and Y_i.
|
||||
Passing `None` means uniform weights.
|
||||
estimate_scale: If `True`, also estimates a scaling component `s`
|
||||
**estimate_scale**: If `True`, also estimates a scaling component `s`
|
||||
of the transformation. Otherwise assumes an identity
|
||||
scale and returns a tensor of ones.
|
||||
allow_reflection: If `True`, allows the algorithm to return `R`
|
||||
**allow_reflection**: If `True`, allows the algorithm to return `R`
|
||||
which is orthonormal but has determinant==-1.
|
||||
eps: A scalar for clamping to avoid dividing by zero. Active for the
|
||||
**eps**: A scalar for clamping to avoid dividing by zero. Active for the
|
||||
code that estimates the output scale `s`.
|
||||
|
||||
Returns:
|
||||
3-element tuple containing
|
||||
3-element named tuple `SimilarityTransform` containing
|
||||
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
|
||||
- **T**: Batch of translations of shape `(minibatch, d)`.
|
||||
- **s**: batch of scaling factors of shape `(minibatch, )`.
|
||||
@ -58,8 +267,8 @@ def corresponding_points_alignment(
|
||||
"""
|
||||
|
||||
# make sure we convert input Pointclouds structures to tensors
|
||||
Xt, num_points = _convert_point_cloud_to_tensor(X)
|
||||
Yt, num_points_Y = _convert_point_cloud_to_tensor(Y)
|
||||
Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
|
||||
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
|
||||
|
||||
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
|
||||
raise ValueError(
|
||||
@ -90,8 +299,8 @@ def corresponding_points_alignment(
|
||||
weights = mask if weights is None else mask * weights.type_as(Xt)
|
||||
|
||||
# compute the centroids of the point sets
|
||||
Xmu = oputil.wmean(Xt, weights, eps=eps)
|
||||
Ymu = oputil.wmean(Yt, weights, eps=eps)
|
||||
Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
|
||||
Ymu = oputil.wmean(Yt, weight=weights, eps=eps)
|
||||
|
||||
# mean-center the point sets
|
||||
Xc = Xt - Xmu
|
||||
@ -107,7 +316,7 @@ def corresponding_points_alignment(
|
||||
if (num_points < (dim + 1)).any():
|
||||
warnings.warn(
|
||||
"The size of one of the point clouds is <= dim+1. "
|
||||
+ "corresponding_points_alignment can't return a unique solution."
|
||||
+ "corresponding_points_alignment cannot return a unique rotation."
|
||||
)
|
||||
|
||||
# compute the covariance XYcov between the point sets Xc, Yc
|
||||
@ -117,6 +326,16 @@ def corresponding_points_alignment(
|
||||
# decompose the covariance matrix XYcov
|
||||
U, S, V = torch.svd(XYcov)
|
||||
|
||||
# catch ambiguous rotation by checking the magnitude of singular values
|
||||
if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not (
|
||||
num_points < (dim + 1)
|
||||
).any():
|
||||
warnings.warn(
|
||||
"Excessively low rank of "
|
||||
+ "cross-correlation between aligned point clouds. "
|
||||
+ "corresponding_points_alignment cannot return a unique rotation."
|
||||
)
|
||||
|
||||
# identity matrix used for fixing reflections
|
||||
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)
|
||||
|
||||
@ -148,26 +367,18 @@ def corresponding_points_alignment(
|
||||
# unit scaling since we do not estimate scale
|
||||
s = T.new_ones(b)
|
||||
|
||||
return R, T, s
|
||||
return SimilarityTransform(R, T, s)
|
||||
|
||||
|
||||
def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]):
|
||||
def _apply_similarity_transform(
|
||||
X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
|
||||
padded representation and returns it together with the number of points
|
||||
per batch. Otherwise, returns the input itself with the number of points
|
||||
set to the size of the second dimension of `pcl`.
|
||||
Applies a similarity transformation parametrized with a batch of orthonormal
|
||||
matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T`
|
||||
of shape `(minibatch, d)` and a batch of scaling factors `s`
|
||||
of shape `(minibatch,)` to a given `d`-dimensional cloud `X`
|
||||
of shape `(minibatch, num_points, d)`
|
||||
"""
|
||||
if isinstance(pcl, Pointclouds):
|
||||
X = pcl.points_padded()
|
||||
num_points = pcl.num_points_per_cloud()
|
||||
elif torch.is_tensor(pcl):
|
||||
X = pcl
|
||||
num_points = X.shape[1] * torch.ones(
|
||||
X.shape[0], device=X.device, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The inputs X, Y should be either Pointclouds objects or tensors."
|
||||
)
|
||||
return X, num_points
|
||||
X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :]
|
||||
return X
|
||||
|
@ -1,9 +1,13 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch3d.structures import Pointclouds
|
||||
|
||||
|
||||
def wmean(
|
||||
x: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
@ -41,3 +45,55 @@ def wmean(
|
||||
return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp(
|
||||
eps
|
||||
)
|
||||
|
||||
|
||||
def eyes(
|
||||
dim: int,
|
||||
N: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generates a batch of `N` identity matrices of shape `(N, dim, dim)`.
|
||||
|
||||
Args:
|
||||
**dim**: The dimensionality of the identity matrices.
|
||||
**N**: The number of identity matrices.
|
||||
**device**: The device to be used for allocating the matrices.
|
||||
**dtype**: The datatype of the matrices.
|
||||
|
||||
Returns:
|
||||
**identities**: A batch of identity matrices of shape `(N, dim, dim)`.
|
||||
"""
|
||||
identities = torch.eye(dim, device=device, dtype=dtype)
|
||||
return identities[None].repeat(N, 1, 1)
|
||||
|
||||
|
||||
def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
|
||||
"""
|
||||
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
|
||||
padded representation and returns it together with the number of points
|
||||
per batch. Otherwise, returns the input itself with the number of points
|
||||
set to the size of the second dimension of `pcl`.
|
||||
"""
|
||||
if is_pointclouds(pcl):
|
||||
X = pcl.points_padded() # type: ignore
|
||||
num_points = pcl.num_points_per_cloud() # type: ignore
|
||||
elif torch.is_tensor(pcl):
|
||||
X = pcl
|
||||
num_points = X.shape[1] * torch.ones( # type: ignore
|
||||
X.shape[0], device=X.device, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The inputs X, Y should be either Pointclouds objects or tensors."
|
||||
)
|
||||
return X, num_points
|
||||
|
||||
|
||||
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
|
||||
""" Checks whether the input `pcl` is an instance `Pointclouds` of
|
||||
by checking the existence of `points_padded` and `num_points_per_cloud`
|
||||
functions.
|
||||
"""
|
||||
return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud")
|
||||
|
@ -5,7 +5,38 @@ from copy import deepcopy
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_points_alignment import TestCorrespondingPointsAlignment
|
||||
from test_points_alignment import TestCorrespondingPointsAlignment, TestICP
|
||||
|
||||
|
||||
def bm_iterative_closest_point() -> None:
|
||||
|
||||
case_grid = {
|
||||
"batch_size": [1, 10],
|
||||
"dim": [3, 20],
|
||||
"n_points_X": [100, 1000],
|
||||
"n_points_Y": [100, 1000],
|
||||
"use_pointclouds": [False],
|
||||
}
|
||||
|
||||
test_args = sorted(case_grid.keys())
|
||||
test_cases = product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
|
||||
|
||||
# add the use_pointclouds=True test cases whenever we have dim==3
|
||||
kwargs_to_add = []
|
||||
for entry in kwargs_list:
|
||||
if entry["dim"] == 3:
|
||||
entry_add = deepcopy(entry)
|
||||
entry_add["use_pointclouds"] = True
|
||||
kwargs_to_add.append(entry_add)
|
||||
kwargs_list.extend(kwargs_to_add)
|
||||
|
||||
benchmark(
|
||||
TestICP.iterative_closest_point,
|
||||
"IterativeClosestPoint",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
def bm_corresponding_points_alignment() -> None:
|
||||
@ -21,7 +52,7 @@ def bm_corresponding_points_alignment() -> None:
|
||||
}
|
||||
|
||||
test_args = sorted(case_grid.keys())
|
||||
test_cases = product(*[case_grid[k] for k in test_args])
|
||||
test_cases = product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
|
||||
|
||||
# add the use_pointclouds=True test cases whenever we have dim==3
|
||||
|
BIN
tests/icp_data.pth
Normal file
BIN
tests/icp_data.pth
Normal file
Binary file not shown.
@ -1,8 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -36,6 +35,256 @@ def _apply_pcl_transformation(X, R, T, s=None):
|
||||
return X_t
|
||||
|
||||
|
||||
class TestICP(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
trimesh_results_path = Path(__file__).resolve().parent / "icp_data.pth"
|
||||
self.trimesh_results = torch.load(trimesh_results_path)
|
||||
|
||||
@staticmethod
|
||||
def iterative_closest_point(
|
||||
batch_size=10,
|
||||
n_points_X=100,
|
||||
n_points_Y=100,
|
||||
dim=3,
|
||||
use_pointclouds=False,
|
||||
estimate_scale=False,
|
||||
):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# initialize a ground truth point cloud
|
||||
X, Y = [
|
||||
TestCorrespondingPointsAlignment.init_point_cloud(
|
||||
batch_size=batch_size,
|
||||
n_points=n_points,
|
||||
dim=dim,
|
||||
device=device,
|
||||
use_pointclouds=use_pointclouds,
|
||||
random_pcl_size=True,
|
||||
fix_seed=i,
|
||||
)
|
||||
for i, n_points in enumerate((n_points_X, n_points_Y))
|
||||
]
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_iterative_closest_point():
|
||||
points_alignment.iterative_closest_point(
|
||||
X,
|
||||
Y,
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
relative_rmse_thr=1e-4,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return run_iterative_closest_point
|
||||
|
||||
def test_init_transformation(self, batch_size=10):
|
||||
"""
|
||||
First runs a full ICP on a random problem. Then takes a given point
|
||||
in the history of ICP iteration transformations, initializes
|
||||
a second run of ICP with this transformation and checks whether
|
||||
both runs ended with the same solution.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
for dim in (2, 3, 11):
|
||||
for n_points_X in (30, 100):
|
||||
for n_points_Y in (30, 100):
|
||||
# initialize ground truth point clouds
|
||||
X, Y = [
|
||||
TestCorrespondingPointsAlignment.init_point_cloud(
|
||||
batch_size=batch_size,
|
||||
n_points=n_points,
|
||||
dim=dim,
|
||||
device=device,
|
||||
use_pointclouds=False,
|
||||
random_pcl_size=True,
|
||||
)
|
||||
for n_points in (n_points_X, n_points_Y)
|
||||
]
|
||||
|
||||
# run full icp
|
||||
converged, _, Xt, (
|
||||
R,
|
||||
T,
|
||||
s,
|
||||
), t_hist = points_alignment.iterative_closest_point(
|
||||
X,
|
||||
Y,
|
||||
estimate_scale=False,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
|
||||
# start from the solution after the third
|
||||
# iteration of the previous ICP
|
||||
t_init = t_hist[min(2, len(t_hist) - 1)]
|
||||
|
||||
# rerun the ICP
|
||||
converged_init, _, Xt_init, (
|
||||
R_init,
|
||||
T_init,
|
||||
s_init,
|
||||
), t_hist_init = points_alignment.iterative_closest_point(
|
||||
X,
|
||||
Y,
|
||||
init_transform=t_init,
|
||||
estimate_scale=False,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
|
||||
# compare transformations and obtained clouds
|
||||
# check that both sets of transforms are the same
|
||||
atol = 3e-5
|
||||
self.assertClose(R_init, R, atol=atol)
|
||||
self.assertClose(T_init, T, atol=atol)
|
||||
self.assertClose(s_init, s, atol=atol)
|
||||
self.assertClose(Xt_init, Xt, atol=atol)
|
||||
|
||||
def test_heterogenous_inputs(self, batch_size=10):
|
||||
"""
|
||||
Tests whether we get the same result when running ICP on
|
||||
a set of randomly-sized Pointclouds and on their padded versions.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
for estimate_scale in (True, False):
|
||||
for max_n_points in (10, 30, 100):
|
||||
# initialize ground truth point clouds
|
||||
X_pcl, Y_pcl = [
|
||||
TestCorrespondingPointsAlignment.init_point_cloud(
|
||||
batch_size=batch_size,
|
||||
n_points=max_n_points,
|
||||
dim=3,
|
||||
device=device,
|
||||
use_pointclouds=True,
|
||||
random_pcl_size=True,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
# get the padded versions and their num of points
|
||||
X_padded = X_pcl.points_padded()
|
||||
Y_padded = Y_pcl.points_padded()
|
||||
n_points_X = X_pcl.num_points_per_cloud()
|
||||
n_points_Y = Y_pcl.num_points_per_cloud()
|
||||
|
||||
# run icp with Pointlouds inputs
|
||||
_, _, Xt_pcl, (
|
||||
R_pcl,
|
||||
T_pcl,
|
||||
s_pcl,
|
||||
), _ = points_alignment.iterative_closest_point(
|
||||
X_pcl,
|
||||
Y_pcl,
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
Xt_pcl = Xt_pcl.points_padded()
|
||||
|
||||
# run icp with tensor inputs on each element
|
||||
# of the batch separately
|
||||
icp_results = [
|
||||
points_alignment.iterative_closest_point(
|
||||
X_[None, :n_X, :],
|
||||
Y_[None, :n_Y, :],
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
for X_, Y_, n_X, n_Y in zip(
|
||||
X_padded, Y_padded, n_points_X, n_points_Y
|
||||
)
|
||||
]
|
||||
|
||||
# parse out the transformation results
|
||||
R, T, s = [
|
||||
torch.cat([x.RTs[i] for x in icp_results], dim=0) for i in range(3)
|
||||
]
|
||||
|
||||
# check that both sets of transforms are the same
|
||||
atol = 1e-5
|
||||
self.assertClose(R_pcl, R, atol=atol)
|
||||
self.assertClose(T_pcl, T, atol=atol)
|
||||
self.assertClose(s_pcl, s, atol=atol)
|
||||
|
||||
# compare the transformed point clouds
|
||||
for pcli in range(batch_size):
|
||||
nX = n_points_X[pcli]
|
||||
Xt_ = icp_results[pcli].Xt[0, :nX]
|
||||
Xt_pcl_ = Xt_pcl[pcli][:nX]
|
||||
self.assertClose(Xt_pcl_, Xt_, atol=atol)
|
||||
|
||||
def test_compare_with_trimesh(self):
|
||||
"""
|
||||
Compares the outputs of `iterative_closest_point` with the results
|
||||
of `trimesh.registration.icp` from the `trimesh` python package:
|
||||
https://github.com/mikedh/trimesh
|
||||
|
||||
We have run `trimesh.registration.icp` on several random problems
|
||||
with different point cloud sizes. The results of trimesh, together with
|
||||
the randomly generated input clouds are loaded in the constructor of
|
||||
this class and this test compares the loaded results to our runs.
|
||||
"""
|
||||
for n_points_X in (10, 20, 50, 100):
|
||||
for n_points_Y in (10, 20, 50, 100):
|
||||
self._compare_with_trimesh(n_points_X=n_points_X, n_points_Y=n_points_Y)
|
||||
|
||||
def _compare_with_trimesh(
|
||||
self, n_points_X=100, n_points_Y=100, estimate_scale=False
|
||||
):
|
||||
"""
|
||||
Executes a single test for `iterative_closest_point` for a
|
||||
specific setting of the inputs / outputs. Compares the result with
|
||||
the result of the trimesh package on the same input data.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# load the trimesh results and the initial point clouds for icp
|
||||
key = (int(n_points_X), int(n_points_Y), int(estimate_scale))
|
||||
X, Y, R_trimesh, T_trimesh, s_trimesh = [
|
||||
x.to(device) for x in self.trimesh_results[key]
|
||||
]
|
||||
|
||||
# run the icp algorithm
|
||||
converged, _, _, (
|
||||
R_ours,
|
||||
T_ours,
|
||||
s_ours,
|
||||
), _ = points_alignment.iterative_closest_point(
|
||||
X,
|
||||
Y,
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
|
||||
# check that we have the same transformation
|
||||
# and that the icp converged
|
||||
atol = 1e-5
|
||||
self.assertClose(R_ours, R_trimesh, atol=atol)
|
||||
self.assertClose(T_ours, T_trimesh, atol=atol)
|
||||
self.assertClose(s_ours, s_trimesh, atol=atol)
|
||||
self.assertTrue(converged)
|
||||
|
||||
|
||||
class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@ -72,10 +321,17 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
device=None,
|
||||
use_pointclouds=False,
|
||||
random_pcl_size=True,
|
||||
fix_seed=None,
|
||||
):
|
||||
"""
|
||||
Generate a batch of normally distributed point clouds.
|
||||
"""
|
||||
|
||||
if fix_seed is not None:
|
||||
# make sure we always generate the same pointcloud
|
||||
seed = torch.random.get_rng_state()
|
||||
torch.manual_seed(fix_seed)
|
||||
|
||||
if use_pointclouds:
|
||||
assert dim == 3, "Pointclouds support only 3-dim points."
|
||||
# generate a `batch_size` point clouds with number of points
|
||||
@ -102,6 +358,10 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
X = torch.randn(
|
||||
batch_size, n_points, dim, device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
if fix_seed:
|
||||
torch.random.set_rng_state(seed)
|
||||
|
||||
return X
|
||||
|
||||
@staticmethod
|
||||
@ -230,7 +490,6 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
- use_pointclouds ... If True, passes the Pointclouds objects
|
||||
to corresponding_points_alignment.
|
||||
"""
|
||||
|
||||
# run this for several different point cloud sizes
|
||||
for n_points in (100, 3, 2, 1):
|
||||
# run this for several different dimensionalities
|
||||
|
Loading…
x
Reference in New Issue
Block a user