diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 48703e2a..fe522d3d 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -6,9 +6,10 @@ from .graph_conv import GraphConv from .knn import knn_gather, knn_points from .mesh_face_areas_normals import mesh_face_areas_normals from .packed_to_padded import packed_to_padded, padded_to_packed -from .points_alignment import corresponding_points_alignment +from .points_alignment import corresponding_points_alignment, iterative_closest_point from .sample_points_from_meshes import sample_points_from_meshes from .subdivide_meshes import SubdivideMeshes +from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean from .vert_align import vert_align diff --git a/pytorch3d/ops/points_alignment.py b/pytorch3d/ops/points_alignment.py index 80100f5b..7ac3f182 100644 --- a/pytorch3d/ops/points_alignment.py +++ b/pytorch3d/ops/points_alignment.py @@ -1,22 +1,231 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import warnings -from typing import List, Tuple, Union +from typing import TYPE_CHECKING, List, NamedTuple, Optional, Union import torch -from pytorch3d.ops import utils as oputil +from pytorch3d.ops import knn_points from pytorch3d.structures import utils as strutil -from pytorch3d.structures.pointclouds import Pointclouds + +from . import utils as oputil + + +if TYPE_CHECKING: + from pytorch3d.structures.pointclouds import Pointclouds + + +# named tuples for inputs/outputs +class SimilarityTransform(NamedTuple): + R: torch.Tensor + T: torch.Tensor + s: torch.Tensor + + +class ICPSolution(NamedTuple): + converged: bool + rmse: Union[torch.Tensor, None] + Xt: torch.Tensor + RTs: SimilarityTransform + t_history: List[SimilarityTransform] + + +def iterative_closest_point( + X: Union[torch.Tensor, "Pointclouds"], + Y: Union[torch.Tensor, "Pointclouds"], + init_transform: Optional[SimilarityTransform] = None, + max_iterations: int = 100, + relative_rmse_thr: float = 1e-6, + estimate_scale: bool = False, + allow_reflection: bool = False, + verbose: bool = False, +) -> ICPSolution: + """ + Executes the iterative closest point (ICP) algorithm [1, 2] in order to find + a similarity transformation (rotation `R`, translation `T`, and + optionally scale `s`) between two given differently-sized sets of + `d`-dimensional points `X` and `Y`, such that: + + `s[i] X[i] R[i] + T[i] = Y[NN[i]]`, + + for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands + for the indices of nearest neighbors from `Y` to each point in `X`. + Note, however, that the solution is only a local optimum. + + Args: + **X**: Batch of `d`-dimensional points + of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object. + **Y**: Batch of `d`-dimensional points + of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object. + **init_transform**: A named-tuple `SimilarityTransform` of tensors + `R`, `T, `s`, where `R` is a batch of orthonormal matrices of + shape `(minibatch, d, d)`, `T` is a batch of translations + of shape `(minibatch, d)` and `s` is a batch of scaling factors + of shape `(minibatch,)`. + **max_iterations**: The maximum number of ICP iterations. + **relative_rmse_thr**: A threshold on the relative root mean squared error + used to terminate the algorithm. + **estimate_scale**: If `True`, also estimates a scaling component `s` + of the transformation. Otherwise assumes the identity + scale and returns a tensor of ones. + **allow_reflection**: If `True`, allows the algorithm to return `R` + which is orthonormal but has determinant==-1. + **verbose**: If `True`, prints status messages during each ICP iteration. + + Returns: + A named tuple `ICPSolution` with the following fields: + **converged**: A boolean flag denoting whether the algorithm converged + successfully (=`True`) or not (=`False`). + **rmse**: Attained root mean squared error after termination of ICP. + **Xt**: The point cloud `X` transformed with the final transformation + (`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an + instance of `Pointclouds`, otherwise returns `torch.Tensor`. + **RTs**: A named tuple `SimilarityTransform` containing + a batch of similarity transforms with fields: + **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`. + **T**: Batch of translations of shape `(minibatch, d)`. + **s**: batch of scaling factors of shape `(minibatch, )`. + **t_history**: A list of named tuples `SimilarityTransform` + the transformation parameters after each ICP iteration. + + References: + [1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992. + [2] https://en.wikipedia.org/wiki/Iterative_closest_point + """ + + # make sure we convert input Pointclouds structures to + # padded tensors of shape (N, P, 3) + Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X) + Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y) + + b, size_X, dim = Xt.shape + + if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]): + raise ValueError( + "Point sets X and Y have to have the same " + + "number of batches and data dimensions." + ) + + if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and ( + num_points_Y != num_points_X + ).any(): + # we have a heterogeneous input (e.g. because X/Y is + # an instance of Pointclouds) + mask_X = ( + torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None] + < num_points_X[:, None] + ).type_as(Xt) + else: + mask_X = Xt.new_ones(b, size_X) + + # clone the initial point cloud + Xt_init = Xt.clone() + + if init_transform is not None: + # parse the initial transform from the input and apply to Xt + try: + R, T, s = init_transform + assert ( + R.shape == torch.Size((b, dim, dim)) + and T.shape == torch.Size((b, dim)) + and s.shape == torch.Size((b,)) + ) + except Exception: + raise ValueError( + "The initial transformation init_transform has to be " + "a named tuple SimilarityTransform with elements (R, T, s). " + "R are dim x dim orthonormal matrices of shape " + "(minibatch, dim, dim), T is a batch of dim-dimensional " + "translations of shape (minibatch, dim) and s is a batch " + "of scalars of shape (minibatch,)." + ) + # apply the init transform to the input point cloud + Xt = _apply_similarity_transform(Xt, R, T, s) + else: + # initialize the transformation with identity + R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype) + T = Xt.new_zeros((b, dim)) + s = Xt.new_ones(b) + + prev_rmse = None + rmse = None + iteration = -1 + converged = False + + # initialize the transformation history + t_history = [] + + # the main loop over ICP iterations + for iteration in range(max_iterations): + Xt_nn_points = knn_points( + Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True + )[2][:, :, 0, :] + + # get the alignment of the nearest neighbors from Yt with Xt_init + R, T, s = corresponding_points_alignment( + Xt_init, + Xt_nn_points, + weights=mask_X, + estimate_scale=estimate_scale, + allow_reflection=allow_reflection, + ) + + # apply the estimated similarity transform to Xt_init + Xt = _apply_similarity_transform(Xt_init, R, T, s) + + # add the current transformation to the history + t_history.append(SimilarityTransform(R, T, s)) + + # compute the root mean squared error + Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2) + rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0] + + # compute the relative rmse + if prev_rmse is None: + relative_rmse = rmse.new_ones(b) + else: + relative_rmse = (prev_rmse - rmse) / prev_rmse + + if verbose: + rmse_msg = ( + f"ICP iteration {iteration}: mean/max rmse = " + + f"{rmse.mean():1.2e}/{rmse.max():1.2e} " + + f"; mean relative rmse = {relative_rmse.mean():1.2e}" + ) + print(rmse_msg) + + # check for convergence + if (relative_rmse <= relative_rmse_thr).all(): + converged = True + break + + # update the previous rmse + prev_rmse = rmse + + if verbose: + if converged: + print(f"ICP has converged in {iteration + 1} iterations.") + else: + print(f"ICP has not converged in {max_iterations} iterations.") + + if oputil.is_pointclouds(X): + Xt = X.update_padded(Xt) # type: ignore + + return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history) + + +# threshold for checking that point crosscorelation +# is full rank in corresponding_points_alignment +AMBIGUOUS_ROT_SINGULAR_THR = 1e-15 def corresponding_points_alignment( - X: Union[torch.Tensor, Pointclouds], - Y: Union[torch.Tensor, Pointclouds], + X: Union[torch.Tensor, "Pointclouds"], + Y: Union[torch.Tensor, "Pointclouds"], weights: Union[torch.Tensor, List[torch.Tensor], None] = None, estimate_scale: bool = False, allow_reflection: bool = False, - eps: float = 1e-8, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + eps: float = 1e-9, +) -> SimilarityTransform: """ Finds a similarity transformation (rotation `R`, translation `T` and optionally scale `s`) between two given sets of corresponding @@ -29,25 +238,25 @@ def corresponding_points_alignment( The algorithm is also known as Umeyama [1]. Args: - X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` + **X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` or a `Pointclouds` object. - Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` + **Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` or a `Pointclouds` object. - weights: Batch of non-negative weights of + **weights**: Batch of non-negative weights of shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional tensors that may have different shapes; in that case, the length of i-th tensor should be equal to the number of points in X_i and Y_i. Passing `None` means uniform weights. - estimate_scale: If `True`, also estimates a scaling component `s` + **estimate_scale**: If `True`, also estimates a scaling component `s` of the transformation. Otherwise assumes an identity scale and returns a tensor of ones. - allow_reflection: If `True`, allows the algorithm to return `R` + **allow_reflection**: If `True`, allows the algorithm to return `R` which is orthonormal but has determinant==-1. - eps: A scalar for clamping to avoid dividing by zero. Active for the + **eps**: A scalar for clamping to avoid dividing by zero. Active for the code that estimates the output scale `s`. Returns: - 3-element tuple containing + 3-element named tuple `SimilarityTransform` containing - **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`. - **T**: Batch of translations of shape `(minibatch, d)`. - **s**: batch of scaling factors of shape `(minibatch, )`. @@ -58,8 +267,8 @@ def corresponding_points_alignment( """ # make sure we convert input Pointclouds structures to tensors - Xt, num_points = _convert_point_cloud_to_tensor(X) - Yt, num_points_Y = _convert_point_cloud_to_tensor(Y) + Xt, num_points = oputil.convert_pointclouds_to_tensor(X) + Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y) if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any(): raise ValueError( @@ -90,8 +299,8 @@ def corresponding_points_alignment( weights = mask if weights is None else mask * weights.type_as(Xt) # compute the centroids of the point sets - Xmu = oputil.wmean(Xt, weights, eps=eps) - Ymu = oputil.wmean(Yt, weights, eps=eps) + Xmu = oputil.wmean(Xt, weight=weights, eps=eps) + Ymu = oputil.wmean(Yt, weight=weights, eps=eps) # mean-center the point sets Xc = Xt - Xmu @@ -107,7 +316,7 @@ def corresponding_points_alignment( if (num_points < (dim + 1)).any(): warnings.warn( "The size of one of the point clouds is <= dim+1. " - + "corresponding_points_alignment can't return a unique solution." + + "corresponding_points_alignment cannot return a unique rotation." ) # compute the covariance XYcov between the point sets Xc, Yc @@ -117,6 +326,16 @@ def corresponding_points_alignment( # decompose the covariance matrix XYcov U, S, V = torch.svd(XYcov) + # catch ambiguous rotation by checking the magnitude of singular values + if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not ( + num_points < (dim + 1) + ).any(): + warnings.warn( + "Excessively low rank of " + + "cross-correlation between aligned point clouds. " + + "corresponding_points_alignment cannot return a unique rotation." + ) + # identity matrix used for fixing reflections E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1) @@ -148,26 +367,18 @@ def corresponding_points_alignment( # unit scaling since we do not estimate scale s = T.new_ones(b) - return R, T, s + return SimilarityTransform(R, T, s) -def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]): +def _apply_similarity_transform( + X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor +) -> torch.Tensor: """ - If `type(pcl)==Pointclouds`, converts a `pcl` object to a - padded representation and returns it together with the number of points - per batch. Otherwise, returns the input itself with the number of points - set to the size of the second dimension of `pcl`. + Applies a similarity transformation parametrized with a batch of orthonormal + matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T` + of shape `(minibatch, d)` and a batch of scaling factors `s` + of shape `(minibatch,)` to a given `d`-dimensional cloud `X` + of shape `(minibatch, num_points, d)` """ - if isinstance(pcl, Pointclouds): - X = pcl.points_padded() - num_points = pcl.num_points_per_cloud() - elif torch.is_tensor(pcl): - X = pcl - num_points = X.shape[1] * torch.ones( - X.shape[0], device=X.device, dtype=torch.int64 - ) - else: - raise ValueError( - "The inputs X, Y should be either Pointclouds objects or tensors." - ) - return X, num_points + X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :] + return X diff --git a/pytorch3d/ops/utils.py b/pytorch3d/ops/utils.py index fa690ee1..134172b0 100644 --- a/pytorch3d/ops/utils.py +++ b/pytorch3d/ops/utils.py @@ -1,9 +1,13 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union import torch +if TYPE_CHECKING: + from pytorch3d.structures import Pointclouds + + def wmean( x: torch.Tensor, weight: Optional[torch.Tensor] = None, @@ -41,3 +45,55 @@ def wmean( return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp( eps ) + + +def eyes( + dim: int, + N: int, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Generates a batch of `N` identity matrices of shape `(N, dim, dim)`. + + Args: + **dim**: The dimensionality of the identity matrices. + **N**: The number of identity matrices. + **device**: The device to be used for allocating the matrices. + **dtype**: The datatype of the matrices. + + Returns: + **identities**: A batch of identity matrices of shape `(N, dim, dim)`. + """ + identities = torch.eye(dim, device=device, dtype=dtype) + return identities[None].repeat(N, 1, 1) + + +def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]): + """ + If `type(pcl)==Pointclouds`, converts a `pcl` object to a + padded representation and returns it together with the number of points + per batch. Otherwise, returns the input itself with the number of points + set to the size of the second dimension of `pcl`. + """ + if is_pointclouds(pcl): + X = pcl.points_padded() # type: ignore + num_points = pcl.num_points_per_cloud() # type: ignore + elif torch.is_tensor(pcl): + X = pcl + num_points = X.shape[1] * torch.ones( # type: ignore + X.shape[0], device=X.device, dtype=torch.int64 + ) + else: + raise ValueError( + "The inputs X, Y should be either Pointclouds objects or tensors." + ) + return X, num_points + + +def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]): + """ Checks whether the input `pcl` is an instance `Pointclouds` of + by checking the existence of `points_padded` and `num_points_per_cloud` + functions. + """ + return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud") diff --git a/tests/bm_points_alignment.py b/tests/bm_points_alignment.py index 24f8d0d2..942e76aa 100644 --- a/tests/bm_points_alignment.py +++ b/tests/bm_points_alignment.py @@ -5,7 +5,38 @@ from copy import deepcopy from itertools import product from fvcore.common.benchmark import benchmark -from test_points_alignment import TestCorrespondingPointsAlignment +from test_points_alignment import TestCorrespondingPointsAlignment, TestICP + + +def bm_iterative_closest_point() -> None: + + case_grid = { + "batch_size": [1, 10], + "dim": [3, 20], + "n_points_X": [100, 1000], + "n_points_Y": [100, 1000], + "use_pointclouds": [False], + } + + test_args = sorted(case_grid.keys()) + test_cases = product(*case_grid.values()) + kwargs_list = [dict(zip(test_args, case)) for case in test_cases] + + # add the use_pointclouds=True test cases whenever we have dim==3 + kwargs_to_add = [] + for entry in kwargs_list: + if entry["dim"] == 3: + entry_add = deepcopy(entry) + entry_add["use_pointclouds"] = True + kwargs_to_add.append(entry_add) + kwargs_list.extend(kwargs_to_add) + + benchmark( + TestICP.iterative_closest_point, + "IterativeClosestPoint", + kwargs_list, + warmup_iters=1, + ) def bm_corresponding_points_alignment() -> None: @@ -21,7 +52,7 @@ def bm_corresponding_points_alignment() -> None: } test_args = sorted(case_grid.keys()) - test_cases = product(*[case_grid[k] for k in test_args]) + test_cases = product(*case_grid.values()) kwargs_list = [dict(zip(test_args, case)) for case in test_cases] # add the use_pointclouds=True test cases whenever we have dim==3 diff --git a/tests/icp_data.pth b/tests/icp_data.pth new file mode 100644 index 00000000..e99a61b3 Binary files /dev/null and b/tests/icp_data.pth differ diff --git a/tests/test_points_alignment.py b/tests/test_points_alignment.py index 35a00b8e..aa8770d0 100644 --- a/tests/test_points_alignment.py +++ b/tests/test_points_alignment.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - import unittest +from pathlib import Path import numpy as np import torch @@ -36,6 +35,256 @@ def _apply_pcl_transformation(X, R, T, s=None): return X_t +class TestICP(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + trimesh_results_path = Path(__file__).resolve().parent / "icp_data.pth" + self.trimesh_results = torch.load(trimesh_results_path) + + @staticmethod + def iterative_closest_point( + batch_size=10, + n_points_X=100, + n_points_Y=100, + dim=3, + use_pointclouds=False, + estimate_scale=False, + ): + + device = torch.device("cuda:0") + + # initialize a ground truth point cloud + X, Y = [ + TestCorrespondingPointsAlignment.init_point_cloud( + batch_size=batch_size, + n_points=n_points, + dim=dim, + device=device, + use_pointclouds=use_pointclouds, + random_pcl_size=True, + fix_seed=i, + ) + for i, n_points in enumerate((n_points_X, n_points_Y)) + ] + + torch.cuda.synchronize() + + def run_iterative_closest_point(): + points_alignment.iterative_closest_point( + X, + Y, + estimate_scale=estimate_scale, + allow_reflection=False, + verbose=False, + max_iterations=100, + relative_rmse_thr=1e-4, + ) + torch.cuda.synchronize() + + return run_iterative_closest_point + + def test_init_transformation(self, batch_size=10): + """ + First runs a full ICP on a random problem. Then takes a given point + in the history of ICP iteration transformations, initializes + a second run of ICP with this transformation and checks whether + both runs ended with the same solution. + """ + + device = torch.device("cuda:0") + + for dim in (2, 3, 11): + for n_points_X in (30, 100): + for n_points_Y in (30, 100): + # initialize ground truth point clouds + X, Y = [ + TestCorrespondingPointsAlignment.init_point_cloud( + batch_size=batch_size, + n_points=n_points, + dim=dim, + device=device, + use_pointclouds=False, + random_pcl_size=True, + ) + for n_points in (n_points_X, n_points_Y) + ] + + # run full icp + converged, _, Xt, ( + R, + T, + s, + ), t_hist = points_alignment.iterative_closest_point( + X, + Y, + estimate_scale=False, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + + # start from the solution after the third + # iteration of the previous ICP + t_init = t_hist[min(2, len(t_hist) - 1)] + + # rerun the ICP + converged_init, _, Xt_init, ( + R_init, + T_init, + s_init, + ), t_hist_init = points_alignment.iterative_closest_point( + X, + Y, + init_transform=t_init, + estimate_scale=False, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + + # compare transformations and obtained clouds + # check that both sets of transforms are the same + atol = 3e-5 + self.assertClose(R_init, R, atol=atol) + self.assertClose(T_init, T, atol=atol) + self.assertClose(s_init, s, atol=atol) + self.assertClose(Xt_init, Xt, atol=atol) + + def test_heterogenous_inputs(self, batch_size=10): + """ + Tests whether we get the same result when running ICP on + a set of randomly-sized Pointclouds and on their padded versions. + """ + + device = torch.device("cuda:0") + + for estimate_scale in (True, False): + for max_n_points in (10, 30, 100): + # initialize ground truth point clouds + X_pcl, Y_pcl = [ + TestCorrespondingPointsAlignment.init_point_cloud( + batch_size=batch_size, + n_points=max_n_points, + dim=3, + device=device, + use_pointclouds=True, + random_pcl_size=True, + ) + for _ in range(2) + ] + + # get the padded versions and their num of points + X_padded = X_pcl.points_padded() + Y_padded = Y_pcl.points_padded() + n_points_X = X_pcl.num_points_per_cloud() + n_points_Y = Y_pcl.num_points_per_cloud() + + # run icp with Pointlouds inputs + _, _, Xt_pcl, ( + R_pcl, + T_pcl, + s_pcl, + ), _ = points_alignment.iterative_closest_point( + X_pcl, + Y_pcl, + estimate_scale=estimate_scale, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + Xt_pcl = Xt_pcl.points_padded() + + # run icp with tensor inputs on each element + # of the batch separately + icp_results = [ + points_alignment.iterative_closest_point( + X_[None, :n_X, :], + Y_[None, :n_Y, :], + estimate_scale=estimate_scale, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + for X_, Y_, n_X, n_Y in zip( + X_padded, Y_padded, n_points_X, n_points_Y + ) + ] + + # parse out the transformation results + R, T, s = [ + torch.cat([x.RTs[i] for x in icp_results], dim=0) for i in range(3) + ] + + # check that both sets of transforms are the same + atol = 1e-5 + self.assertClose(R_pcl, R, atol=atol) + self.assertClose(T_pcl, T, atol=atol) + self.assertClose(s_pcl, s, atol=atol) + + # compare the transformed point clouds + for pcli in range(batch_size): + nX = n_points_X[pcli] + Xt_ = icp_results[pcli].Xt[0, :nX] + Xt_pcl_ = Xt_pcl[pcli][:nX] + self.assertClose(Xt_pcl_, Xt_, atol=atol) + + def test_compare_with_trimesh(self): + """ + Compares the outputs of `iterative_closest_point` with the results + of `trimesh.registration.icp` from the `trimesh` python package: + https://github.com/mikedh/trimesh + + We have run `trimesh.registration.icp` on several random problems + with different point cloud sizes. The results of trimesh, together with + the randomly generated input clouds are loaded in the constructor of + this class and this test compares the loaded results to our runs. + """ + for n_points_X in (10, 20, 50, 100): + for n_points_Y in (10, 20, 50, 100): + self._compare_with_trimesh(n_points_X=n_points_X, n_points_Y=n_points_Y) + + def _compare_with_trimesh( + self, n_points_X=100, n_points_Y=100, estimate_scale=False + ): + """ + Executes a single test for `iterative_closest_point` for a + specific setting of the inputs / outputs. Compares the result with + the result of the trimesh package on the same input data. + """ + + device = torch.device("cuda:0") + + # load the trimesh results and the initial point clouds for icp + key = (int(n_points_X), int(n_points_Y), int(estimate_scale)) + X, Y, R_trimesh, T_trimesh, s_trimesh = [ + x.to(device) for x in self.trimesh_results[key] + ] + + # run the icp algorithm + converged, _, _, ( + R_ours, + T_ours, + s_ours, + ), _ = points_alignment.iterative_closest_point( + X, + Y, + estimate_scale=estimate_scale, + allow_reflection=False, + verbose=False, + max_iterations=100, + ) + + # check that we have the same transformation + # and that the icp converged + atol = 1e-5 + self.assertClose(R_ours, R_trimesh, atol=atol) + self.assertClose(T_ours, T_trimesh, atol=atol) + self.assertClose(s_ours, s_trimesh, atol=atol) + self.assertTrue(converged) + + class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() @@ -72,10 +321,17 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): device=None, use_pointclouds=False, random_pcl_size=True, + fix_seed=None, ): """ Generate a batch of normally distributed point clouds. """ + + if fix_seed is not None: + # make sure we always generate the same pointcloud + seed = torch.random.get_rng_state() + torch.manual_seed(fix_seed) + if use_pointclouds: assert dim == 3, "Pointclouds support only 3-dim points." # generate a `batch_size` point clouds with number of points @@ -102,6 +358,10 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): X = torch.randn( batch_size, n_points, dim, device=device, dtype=torch.float32 ) + + if fix_seed: + torch.random.set_rng_state(seed) + return X @staticmethod @@ -230,7 +490,6 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase): - use_pointclouds ... If True, passes the Pointclouds objects to corresponding_points_alignment. """ - # run this for several different point cloud sizes for n_points in (100, 3, 2, 1): # run this for several different dimensionalities