mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
Weighted Umeyama.
Summary: 1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own. 2. Refactored to use the same code for the Pointclouds mask and passed weights. 3. Added test cases with random weights. 4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: https://github.com/pytorch/pytorch/issues/31421 ). Reviewed By: gkioxari Differential Revision: D20070293 fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e5b1d6d3a3
commit
e37085d999
@@ -1,16 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.structures import utils as strutil
|
||||
from pytorch3d.ops import utils as oputil
|
||||
|
||||
|
||||
def corresponding_points_alignment(
|
||||
X: Union[torch.Tensor, Pointclouds],
|
||||
Y: Union[torch.Tensor, Pointclouds],
|
||||
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
|
||||
estimate_scale: bool = False,
|
||||
allow_reflection: bool = False,
|
||||
eps: float = 1e-8,
|
||||
@@ -28,9 +30,14 @@ def corresponding_points_alignment(
|
||||
|
||||
Args:
|
||||
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
or a `Pointclouds` object.
|
||||
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
or a `Pointclouds` object.
|
||||
weights: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
|
||||
tensors that may have different shapes; in that case, the length of
|
||||
i-th tensor should be equal to the number of points in X_i and Y_i.
|
||||
Passing `None` means uniform weights.
|
||||
estimate_scale: If `True`, also estimates a scaling component `s`
|
||||
of the transformation. Otherwise assumes an identity
|
||||
scale and returns a tensor of ones.
|
||||
@@ -59,25 +66,45 @@ def corresponding_points_alignment(
|
||||
"Point sets X and Y have to have the same \
|
||||
number of batches, points and dimensions."
|
||||
)
|
||||
if weights is not None:
|
||||
if isinstance(weights, list):
|
||||
if any(np != w.shape[0] for np, w in zip(num_points, weights)):
|
||||
raise ValueError(
|
||||
"number of weights should equal to the "
|
||||
+ "number of points in the point cloud."
|
||||
)
|
||||
weights = [w[..., None] for w in weights]
|
||||
weights = strutil.list_to_padded(weights)[..., 0]
|
||||
|
||||
if Xt.shape[:2] != weights.shape:
|
||||
raise ValueError(
|
||||
"weights should have the same first two dimensions as X."
|
||||
)
|
||||
|
||||
b, n, dim = Xt.shape
|
||||
|
||||
# compute the centroids of the point sets
|
||||
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
|
||||
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)
|
||||
|
||||
# mean-center the point sets
|
||||
Xc = Xt - Xmu[:, None]
|
||||
Yc = Yt - Ymu[:, None]
|
||||
|
||||
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
|
||||
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
|
||||
mask = (
|
||||
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
|
||||
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
|
||||
< num_points[:, None]
|
||||
).type_as(Xc)
|
||||
Xc *= mask[:, :, None]
|
||||
Yc *= mask[:, :, None]
|
||||
).type_as(Xt)
|
||||
weights = mask if weights is None else mask * weights.type_as(Xt)
|
||||
|
||||
# compute the centroids of the point sets
|
||||
Xmu = oputil.wmean(Xt, weights, eps=eps)
|
||||
Ymu = oputil.wmean(Yt, weights, eps=eps)
|
||||
|
||||
# mean-center the point sets
|
||||
Xc = Xt - Xmu
|
||||
Yc = Yt - Ymu
|
||||
|
||||
total_weight = torch.clamp(num_points, 1)
|
||||
# special handling for heterogeneous point clouds and/or input weights
|
||||
if weights is not None:
|
||||
Xc *= weights[:, :, None]
|
||||
Yc *= weights[:, :, None]
|
||||
total_weight = torch.clamp(weights.sum(1), eps)
|
||||
|
||||
if (num_points < (dim + 1)).any():
|
||||
warnings.warn(
|
||||
@@ -87,7 +114,7 @@ def corresponding_points_alignment(
|
||||
|
||||
# compute the covariance XYcov between the point sets Xc, Yc
|
||||
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
|
||||
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)
|
||||
XYcov = XYcov / total_weight[:, None, None]
|
||||
|
||||
# decompose the covariance matrix XYcov
|
||||
U, S, V = torch.svd(XYcov)
|
||||
@@ -111,17 +138,16 @@ def corresponding_points_alignment(
|
||||
if estimate_scale:
|
||||
# estimate the scaling component of the transformation
|
||||
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
|
||||
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)
|
||||
Xcov = (Xc * Xc).sum((1, 2)) / total_weight
|
||||
|
||||
# the scaling component
|
||||
s = trace_ES / torch.clamp(Xcov, eps)
|
||||
|
||||
# translation component
|
||||
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]
|
||||
|
||||
T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
|
||||
else:
|
||||
# translation component
|
||||
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]
|
||||
T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]
|
||||
|
||||
# unit scaling since we do not estimate scale
|
||||
s = T.new_ones(b)
|
||||
|
||||
44
pytorch3d/ops/utils.py
Normal file
44
pytorch3d/ops/utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def wmean(
|
||||
x: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
dim: Union[int, Tuple[int]] = -2,
|
||||
keepdim: bool = True,
|
||||
eps: float = 1e-9,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finds the mean of the input tensor across the specified dimension.
|
||||
If the `weight` argument is provided, computes weighted mean.
|
||||
Args:
|
||||
x: tensor of shape `(*, D)`, where D is assumed to be spatial;
|
||||
weights: if given, non-negative tensor of shape `(*,)`. It must be
|
||||
broadcastable to `x.shape[:-1]`. Note that the weights for
|
||||
the last (spatial) dimension are assumed same;
|
||||
dim: dimension(s) in `x` to average over;
|
||||
keepdim: tells whether to keep the resulting singleton dimension.
|
||||
eps: minumum clamping value in the denominator.
|
||||
Returns:
|
||||
the mean tensor:
|
||||
* if `weights` is None => `mean(x, dim)`,
|
||||
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
|
||||
"""
|
||||
args = dict(dim=dim, keepdim=keepdim)
|
||||
|
||||
if weight is None:
|
||||
return x.mean(**args)
|
||||
|
||||
if any(
|
||||
xd != wd and xd != 1 and wd != 1
|
||||
for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1])
|
||||
):
|
||||
raise ValueError("wmean: weights are not compatible with the tensor")
|
||||
|
||||
return (
|
||||
(x * weight[..., None]).sum(**args)
|
||||
/ weight[..., None].sum(**args).clamp(eps)
|
||||
)
|
||||
Reference in New Issue
Block a user