mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
ICP - point-to-point version
Summary: The iterative closest point algorithm - point-to-point version. Output of `bm_iterative_closest_point`: Argument key: `batch_size dim n_points_X n_points_Y use_pointclouds` ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- IterativeClosestPoint_1_3_100_100_False 107569 111323 5 IterativeClosestPoint_1_3_100_1000_False 118972 122306 5 IterativeClosestPoint_1_3_1000_100_False 108576 110978 5 IterativeClosestPoint_1_3_1000_1000_False 331836 333515 2 IterativeClosestPoint_1_20_100_100_False 134387 137842 4 IterativeClosestPoint_1_20_100_1000_False 149218 153405 4 IterativeClosestPoint_1_20_1000_100_False 414248 416595 2 IterativeClosestPoint_1_20_1000_1000_False 374318 374662 2 IterativeClosestPoint_10_3_100_100_False 539852 539852 1 IterativeClosestPoint_10_3_100_1000_False 752784 752784 1 IterativeClosestPoint_10_3_1000_100_False 1070700 1070700 1 IterativeClosestPoint_10_3_1000_1000_False 1164020 1164020 1 IterativeClosestPoint_10_20_100_100_False 374548 377337 2 IterativeClosestPoint_10_20_100_1000_False 472764 476685 2 IterativeClosestPoint_10_20_1000_100_False 1457175 1457175 1 IterativeClosestPoint_10_20_1000_1000_False 2195820 2195820 1 IterativeClosestPoint_1_3_100_100_True 110084 115824 5 IterativeClosestPoint_1_3_100_1000_True 142728 147696 4 IterativeClosestPoint_1_3_1000_100_True 212966 213966 3 IterativeClosestPoint_1_3_1000_1000_True 369130 375114 2 IterativeClosestPoint_10_3_100_100_True 354615 355179 2 IterativeClosestPoint_10_3_100_1000_True 451815 452704 2 IterativeClosestPoint_10_3_1000_100_True 511833 511833 1 IterativeClosestPoint_10_3_1000_1000_True 798453 798453 1 -------------------------------------------------------------------------------- ``` Reviewed By: shapovalov, gkioxari Differential Revision: D19909952 fbshipit-source-id: f77fadc88fb7c53999909d594114b182ee2a3def
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b5eb33b36c
commit
8abbe22ffb
@@ -5,7 +5,38 @@ from copy import deepcopy
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_points_alignment import TestCorrespondingPointsAlignment
|
||||
from test_points_alignment import TestCorrespondingPointsAlignment, TestICP
|
||||
|
||||
|
||||
def bm_iterative_closest_point() -> None:
|
||||
|
||||
case_grid = {
|
||||
"batch_size": [1, 10],
|
||||
"dim": [3, 20],
|
||||
"n_points_X": [100, 1000],
|
||||
"n_points_Y": [100, 1000],
|
||||
"use_pointclouds": [False],
|
||||
}
|
||||
|
||||
test_args = sorted(case_grid.keys())
|
||||
test_cases = product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
|
||||
|
||||
# add the use_pointclouds=True test cases whenever we have dim==3
|
||||
kwargs_to_add = []
|
||||
for entry in kwargs_list:
|
||||
if entry["dim"] == 3:
|
||||
entry_add = deepcopy(entry)
|
||||
entry_add["use_pointclouds"] = True
|
||||
kwargs_to_add.append(entry_add)
|
||||
kwargs_list.extend(kwargs_to_add)
|
||||
|
||||
benchmark(
|
||||
TestICP.iterative_closest_point,
|
||||
"IterativeClosestPoint",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
def bm_corresponding_points_alignment() -> None:
|
||||
@@ -21,7 +52,7 @@ def bm_corresponding_points_alignment() -> None:
|
||||
}
|
||||
|
||||
test_args = sorted(case_grid.keys())
|
||||
test_cases = product(*[case_grid[k] for k in test_args])
|
||||
test_cases = product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
|
||||
|
||||
# add the use_pointclouds=True test cases whenever we have dim==3
|
||||
|
||||
BIN
tests/icp_data.pth
Normal file
BIN
tests/icp_data.pth
Normal file
Binary file not shown.
@@ -1,8 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -36,6 +35,256 @@ def _apply_pcl_transformation(X, R, T, s=None):
|
||||
return X_t
|
||||
|
||||
|
||||
class TestICP(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
trimesh_results_path = Path(__file__).resolve().parent / "icp_data.pth"
|
||||
self.trimesh_results = torch.load(trimesh_results_path)
|
||||
|
||||
@staticmethod
|
||||
def iterative_closest_point(
|
||||
batch_size=10,
|
||||
n_points_X=100,
|
||||
n_points_Y=100,
|
||||
dim=3,
|
||||
use_pointclouds=False,
|
||||
estimate_scale=False,
|
||||
):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# initialize a ground truth point cloud
|
||||
X, Y = [
|
||||
TestCorrespondingPointsAlignment.init_point_cloud(
|
||||
batch_size=batch_size,
|
||||
n_points=n_points,
|
||||
dim=dim,
|
||||
device=device,
|
||||
use_pointclouds=use_pointclouds,
|
||||
random_pcl_size=True,
|
||||
fix_seed=i,
|
||||
)
|
||||
for i, n_points in enumerate((n_points_X, n_points_Y))
|
||||
]
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_iterative_closest_point():
|
||||
points_alignment.iterative_closest_point(
|
||||
X,
|
||||
Y,
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
relative_rmse_thr=1e-4,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return run_iterative_closest_point
|
||||
|
||||
def test_init_transformation(self, batch_size=10):
|
||||
"""
|
||||
First runs a full ICP on a random problem. Then takes a given point
|
||||
in the history of ICP iteration transformations, initializes
|
||||
a second run of ICP with this transformation and checks whether
|
||||
both runs ended with the same solution.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
for dim in (2, 3, 11):
|
||||
for n_points_X in (30, 100):
|
||||
for n_points_Y in (30, 100):
|
||||
# initialize ground truth point clouds
|
||||
X, Y = [
|
||||
TestCorrespondingPointsAlignment.init_point_cloud(
|
||||
batch_size=batch_size,
|
||||
n_points=n_points,
|
||||
dim=dim,
|
||||
device=device,
|
||||
use_pointclouds=False,
|
||||
random_pcl_size=True,
|
||||
)
|
||||
for n_points in (n_points_X, n_points_Y)
|
||||
]
|
||||
|
||||
# run full icp
|
||||
converged, _, Xt, (
|
||||
R,
|
||||
T,
|
||||
s,
|
||||
), t_hist = points_alignment.iterative_closest_point(
|
||||
X,
|
||||
Y,
|
||||
estimate_scale=False,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
|
||||
# start from the solution after the third
|
||||
# iteration of the previous ICP
|
||||
t_init = t_hist[min(2, len(t_hist) - 1)]
|
||||
|
||||
# rerun the ICP
|
||||
converged_init, _, Xt_init, (
|
||||
R_init,
|
||||
T_init,
|
||||
s_init,
|
||||
), t_hist_init = points_alignment.iterative_closest_point(
|
||||
X,
|
||||
Y,
|
||||
init_transform=t_init,
|
||||
estimate_scale=False,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
|
||||
# compare transformations and obtained clouds
|
||||
# check that both sets of transforms are the same
|
||||
atol = 3e-5
|
||||
self.assertClose(R_init, R, atol=atol)
|
||||
self.assertClose(T_init, T, atol=atol)
|
||||
self.assertClose(s_init, s, atol=atol)
|
||||
self.assertClose(Xt_init, Xt, atol=atol)
|
||||
|
||||
def test_heterogenous_inputs(self, batch_size=10):
|
||||
"""
|
||||
Tests whether we get the same result when running ICP on
|
||||
a set of randomly-sized Pointclouds and on their padded versions.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
for estimate_scale in (True, False):
|
||||
for max_n_points in (10, 30, 100):
|
||||
# initialize ground truth point clouds
|
||||
X_pcl, Y_pcl = [
|
||||
TestCorrespondingPointsAlignment.init_point_cloud(
|
||||
batch_size=batch_size,
|
||||
n_points=max_n_points,
|
||||
dim=3,
|
||||
device=device,
|
||||
use_pointclouds=True,
|
||||
random_pcl_size=True,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
# get the padded versions and their num of points
|
||||
X_padded = X_pcl.points_padded()
|
||||
Y_padded = Y_pcl.points_padded()
|
||||
n_points_X = X_pcl.num_points_per_cloud()
|
||||
n_points_Y = Y_pcl.num_points_per_cloud()
|
||||
|
||||
# run icp with Pointlouds inputs
|
||||
_, _, Xt_pcl, (
|
||||
R_pcl,
|
||||
T_pcl,
|
||||
s_pcl,
|
||||
), _ = points_alignment.iterative_closest_point(
|
||||
X_pcl,
|
||||
Y_pcl,
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
Xt_pcl = Xt_pcl.points_padded()
|
||||
|
||||
# run icp with tensor inputs on each element
|
||||
# of the batch separately
|
||||
icp_results = [
|
||||
points_alignment.iterative_closest_point(
|
||||
X_[None, :n_X, :],
|
||||
Y_[None, :n_Y, :],
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
for X_, Y_, n_X, n_Y in zip(
|
||||
X_padded, Y_padded, n_points_X, n_points_Y
|
||||
)
|
||||
]
|
||||
|
||||
# parse out the transformation results
|
||||
R, T, s = [
|
||||
torch.cat([x.RTs[i] for x in icp_results], dim=0) for i in range(3)
|
||||
]
|
||||
|
||||
# check that both sets of transforms are the same
|
||||
atol = 1e-5
|
||||
self.assertClose(R_pcl, R, atol=atol)
|
||||
self.assertClose(T_pcl, T, atol=atol)
|
||||
self.assertClose(s_pcl, s, atol=atol)
|
||||
|
||||
# compare the transformed point clouds
|
||||
for pcli in range(batch_size):
|
||||
nX = n_points_X[pcli]
|
||||
Xt_ = icp_results[pcli].Xt[0, :nX]
|
||||
Xt_pcl_ = Xt_pcl[pcli][:nX]
|
||||
self.assertClose(Xt_pcl_, Xt_, atol=atol)
|
||||
|
||||
def test_compare_with_trimesh(self):
|
||||
"""
|
||||
Compares the outputs of `iterative_closest_point` with the results
|
||||
of `trimesh.registration.icp` from the `trimesh` python package:
|
||||
https://github.com/mikedh/trimesh
|
||||
|
||||
We have run `trimesh.registration.icp` on several random problems
|
||||
with different point cloud sizes. The results of trimesh, together with
|
||||
the randomly generated input clouds are loaded in the constructor of
|
||||
this class and this test compares the loaded results to our runs.
|
||||
"""
|
||||
for n_points_X in (10, 20, 50, 100):
|
||||
for n_points_Y in (10, 20, 50, 100):
|
||||
self._compare_with_trimesh(n_points_X=n_points_X, n_points_Y=n_points_Y)
|
||||
|
||||
def _compare_with_trimesh(
|
||||
self, n_points_X=100, n_points_Y=100, estimate_scale=False
|
||||
):
|
||||
"""
|
||||
Executes a single test for `iterative_closest_point` for a
|
||||
specific setting of the inputs / outputs. Compares the result with
|
||||
the result of the trimesh package on the same input data.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# load the trimesh results and the initial point clouds for icp
|
||||
key = (int(n_points_X), int(n_points_Y), int(estimate_scale))
|
||||
X, Y, R_trimesh, T_trimesh, s_trimesh = [
|
||||
x.to(device) for x in self.trimesh_results[key]
|
||||
]
|
||||
|
||||
# run the icp algorithm
|
||||
converged, _, _, (
|
||||
R_ours,
|
||||
T_ours,
|
||||
s_ours,
|
||||
), _ = points_alignment.iterative_closest_point(
|
||||
X,
|
||||
Y,
|
||||
estimate_scale=estimate_scale,
|
||||
allow_reflection=False,
|
||||
verbose=False,
|
||||
max_iterations=100,
|
||||
)
|
||||
|
||||
# check that we have the same transformation
|
||||
# and that the icp converged
|
||||
atol = 1e-5
|
||||
self.assertClose(R_ours, R_trimesh, atol=atol)
|
||||
self.assertClose(T_ours, T_trimesh, atol=atol)
|
||||
self.assertClose(s_ours, s_trimesh, atol=atol)
|
||||
self.assertTrue(converged)
|
||||
|
||||
|
||||
class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@@ -72,10 +321,17 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
device=None,
|
||||
use_pointclouds=False,
|
||||
random_pcl_size=True,
|
||||
fix_seed=None,
|
||||
):
|
||||
"""
|
||||
Generate a batch of normally distributed point clouds.
|
||||
"""
|
||||
|
||||
if fix_seed is not None:
|
||||
# make sure we always generate the same pointcloud
|
||||
seed = torch.random.get_rng_state()
|
||||
torch.manual_seed(fix_seed)
|
||||
|
||||
if use_pointclouds:
|
||||
assert dim == 3, "Pointclouds support only 3-dim points."
|
||||
# generate a `batch_size` point clouds with number of points
|
||||
@@ -102,6 +358,10 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
X = torch.randn(
|
||||
batch_size, n_points, dim, device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
if fix_seed:
|
||||
torch.random.set_rng_state(seed)
|
||||
|
||||
return X
|
||||
|
||||
@staticmethod
|
||||
@@ -230,7 +490,6 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
- use_pointclouds ... If True, passes the Pointclouds objects
|
||||
to corresponding_points_alignment.
|
||||
"""
|
||||
|
||||
# run this for several different point cloud sizes
|
||||
for n_points in (100, 3, 2, 1):
|
||||
# run this for several different dimensionalities
|
||||
|
||||
Reference in New Issue
Block a user