mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00

Differential Revision: D31042748 fbshipit-source-id: fffb983bd6765d306a407587ddf64e68e57e9ecc
200 lines
6.9 KiB
Python
200 lines
6.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from .knn import knn_points
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from pytorch3d.structures import Pointclouds
|
|
|
|
|
|
def masked_gather(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Helper function for torch.gather to collect the points at
|
|
the given indices in idx where some of the indices might be -1 to
|
|
indicate padding. These indices are first replaced with 0.
|
|
Then the points are gathered after which the padded values
|
|
are set to 0.0.
|
|
|
|
Args:
|
|
points: (N, P, D) float32 tensor of points
|
|
idx: (N, K) or (N, P, K) long tensor of indices into points, where
|
|
some indices are -1 to indicate padding
|
|
|
|
Returns:
|
|
selected_points: (N, K, D) float32 tensor of points
|
|
at the given indices
|
|
"""
|
|
|
|
if len(idx) != len(points):
|
|
raise ValueError("points and idx must have the same batch dimension")
|
|
|
|
N, P, D = points.shape
|
|
|
|
if idx.ndim == 3:
|
|
# Case: KNN, Ball Query where idx is of shape (N, P', K)
|
|
# where P' is not necessarily the same as P as the
|
|
# points may be gathered from a different pointcloud.
|
|
K = idx.shape[2]
|
|
# Match dimensions for points and indices
|
|
idx_expanded = idx[..., None].expand(-1, -1, -1, D)
|
|
points = points[:, :, None, :].expand(-1, -1, K, -1)
|
|
elif idx.ndim == 2:
|
|
# Farthest point sampling where idx is of shape (N, K)
|
|
idx_expanded = idx[..., None].expand(-1, -1, D)
|
|
else:
|
|
raise ValueError("idx format is not supported %s" % repr(idx.shape))
|
|
|
|
idx_expanded_mask = idx_expanded.eq(-1)
|
|
idx_expanded = idx_expanded.clone()
|
|
# Replace -1 values with 0 for gather
|
|
idx_expanded[idx_expanded_mask] = 0
|
|
# Gather points
|
|
selected_points = points.gather(dim=1, index=idx_expanded)
|
|
# Replace padded values
|
|
selected_points[idx_expanded_mask] = 0.0
|
|
return selected_points
|
|
|
|
|
|
def wmean(
|
|
x: torch.Tensor,
|
|
weight: Optional[torch.Tensor] = None,
|
|
dim: Union[int, Tuple[int]] = -2,
|
|
keepdim: bool = True,
|
|
eps: float = 1e-9,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Finds the mean of the input tensor across the specified dimension.
|
|
If the `weight` argument is provided, computes weighted mean.
|
|
Args:
|
|
x: tensor of shape `(*, D)`, where D is assumed to be spatial;
|
|
weights: if given, non-negative tensor of shape `(*,)`. It must be
|
|
broadcastable to `x.shape[:-1]`. Note that the weights for
|
|
the last (spatial) dimension are assumed same;
|
|
dim: dimension(s) in `x` to average over;
|
|
keepdim: tells whether to keep the resulting singleton dimension.
|
|
eps: minimum clamping value in the denominator.
|
|
Returns:
|
|
the mean tensor:
|
|
* if `weights` is None => `mean(x, dim)`,
|
|
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
|
|
"""
|
|
args = {"dim": dim, "keepdim": keepdim}
|
|
|
|
if weight is None:
|
|
return x.mean(**args)
|
|
|
|
if any(
|
|
xd != wd and xd != 1 and wd != 1
|
|
for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1])
|
|
):
|
|
raise ValueError("wmean: weights are not compatible with the tensor")
|
|
|
|
return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp(
|
|
eps
|
|
)
|
|
|
|
|
|
def eyes(
|
|
dim: int,
|
|
N: int,
|
|
device: Optional[torch.device] = None,
|
|
dtype: torch.dtype = torch.float32,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Generates a batch of `N` identity matrices of shape `(N, dim, dim)`.
|
|
|
|
Args:
|
|
**dim**: The dimensionality of the identity matrices.
|
|
**N**: The number of identity matrices.
|
|
**device**: The device to be used for allocating the matrices.
|
|
**dtype**: The datatype of the matrices.
|
|
|
|
Returns:
|
|
**identities**: A batch of identity matrices of shape `(N, dim, dim)`.
|
|
"""
|
|
identities = torch.eye(dim, device=device, dtype=dtype)
|
|
return identities[None].repeat(N, 1, 1)
|
|
|
|
|
|
def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
|
|
"""
|
|
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
|
|
padded representation and returns it together with the number of points
|
|
per batch. Otherwise, returns the input itself with the number of points
|
|
set to the size of the second dimension of `pcl`.
|
|
"""
|
|
if is_pointclouds(pcl):
|
|
X = pcl.points_padded() # type: ignore
|
|
num_points = pcl.num_points_per_cloud() # type: ignore
|
|
elif torch.is_tensor(pcl):
|
|
X = pcl
|
|
num_points = X.shape[1] * torch.ones( # type: ignore
|
|
X.shape[0], device=X.device, dtype=torch.int64
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"The inputs X, Y should be either Pointclouds objects or tensors."
|
|
)
|
|
return X, num_points
|
|
|
|
|
|
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
|
|
"""Checks whether the input `pcl` is an instance of `Pointclouds`
|
|
by checking the existence of `points_padded` and `num_points_per_cloud`
|
|
functions.
|
|
"""
|
|
return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud")
|
|
|
|
|
|
def get_point_covariances(
|
|
points_padded: torch.Tensor,
|
|
num_points_per_cloud: torch.Tensor,
|
|
neighborhood_size: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Computes the per-point covariance matrices by of the 3D locations of
|
|
K-nearest neighbors of each point.
|
|
|
|
Args:
|
|
**points_padded**: Input point clouds as a padded tensor
|
|
of shape `(minibatch, num_points, dim)`.
|
|
**num_points_per_cloud**: Number of points per cloud
|
|
of shape `(minibatch,)`.
|
|
**neighborhood_size**: Number of nearest neighbors for each point
|
|
used to estimate the covariance matrices.
|
|
|
|
Returns:
|
|
**covariances**: A batch of per-point covariance matrices
|
|
of shape `(minibatch, dim, dim)`.
|
|
**k_nearest_neighbors**: A batch of `neighborhood_size` nearest
|
|
neighbors for each of the point cloud points
|
|
of shape `(minibatch, num_points, neighborhood_size, dim)`.
|
|
"""
|
|
# get K nearest neighbor idx for each point in the point cloud
|
|
k_nearest_neighbors = knn_points(
|
|
points_padded,
|
|
points_padded,
|
|
lengths1=num_points_per_cloud,
|
|
lengths2=num_points_per_cloud,
|
|
K=neighborhood_size,
|
|
return_nn=True,
|
|
).knn
|
|
# obtain the mean of the neighborhood
|
|
pt_mean = k_nearest_neighbors.mean(2, keepdim=True)
|
|
# compute the diff of the neighborhood and the mean of the neighborhood
|
|
central_diff = k_nearest_neighbors - pt_mean
|
|
# per-nn-point covariances
|
|
per_pt_cov = central_diff.unsqueeze(4) * central_diff.unsqueeze(3)
|
|
# per-point covariances
|
|
covariances = per_pt_cov.mean(2)
|
|
|
|
return covariances, k_nearest_neighbors
|