ICP - point-to-point version

Summary:
The iterative closest point algorithm - point-to-point version.

Output of `bm_iterative_closest_point`:
Argument key: `batch_size dim n_points_X n_points_Y use_pointclouds`

```
Benchmark                                         Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
IterativeClosestPoint_1_3_100_100_False              107569          111323              5
IterativeClosestPoint_1_3_100_1000_False             118972          122306              5
IterativeClosestPoint_1_3_1000_100_False             108576          110978              5
IterativeClosestPoint_1_3_1000_1000_False            331836          333515              2
IterativeClosestPoint_1_20_100_100_False             134387          137842              4
IterativeClosestPoint_1_20_100_1000_False            149218          153405              4
IterativeClosestPoint_1_20_1000_100_False            414248          416595              2
IterativeClosestPoint_1_20_1000_1000_False           374318          374662              2
IterativeClosestPoint_10_3_100_100_False             539852          539852              1
IterativeClosestPoint_10_3_100_1000_False            752784          752784              1
IterativeClosestPoint_10_3_1000_100_False           1070700         1070700              1
IterativeClosestPoint_10_3_1000_1000_False          1164020         1164020              1
IterativeClosestPoint_10_20_100_100_False            374548          377337              2
IterativeClosestPoint_10_20_100_1000_False           472764          476685              2
IterativeClosestPoint_10_20_1000_100_False          1457175         1457175              1
IterativeClosestPoint_10_20_1000_1000_False         2195820         2195820              1
IterativeClosestPoint_1_3_100_100_True               110084          115824              5
IterativeClosestPoint_1_3_100_1000_True              142728          147696              4
IterativeClosestPoint_1_3_1000_100_True              212966          213966              3
IterativeClosestPoint_1_3_1000_1000_True             369130          375114              2
IterativeClosestPoint_10_3_100_100_True              354615          355179              2
IterativeClosestPoint_10_3_100_1000_True             451815          452704              2
IterativeClosestPoint_10_3_1000_100_True             511833          511833              1
IterativeClosestPoint_10_3_1000_1000_True            798453          798453              1
--------------------------------------------------------------------------------
```

Reviewed By: shapovalov, gkioxari

Differential Revision: D19909952

fbshipit-source-id: f77fadc88fb7c53999909d594114b182ee2a3def
This commit is contained in:
David Novotny
2020-04-16 13:59:34 -07:00
committed by Facebook GitHub Bot
parent b5eb33b36c
commit 8abbe22ffb
6 changed files with 603 additions and 45 deletions

View File

@@ -5,7 +5,38 @@ from copy import deepcopy
from itertools import product
from fvcore.common.benchmark import benchmark
from test_points_alignment import TestCorrespondingPointsAlignment
from test_points_alignment import TestCorrespondingPointsAlignment, TestICP
def bm_iterative_closest_point() -> None:
case_grid = {
"batch_size": [1, 10],
"dim": [3, 20],
"n_points_X": [100, 1000],
"n_points_Y": [100, 1000],
"use_pointclouds": [False],
}
test_args = sorted(case_grid.keys())
test_cases = product(*case_grid.values())
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
# add the use_pointclouds=True test cases whenever we have dim==3
kwargs_to_add = []
for entry in kwargs_list:
if entry["dim"] == 3:
entry_add = deepcopy(entry)
entry_add["use_pointclouds"] = True
kwargs_to_add.append(entry_add)
kwargs_list.extend(kwargs_to_add)
benchmark(
TestICP.iterative_closest_point,
"IterativeClosestPoint",
kwargs_list,
warmup_iters=1,
)
def bm_corresponding_points_alignment() -> None:
@@ -21,7 +52,7 @@ def bm_corresponding_points_alignment() -> None:
}
test_args = sorted(case_grid.keys())
test_cases = product(*[case_grid[k] for k in test_args])
test_cases = product(*case_grid.values())
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
# add the use_pointclouds=True test cases whenever we have dim==3

BIN
tests/icp_data.pth Normal file

Binary file not shown.

View File

@@ -1,8 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from pathlib import Path
import numpy as np
import torch
@@ -36,6 +35,256 @@ def _apply_pcl_transformation(X, R, T, s=None):
return X_t
class TestICP(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
trimesh_results_path = Path(__file__).resolve().parent / "icp_data.pth"
self.trimesh_results = torch.load(trimesh_results_path)
@staticmethod
def iterative_closest_point(
batch_size=10,
n_points_X=100,
n_points_Y=100,
dim=3,
use_pointclouds=False,
estimate_scale=False,
):
device = torch.device("cuda:0")
# initialize a ground truth point cloud
X, Y = [
TestCorrespondingPointsAlignment.init_point_cloud(
batch_size=batch_size,
n_points=n_points,
dim=dim,
device=device,
use_pointclouds=use_pointclouds,
random_pcl_size=True,
fix_seed=i,
)
for i, n_points in enumerate((n_points_X, n_points_Y))
]
torch.cuda.synchronize()
def run_iterative_closest_point():
points_alignment.iterative_closest_point(
X,
Y,
estimate_scale=estimate_scale,
allow_reflection=False,
verbose=False,
max_iterations=100,
relative_rmse_thr=1e-4,
)
torch.cuda.synchronize()
return run_iterative_closest_point
def test_init_transformation(self, batch_size=10):
"""
First runs a full ICP on a random problem. Then takes a given point
in the history of ICP iteration transformations, initializes
a second run of ICP with this transformation and checks whether
both runs ended with the same solution.
"""
device = torch.device("cuda:0")
for dim in (2, 3, 11):
for n_points_X in (30, 100):
for n_points_Y in (30, 100):
# initialize ground truth point clouds
X, Y = [
TestCorrespondingPointsAlignment.init_point_cloud(
batch_size=batch_size,
n_points=n_points,
dim=dim,
device=device,
use_pointclouds=False,
random_pcl_size=True,
)
for n_points in (n_points_X, n_points_Y)
]
# run full icp
converged, _, Xt, (
R,
T,
s,
), t_hist = points_alignment.iterative_closest_point(
X,
Y,
estimate_scale=False,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
# start from the solution after the third
# iteration of the previous ICP
t_init = t_hist[min(2, len(t_hist) - 1)]
# rerun the ICP
converged_init, _, Xt_init, (
R_init,
T_init,
s_init,
), t_hist_init = points_alignment.iterative_closest_point(
X,
Y,
init_transform=t_init,
estimate_scale=False,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
# compare transformations and obtained clouds
# check that both sets of transforms are the same
atol = 3e-5
self.assertClose(R_init, R, atol=atol)
self.assertClose(T_init, T, atol=atol)
self.assertClose(s_init, s, atol=atol)
self.assertClose(Xt_init, Xt, atol=atol)
def test_heterogenous_inputs(self, batch_size=10):
"""
Tests whether we get the same result when running ICP on
a set of randomly-sized Pointclouds and on their padded versions.
"""
device = torch.device("cuda:0")
for estimate_scale in (True, False):
for max_n_points in (10, 30, 100):
# initialize ground truth point clouds
X_pcl, Y_pcl = [
TestCorrespondingPointsAlignment.init_point_cloud(
batch_size=batch_size,
n_points=max_n_points,
dim=3,
device=device,
use_pointclouds=True,
random_pcl_size=True,
)
for _ in range(2)
]
# get the padded versions and their num of points
X_padded = X_pcl.points_padded()
Y_padded = Y_pcl.points_padded()
n_points_X = X_pcl.num_points_per_cloud()
n_points_Y = Y_pcl.num_points_per_cloud()
# run icp with Pointlouds inputs
_, _, Xt_pcl, (
R_pcl,
T_pcl,
s_pcl,
), _ = points_alignment.iterative_closest_point(
X_pcl,
Y_pcl,
estimate_scale=estimate_scale,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
Xt_pcl = Xt_pcl.points_padded()
# run icp with tensor inputs on each element
# of the batch separately
icp_results = [
points_alignment.iterative_closest_point(
X_[None, :n_X, :],
Y_[None, :n_Y, :],
estimate_scale=estimate_scale,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
for X_, Y_, n_X, n_Y in zip(
X_padded, Y_padded, n_points_X, n_points_Y
)
]
# parse out the transformation results
R, T, s = [
torch.cat([x.RTs[i] for x in icp_results], dim=0) for i in range(3)
]
# check that both sets of transforms are the same
atol = 1e-5
self.assertClose(R_pcl, R, atol=atol)
self.assertClose(T_pcl, T, atol=atol)
self.assertClose(s_pcl, s, atol=atol)
# compare the transformed point clouds
for pcli in range(batch_size):
nX = n_points_X[pcli]
Xt_ = icp_results[pcli].Xt[0, :nX]
Xt_pcl_ = Xt_pcl[pcli][:nX]
self.assertClose(Xt_pcl_, Xt_, atol=atol)
def test_compare_with_trimesh(self):
"""
Compares the outputs of `iterative_closest_point` with the results
of `trimesh.registration.icp` from the `trimesh` python package:
https://github.com/mikedh/trimesh
We have run `trimesh.registration.icp` on several random problems
with different point cloud sizes. The results of trimesh, together with
the randomly generated input clouds are loaded in the constructor of
this class and this test compares the loaded results to our runs.
"""
for n_points_X in (10, 20, 50, 100):
for n_points_Y in (10, 20, 50, 100):
self._compare_with_trimesh(n_points_X=n_points_X, n_points_Y=n_points_Y)
def _compare_with_trimesh(
self, n_points_X=100, n_points_Y=100, estimate_scale=False
):
"""
Executes a single test for `iterative_closest_point` for a
specific setting of the inputs / outputs. Compares the result with
the result of the trimesh package on the same input data.
"""
device = torch.device("cuda:0")
# load the trimesh results and the initial point clouds for icp
key = (int(n_points_X), int(n_points_Y), int(estimate_scale))
X, Y, R_trimesh, T_trimesh, s_trimesh = [
x.to(device) for x in self.trimesh_results[key]
]
# run the icp algorithm
converged, _, _, (
R_ours,
T_ours,
s_ours,
), _ = points_alignment.iterative_closest_point(
X,
Y,
estimate_scale=estimate_scale,
allow_reflection=False,
verbose=False,
max_iterations=100,
)
# check that we have the same transformation
# and that the icp converged
atol = 1e-5
self.assertClose(R_ours, R_trimesh, atol=atol)
self.assertClose(T_ours, T_trimesh, atol=atol)
self.assertClose(s_ours, s_trimesh, atol=atol)
self.assertTrue(converged)
class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
@@ -72,10 +321,17 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
device=None,
use_pointclouds=False,
random_pcl_size=True,
fix_seed=None,
):
"""
Generate a batch of normally distributed point clouds.
"""
if fix_seed is not None:
# make sure we always generate the same pointcloud
seed = torch.random.get_rng_state()
torch.manual_seed(fix_seed)
if use_pointclouds:
assert dim == 3, "Pointclouds support only 3-dim points."
# generate a `batch_size` point clouds with number of points
@@ -102,6 +358,10 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
X = torch.randn(
batch_size, n_points, dim, device=device, dtype=torch.float32
)
if fix_seed:
torch.random.set_rng_state(seed)
return X
@staticmethod
@@ -230,7 +490,6 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
- use_pointclouds ... If True, passes the Pointclouds objects
to corresponding_points_alignment.
"""
# run this for several different point cloud sizes
for n_points in (100, 3, 2, 1):
# run this for several different dimensionalities