Weighted Umeyama.

Summary:
1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own.
2. Refactored to use the same code for the Pointclouds mask and passed weights.
3. Added test cases with random weights.
4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: https://github.com/pytorch/pytorch/issues/31421 ).

Reviewed By: gkioxari

Differential Revision: D20070293

fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
This commit is contained in:
Roman Shapovalov
2020-04-03 02:57:01 -07:00
committed by Facebook GitHub Bot
parent e5b1d6d3a3
commit e37085d999
6 changed files with 278 additions and 50 deletions

View File

@@ -1,16 +1,18 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.structures import utils as strutil
from pytorch3d.ops import utils as oputil
def corresponding_points_alignment(
X: Union[torch.Tensor, Pointclouds],
Y: Union[torch.Tensor, Pointclouds],
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
estimate_scale: bool = False,
allow_reflection: bool = False,
eps: float = 1e-8,
@@ -28,9 +30,14 @@ def corresponding_points_alignment(
Args:
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
or a `Pointclouds` object.
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
or a `Pointclouds` object.
weights: Batch of non-negative weights of
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
tensors that may have different shapes; in that case, the length of
i-th tensor should be equal to the number of points in X_i and Y_i.
Passing `None` means uniform weights.
estimate_scale: If `True`, also estimates a scaling component `s`
of the transformation. Otherwise assumes an identity
scale and returns a tensor of ones.
@@ -59,25 +66,45 @@ def corresponding_points_alignment(
"Point sets X and Y have to have the same \
number of batches, points and dimensions."
)
if weights is not None:
if isinstance(weights, list):
if any(np != w.shape[0] for np, w in zip(num_points, weights)):
raise ValueError(
"number of weights should equal to the "
+ "number of points in the point cloud."
)
weights = [w[..., None] for w in weights]
weights = strutil.list_to_padded(weights)[..., 0]
if Xt.shape[:2] != weights.shape:
raise ValueError(
"weights should have the same first two dimensions as X."
)
b, n, dim = Xt.shape
# compute the centroids of the point sets
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)
# mean-center the point sets
Xc = Xt - Xmu[:, None]
Yc = Yt - Ymu[:, None]
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
mask = (
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
< num_points[:, None]
).type_as(Xc)
Xc *= mask[:, :, None]
Yc *= mask[:, :, None]
).type_as(Xt)
weights = mask if weights is None else mask * weights.type_as(Xt)
# compute the centroids of the point sets
Xmu = oputil.wmean(Xt, weights, eps=eps)
Ymu = oputil.wmean(Yt, weights, eps=eps)
# mean-center the point sets
Xc = Xt - Xmu
Yc = Yt - Ymu
total_weight = torch.clamp(num_points, 1)
# special handling for heterogeneous point clouds and/or input weights
if weights is not None:
Xc *= weights[:, :, None]
Yc *= weights[:, :, None]
total_weight = torch.clamp(weights.sum(1), eps)
if (num_points < (dim + 1)).any():
warnings.warn(
@@ -87,7 +114,7 @@ def corresponding_points_alignment(
# compute the covariance XYcov between the point sets Xc, Yc
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)
XYcov = XYcov / total_weight[:, None, None]
# decompose the covariance matrix XYcov
U, S, V = torch.svd(XYcov)
@@ -111,17 +138,16 @@ def corresponding_points_alignment(
if estimate_scale:
# estimate the scaling component of the transformation
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)
Xcov = (Xc * Xc).sum((1, 2)) / total_weight
# the scaling component
s = trace_ES / torch.clamp(Xcov, eps)
# translation component
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]
T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
else:
# translation component
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]
T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]
# unit scaling since we do not estimate scale
s = T.new_ones(b)

44
pytorch3d/ops/utils.py Normal file
View File

@@ -0,0 +1,44 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional, Tuple, Union
import torch
def wmean(
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
dim: Union[int, Tuple[int]] = -2,
keepdim: bool = True,
eps: float = 1e-9,
) -> torch.Tensor:
"""
Finds the mean of the input tensor across the specified dimension.
If the `weight` argument is provided, computes weighted mean.
Args:
x: tensor of shape `(*, D)`, where D is assumed to be spatial;
weights: if given, non-negative tensor of shape `(*,)`. It must be
broadcastable to `x.shape[:-1]`. Note that the weights for
the last (spatial) dimension are assumed same;
dim: dimension(s) in `x` to average over;
keepdim: tells whether to keep the resulting singleton dimension.
eps: minumum clamping value in the denominator.
Returns:
the mean tensor:
* if `weights` is None => `mean(x, dim)`,
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
"""
args = dict(dim=dim, keepdim=keepdim)
if weight is None:
return x.mean(**args)
if any(
xd != wd and xd != 1 and wd != 1
for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1])
):
raise ValueError("wmean: weights are not compatible with the tensor")
return (
(x * weight[..., None]).sum(**args)
/ weight[..., None].sum(**args).clamp(eps)
)