mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00

Differential Revision: D37172764 fbshipit-source-id: a2ec367e56de2781a17f5e708eb5832ec9d7e6b4
390 lines
15 KiB
Python
390 lines
15 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import warnings
|
|
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
from pytorch3d.ops import knn_points
|
|
from pytorch3d.structures import utils as strutil
|
|
|
|
from . import utils as oputil
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from pytorch3d.structures.pointclouds import Pointclouds
|
|
|
|
|
|
# named tuples for inputs/outputs
|
|
class SimilarityTransform(NamedTuple):
|
|
R: torch.Tensor
|
|
T: torch.Tensor
|
|
s: torch.Tensor
|
|
|
|
|
|
class ICPSolution(NamedTuple):
|
|
converged: bool
|
|
rmse: Union[torch.Tensor, None]
|
|
Xt: torch.Tensor
|
|
RTs: SimilarityTransform
|
|
t_history: List[SimilarityTransform]
|
|
|
|
|
|
def iterative_closest_point(
|
|
X: Union[torch.Tensor, "Pointclouds"],
|
|
Y: Union[torch.Tensor, "Pointclouds"],
|
|
init_transform: Optional[SimilarityTransform] = None,
|
|
max_iterations: int = 100,
|
|
relative_rmse_thr: float = 1e-6,
|
|
estimate_scale: bool = False,
|
|
allow_reflection: bool = False,
|
|
verbose: bool = False,
|
|
) -> ICPSolution:
|
|
"""
|
|
Executes the iterative closest point (ICP) algorithm [1, 2] in order to find
|
|
a similarity transformation (rotation `R`, translation `T`, and
|
|
optionally scale `s`) between two given differently-sized sets of
|
|
`d`-dimensional points `X` and `Y`, such that:
|
|
|
|
`s[i] X[i] R[i] + T[i] = Y[NN[i]]`,
|
|
|
|
for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands
|
|
for the indices of nearest neighbors from `Y` to each point in `X`.
|
|
Note, however, that the solution is only a local optimum.
|
|
|
|
Args:
|
|
**X**: Batch of `d`-dimensional points
|
|
of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object.
|
|
**Y**: Batch of `d`-dimensional points
|
|
of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object.
|
|
**init_transform**: A named-tuple `SimilarityTransform` of tensors
|
|
`R`, `T, `s`, where `R` is a batch of orthonormal matrices of
|
|
shape `(minibatch, d, d)`, `T` is a batch of translations
|
|
of shape `(minibatch, d)` and `s` is a batch of scaling factors
|
|
of shape `(minibatch,)`.
|
|
**max_iterations**: The maximum number of ICP iterations.
|
|
**relative_rmse_thr**: A threshold on the relative root mean squared error
|
|
used to terminate the algorithm.
|
|
**estimate_scale**: If `True`, also estimates a scaling component `s`
|
|
of the transformation. Otherwise assumes the identity
|
|
scale and returns a tensor of ones.
|
|
**allow_reflection**: If `True`, allows the algorithm to return `R`
|
|
which is orthonormal but has determinant==-1.
|
|
**verbose**: If `True`, prints status messages during each ICP iteration.
|
|
|
|
Returns:
|
|
A named tuple `ICPSolution` with the following fields:
|
|
**converged**: A boolean flag denoting whether the algorithm converged
|
|
successfully (=`True`) or not (=`False`).
|
|
**rmse**: Attained root mean squared error after termination of ICP.
|
|
**Xt**: The point cloud `X` transformed with the final transformation
|
|
(`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an
|
|
instance of `Pointclouds`, otherwise returns `torch.Tensor`.
|
|
**RTs**: A named tuple `SimilarityTransform` containing
|
|
a batch of similarity transforms with fields:
|
|
**R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
|
|
**T**: Batch of translations of shape `(minibatch, d)`.
|
|
**s**: batch of scaling factors of shape `(minibatch, )`.
|
|
**t_history**: A list of named tuples `SimilarityTransform`
|
|
the transformation parameters after each ICP iteration.
|
|
|
|
References:
|
|
[1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992.
|
|
[2] https://en.wikipedia.org/wiki/Iterative_closest_point
|
|
"""
|
|
|
|
# make sure we convert input Pointclouds structures to
|
|
# padded tensors of shape (N, P, 3)
|
|
Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X)
|
|
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
|
|
|
|
b, size_X, dim = Xt.shape
|
|
|
|
if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]):
|
|
raise ValueError(
|
|
"Point sets X and Y have to have the same "
|
|
+ "number of batches and data dimensions."
|
|
)
|
|
|
|
if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and (
|
|
num_points_Y != num_points_X
|
|
).any():
|
|
# we have a heterogeneous input (e.g. because X/Y is
|
|
# an instance of Pointclouds)
|
|
mask_X = (
|
|
torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None]
|
|
< num_points_X[:, None]
|
|
).type_as(Xt)
|
|
else:
|
|
mask_X = Xt.new_ones(b, size_X)
|
|
|
|
# clone the initial point cloud
|
|
Xt_init = Xt.clone()
|
|
|
|
if init_transform is not None:
|
|
# parse the initial transform from the input and apply to Xt
|
|
try:
|
|
R, T, s = init_transform
|
|
assert (
|
|
R.shape == torch.Size((b, dim, dim))
|
|
and T.shape == torch.Size((b, dim))
|
|
and s.shape == torch.Size((b,))
|
|
)
|
|
except Exception:
|
|
raise ValueError(
|
|
"The initial transformation init_transform has to be "
|
|
"a named tuple SimilarityTransform with elements (R, T, s). "
|
|
"R are dim x dim orthonormal matrices of shape "
|
|
"(minibatch, dim, dim), T is a batch of dim-dimensional "
|
|
"translations of shape (minibatch, dim) and s is a batch "
|
|
"of scalars of shape (minibatch,)."
|
|
)
|
|
# apply the init transform to the input point cloud
|
|
Xt = _apply_similarity_transform(Xt, R, T, s)
|
|
else:
|
|
# initialize the transformation with identity
|
|
R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype)
|
|
T = Xt.new_zeros((b, dim))
|
|
s = Xt.new_ones(b)
|
|
|
|
prev_rmse = None
|
|
rmse = None
|
|
iteration = -1
|
|
converged = False
|
|
|
|
# initialize the transformation history
|
|
t_history = []
|
|
|
|
# the main loop over ICP iterations
|
|
for iteration in range(max_iterations):
|
|
Xt_nn_points = knn_points(
|
|
Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True
|
|
).knn[:, :, 0, :]
|
|
|
|
# get the alignment of the nearest neighbors from Yt with Xt_init
|
|
R, T, s = corresponding_points_alignment(
|
|
Xt_init,
|
|
Xt_nn_points,
|
|
weights=mask_X,
|
|
estimate_scale=estimate_scale,
|
|
allow_reflection=allow_reflection,
|
|
)
|
|
|
|
# apply the estimated similarity transform to Xt_init
|
|
Xt = _apply_similarity_transform(Xt_init, R, T, s)
|
|
|
|
# add the current transformation to the history
|
|
t_history.append(SimilarityTransform(R, T, s))
|
|
|
|
# compute the root mean squared error
|
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
|
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
|
|
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
|
|
|
|
# compute the relative rmse
|
|
if prev_rmse is None:
|
|
relative_rmse = rmse.new_ones(b)
|
|
else:
|
|
relative_rmse = (prev_rmse - rmse) / prev_rmse
|
|
|
|
if verbose:
|
|
rmse_msg = (
|
|
f"ICP iteration {iteration}: mean/max rmse = "
|
|
+ f"{rmse.mean():1.2e}/{rmse.max():1.2e} "
|
|
+ f"; mean relative rmse = {relative_rmse.mean():1.2e}"
|
|
)
|
|
print(rmse_msg)
|
|
|
|
# check for convergence
|
|
if (relative_rmse <= relative_rmse_thr).all():
|
|
converged = True
|
|
break
|
|
|
|
# update the previous rmse
|
|
prev_rmse = rmse
|
|
|
|
if verbose:
|
|
if converged:
|
|
print(f"ICP has converged in {iteration + 1} iterations.")
|
|
else:
|
|
print(f"ICP has not converged in {max_iterations} iterations.")
|
|
|
|
if oputil.is_pointclouds(X):
|
|
Xt = X.update_padded(Xt) # type: ignore
|
|
|
|
return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
|
|
|
|
|
|
# threshold for checking that point crosscorelation
|
|
# is full rank in corresponding_points_alignment
|
|
AMBIGUOUS_ROT_SINGULAR_THR = 1e-15
|
|
|
|
|
|
def corresponding_points_alignment(
|
|
X: Union[torch.Tensor, "Pointclouds"],
|
|
Y: Union[torch.Tensor, "Pointclouds"],
|
|
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
|
|
estimate_scale: bool = False,
|
|
allow_reflection: bool = False,
|
|
eps: float = 1e-9,
|
|
) -> SimilarityTransform:
|
|
"""
|
|
Finds a similarity transformation (rotation `R`, translation `T`
|
|
and optionally scale `s`) between two given sets of corresponding
|
|
`d`-dimensional points `X` and `Y` such that:
|
|
|
|
`s[i] X[i] R[i] + T[i] = Y[i]`,
|
|
|
|
for all batch indexes `i` in the least squares sense.
|
|
|
|
The algorithm is also known as Umeyama [1].
|
|
|
|
Args:
|
|
**X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
|
or a `Pointclouds` object.
|
|
**Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
|
or a `Pointclouds` object.
|
|
**weights**: Batch of non-negative weights of
|
|
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
|
|
tensors that may have different shapes; in that case, the length of
|
|
i-th tensor should be equal to the number of points in X_i and Y_i.
|
|
Passing `None` means uniform weights.
|
|
**estimate_scale**: If `True`, also estimates a scaling component `s`
|
|
of the transformation. Otherwise assumes an identity
|
|
scale and returns a tensor of ones.
|
|
**allow_reflection**: If `True`, allows the algorithm to return `R`
|
|
which is orthonormal but has determinant==-1.
|
|
**eps**: A scalar for clamping to avoid dividing by zero. Active for the
|
|
code that estimates the output scale `s`.
|
|
|
|
Returns:
|
|
3-element named tuple `SimilarityTransform` containing
|
|
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
|
|
- **T**: Batch of translations of shape `(minibatch, d)`.
|
|
- **s**: batch of scaling factors of shape `(minibatch, )`.
|
|
|
|
References:
|
|
[1] Shinji Umeyama: Least-Suqares Estimation of
|
|
Transformation Parameters Between Two Point Patterns
|
|
"""
|
|
|
|
# make sure we convert input Pointclouds structures to tensors
|
|
Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
|
|
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
|
|
|
|
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
|
|
raise ValueError(
|
|
"Point sets X and Y have to have the same \
|
|
number of batches, points and dimensions."
|
|
)
|
|
if weights is not None:
|
|
if isinstance(weights, list):
|
|
if any(np != w.shape[0] for np, w in zip(num_points, weights)):
|
|
raise ValueError(
|
|
"number of weights should equal to the "
|
|
+ "number of points in the point cloud."
|
|
)
|
|
weights = [w[..., None] for w in weights]
|
|
weights = strutil.list_to_padded(weights)[..., 0]
|
|
|
|
if Xt.shape[:2] != weights.shape:
|
|
raise ValueError("weights should have the same first two dimensions as X.")
|
|
|
|
b, n, dim = Xt.shape
|
|
|
|
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
|
|
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
|
|
mask = (
|
|
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
|
|
< num_points[:, None]
|
|
).type_as(Xt)
|
|
weights = mask if weights is None else mask * weights.type_as(Xt)
|
|
|
|
# compute the centroids of the point sets
|
|
Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
|
|
Ymu = oputil.wmean(Yt, weight=weights, eps=eps)
|
|
|
|
# mean-center the point sets
|
|
Xc = Xt - Xmu
|
|
Yc = Yt - Ymu
|
|
|
|
total_weight = torch.clamp(num_points, 1)
|
|
# special handling for heterogeneous point clouds and/or input weights
|
|
if weights is not None:
|
|
Xc *= weights[:, :, None]
|
|
Yc *= weights[:, :, None]
|
|
total_weight = torch.clamp(weights.sum(1), eps)
|
|
|
|
if (num_points < (dim + 1)).any():
|
|
warnings.warn(
|
|
"The size of one of the point clouds is <= dim+1. "
|
|
+ "corresponding_points_alignment cannot return a unique rotation."
|
|
)
|
|
|
|
# compute the covariance XYcov between the point sets Xc, Yc
|
|
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
|
|
XYcov = XYcov / total_weight[:, None, None]
|
|
|
|
# decompose the covariance matrix XYcov
|
|
U, S, V = torch.svd(XYcov)
|
|
|
|
# catch ambiguous rotation by checking the magnitude of singular values
|
|
if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not (
|
|
num_points < (dim + 1)
|
|
).any():
|
|
warnings.warn(
|
|
"Excessively low rank of "
|
|
+ "cross-correlation between aligned point clouds. "
|
|
+ "corresponding_points_alignment cannot return a unique rotation."
|
|
)
|
|
|
|
# identity matrix used for fixing reflections
|
|
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)
|
|
|
|
if not allow_reflection:
|
|
# reflection test:
|
|
# checks whether the estimated rotation has det==1,
|
|
# if not, finds the nearest rotation s.t. det==1 by
|
|
# flipping the sign of the last singular vector U
|
|
R_test = torch.bmm(U, V.transpose(2, 1))
|
|
E[:, -1, -1] = torch.det(R_test)
|
|
|
|
# find the rotation matrix by composing U and V again
|
|
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
|
|
|
|
if estimate_scale:
|
|
# estimate the scaling component of the transformation
|
|
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
|
|
Xcov = (Xc * Xc).sum((1, 2)) / total_weight
|
|
|
|
# the scaling component
|
|
s = trace_ES / torch.clamp(Xcov, eps)
|
|
|
|
# translation component
|
|
T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
|
|
else:
|
|
# translation component
|
|
T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]
|
|
|
|
# unit scaling since we do not estimate scale
|
|
s = T.new_ones(b)
|
|
|
|
return SimilarityTransform(R, T, s)
|
|
|
|
|
|
def _apply_similarity_transform(
|
|
X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Applies a similarity transformation parametrized with a batch of orthonormal
|
|
matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T`
|
|
of shape `(minibatch, d)` and a batch of scaling factors `s`
|
|
of shape `(minibatch,)` to a given `d`-dimensional cloud `X`
|
|
of shape `(minibatch, num_points, d)`
|
|
"""
|
|
X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :]
|
|
return X
|