mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
Umeyama
Summary: Umeyama estimates a rigid motion between two sets of corresponding points. Benchmark output for `bm_points_alignment` ``` Arguments key: [<allow_reflection>_<batch_size>_<dim>_<estimate_scale>_<n_points>_<use_pointclouds>] Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- CorrespodingPointsAlignment_True_1_3_True_100_False 7382 9833 68 CorrespodingPointsAlignment_True_1_3_True_10000_False 8183 10500 62 CorrespodingPointsAlignment_True_1_3_False_100_False 7301 9263 69 CorrespodingPointsAlignment_True_1_3_False_10000_False 7945 9746 64 CorrespodingPointsAlignment_True_1_20_True_100_False 13706 41623 37 CorrespodingPointsAlignment_True_1_20_True_10000_False 11044 33766 46 CorrespodingPointsAlignment_True_1_20_False_100_False 9908 28791 51 CorrespodingPointsAlignment_True_1_20_False_10000_False 9523 18680 53 CorrespodingPointsAlignment_True_10_3_True_100_False 29585 32026 17 CorrespodingPointsAlignment_True_10_3_True_10000_False 29626 36324 18 CorrespodingPointsAlignment_True_10_3_False_100_False 26013 29253 20 CorrespodingPointsAlignment_True_10_3_False_10000_False 25000 33820 20 CorrespodingPointsAlignment_True_10_20_True_100_False 40955 41592 13 CorrespodingPointsAlignment_True_10_20_True_10000_False 42087 42393 12 CorrespodingPointsAlignment_True_10_20_False_100_False 39863 40381 13 CorrespodingPointsAlignment_True_10_20_False_10000_False 40813 41699 13 CorrespodingPointsAlignment_True_100_3_True_100_False 183146 194745 3 CorrespodingPointsAlignment_True_100_3_True_10000_False 213789 231466 3 CorrespodingPointsAlignment_True_100_3_False_100_False 177805 180796 3 CorrespodingPointsAlignment_True_100_3_False_10000_False 184963 185695 3 CorrespodingPointsAlignment_True_100_20_True_100_False 347181 347325 2 CorrespodingPointsAlignment_True_100_20_True_10000_False 363259 363613 2 CorrespodingPointsAlignment_True_100_20_False_100_False 351769 352496 2 CorrespodingPointsAlignment_True_100_20_False_10000_False 375629 379818 2 CorrespodingPointsAlignment_False_1_3_True_100_False 11155 13770 45 CorrespodingPointsAlignment_False_1_3_True_10000_False 10743 13938 47 CorrespodingPointsAlignment_False_1_3_False_100_False 9578 11511 53 CorrespodingPointsAlignment_False_1_3_False_10000_False 9549 11984 53 CorrespodingPointsAlignment_False_1_20_True_100_False 13809 14183 37 CorrespodingPointsAlignment_False_1_20_True_10000_False 14084 15082 36 CorrespodingPointsAlignment_False_1_20_False_100_False 12765 14177 40 CorrespodingPointsAlignment_False_1_20_False_10000_False 12811 13096 40 CorrespodingPointsAlignment_False_10_3_True_100_False 28823 39384 18 CorrespodingPointsAlignment_False_10_3_True_10000_False 27135 27525 19 CorrespodingPointsAlignment_False_10_3_False_100_False 26236 28980 20 CorrespodingPointsAlignment_False_10_3_False_10000_False 42324 45123 12 CorrespodingPointsAlignment_False_10_20_True_100_False 723902 723902 1 CorrespodingPointsAlignment_False_10_20_True_10000_False 220007 252886 3 CorrespodingPointsAlignment_False_10_20_False_100_False 55593 71636 9 CorrespodingPointsAlignment_False_10_20_False_10000_False 44419 71861 12 CorrespodingPointsAlignment_False_100_3_True_100_False 184768 185199 3 CorrespodingPointsAlignment_False_100_3_True_10000_False 198657 213868 3 CorrespodingPointsAlignment_False_100_3_False_100_False 224598 309645 3 CorrespodingPointsAlignment_False_100_3_False_10000_False 197863 202002 3 CorrespodingPointsAlignment_False_100_20_True_100_False 293484 309459 2 CorrespodingPointsAlignment_False_100_20_True_10000_False 327253 366644 2 CorrespodingPointsAlignment_False_100_20_False_100_False 420793 422194 2 CorrespodingPointsAlignment_False_100_20_False_10000_False 462634 485542 2 CorrespodingPointsAlignment_True_1_3_True_100_True 7664 9909 66 CorrespodingPointsAlignment_True_1_3_True_10000_True 7190 8366 70 CorrespodingPointsAlignment_True_1_3_False_100_True 6549 8316 77 CorrespodingPointsAlignment_True_1_3_False_10000_True 6534 7710 77 CorrespodingPointsAlignment_True_10_3_True_100_True 29052 32940 18 CorrespodingPointsAlignment_True_10_3_True_10000_True 30526 33453 17 CorrespodingPointsAlignment_True_10_3_False_100_True 28708 32993 18 CorrespodingPointsAlignment_True_10_3_False_10000_True 30630 35973 17 CorrespodingPointsAlignment_True_100_3_True_100_True 264909 320820 3 CorrespodingPointsAlignment_True_100_3_True_10000_True 310902 322604 2 CorrespodingPointsAlignment_True_100_3_False_100_True 246832 250634 3 CorrespodingPointsAlignment_True_100_3_False_10000_True 276006 289061 2 CorrespodingPointsAlignment_False_1_3_True_100_True 11421 13757 44 CorrespodingPointsAlignment_False_1_3_True_10000_True 11199 12532 45 CorrespodingPointsAlignment_False_1_3_False_100_True 11474 15841 44 CorrespodingPointsAlignment_False_1_3_False_10000_True 10384 13188 49 CorrespodingPointsAlignment_False_10_3_True_100_True 36599 47340 14 CorrespodingPointsAlignment_False_10_3_True_10000_True 40702 50754 13 CorrespodingPointsAlignment_False_10_3_False_100_True 41277 52149 13 CorrespodingPointsAlignment_False_10_3_False_10000_True 34286 37091 15 CorrespodingPointsAlignment_False_100_3_True_100_True 254991 258578 2 CorrespodingPointsAlignment_False_100_3_True_10000_True 257999 261285 2 CorrespodingPointsAlignment_False_100_3_False_100_True 247511 248693 3 CorrespodingPointsAlignment_False_100_3_False_10000_True 251807 263865 3 ``` Reviewed By: gkioxari Differential Revision: D19808389 fbshipit-source-id: 83305a58627d2fc5dcaf3c3015132d8148f28c29
This commit is contained in:
committed by
Facebook GitHub Bot
parent
745aaf3915
commit
e5b1d6d3a3
@@ -6,6 +6,7 @@ from .graph_conv import GraphConv
|
||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||
from .nearest_neighbor_points import nn_points_idx
|
||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||
from .points_alignment import corresponding_points_alignment
|
||||
from .sample_points_from_meshes import sample_points_from_meshes
|
||||
from .subdivide_meshes import SubdivideMeshes
|
||||
from .vert_align import vert_align
|
||||
|
||||
151
pytorch3d/ops/points_alignment.py
Normal file
151
pytorch3d/ops/points_alignment.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
|
||||
def corresponding_points_alignment(
|
||||
X: Union[torch.Tensor, Pointclouds],
|
||||
Y: Union[torch.Tensor, Pointclouds],
|
||||
estimate_scale: bool = False,
|
||||
allow_reflection: bool = False,
|
||||
eps: float = 1e-8,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Finds a similarity transformation (rotation `R`, translation `T`
|
||||
and optionally scale `s`) between two given sets of corresponding
|
||||
`d`-dimensional points `X` and `Y` such that:
|
||||
|
||||
`s[i] X[i] R[i] + T[i] = Y[i]`,
|
||||
|
||||
for all batch indexes `i` in the least squares sense.
|
||||
|
||||
The algorithm is also known as Umeyama [1].
|
||||
|
||||
Args:
|
||||
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
estimate_scale: If `True`, also estimates a scaling component `s`
|
||||
of the transformation. Otherwise assumes an identity
|
||||
scale and returns a tensor of ones.
|
||||
allow_reflection: If `True`, allows the algorithm to return `R`
|
||||
which is orthonormal but has determinant==-1.
|
||||
eps: A scalar for clamping to avoid dividing by zero. Active for the
|
||||
code that estimates the output scale `s`.
|
||||
|
||||
Returns:
|
||||
3-element tuple containing
|
||||
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
|
||||
- **T**: Batch of translations of shape `(minibatch, d)`.
|
||||
- **s**: batch of scaling factors of shape `(minibatch, )`.
|
||||
|
||||
References:
|
||||
[1] Shinji Umeyama: Least-Suqares Estimation of
|
||||
Transformation Parameters Between Two Point Patterns
|
||||
"""
|
||||
|
||||
# make sure we convert input Pointclouds structures to tensors
|
||||
Xt, num_points = _convert_point_cloud_to_tensor(X)
|
||||
Yt, num_points_Y = _convert_point_cloud_to_tensor(Y)
|
||||
|
||||
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
|
||||
raise ValueError(
|
||||
"Point sets X and Y have to have the same \
|
||||
number of batches, points and dimensions."
|
||||
)
|
||||
|
||||
b, n, dim = Xt.shape
|
||||
|
||||
# compute the centroids of the point sets
|
||||
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
|
||||
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)
|
||||
|
||||
# mean-center the point sets
|
||||
Xc = Xt - Xmu[:, None]
|
||||
Yc = Yt - Ymu[:, None]
|
||||
|
||||
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
|
||||
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
|
||||
mask = (
|
||||
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
|
||||
< num_points[:, None]
|
||||
).type_as(Xc)
|
||||
Xc *= mask[:, :, None]
|
||||
Yc *= mask[:, :, None]
|
||||
|
||||
if (num_points < (dim + 1)).any():
|
||||
warnings.warn(
|
||||
"The size of one of the point clouds is <= dim+1. "
|
||||
+ "corresponding_points_alignment can't return a unique solution."
|
||||
)
|
||||
|
||||
# compute the covariance XYcov between the point sets Xc, Yc
|
||||
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
|
||||
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)
|
||||
|
||||
# decompose the covariance matrix XYcov
|
||||
U, S, V = torch.svd(XYcov)
|
||||
|
||||
# identity matrix used for fixing reflections
|
||||
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(
|
||||
b, 1, 1
|
||||
)
|
||||
|
||||
if not allow_reflection:
|
||||
# reflection test:
|
||||
# checks whether the estimated rotation has det==1,
|
||||
# if not, finds the nearest rotation s.t. det==1 by
|
||||
# flipping the sign of the last singular vector U
|
||||
R_test = torch.bmm(U, V.transpose(2, 1))
|
||||
E[:, -1, -1] = torch.det(R_test)
|
||||
|
||||
# find the rotation matrix by composing U and V again
|
||||
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
|
||||
|
||||
if estimate_scale:
|
||||
# estimate the scaling component of the transformation
|
||||
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
|
||||
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)
|
||||
|
||||
# the scaling component
|
||||
s = trace_ES / torch.clamp(Xcov, eps)
|
||||
|
||||
# translation component
|
||||
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]
|
||||
|
||||
else:
|
||||
# translation component
|
||||
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]
|
||||
|
||||
# unit scaling since we do not estimate scale
|
||||
s = T.new_ones(b)
|
||||
|
||||
return R, T, s
|
||||
|
||||
|
||||
def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]):
|
||||
"""
|
||||
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
|
||||
padded representation and returns it together with the number of points
|
||||
per batch. Otherwise, returns the input itself with the number of points
|
||||
set to the size of the second dimension of `pcl`.
|
||||
"""
|
||||
if isinstance(pcl, Pointclouds):
|
||||
X = pcl.points_padded()
|
||||
num_points = pcl.num_points_per_cloud()
|
||||
elif torch.is_tensor(pcl):
|
||||
X = pcl
|
||||
num_points = X.shape[1] * torch.ones(
|
||||
X.shape[0], device=X.device, dtype=torch.int64
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The inputs X, Y should be either Pointclouds objects or tensors."
|
||||
)
|
||||
return X, num_points
|
||||
Reference in New Issue
Block a user