Pyre Bot Jr fadec970c9 suppress errors in vision/fair/pytorch3d
Differential Revision: D30479084

fbshipit-source-id: 6b22dd0afe4dfb1be6249e43a56657519f11dcf1
2021-08-22 23:39:37 -07:00

153 lines
5.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
from .knn import knn_points
if TYPE_CHECKING:
from pytorch3d.structures import Pointclouds
def wmean(
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
dim: Union[int, Tuple[int]] = -2,
keepdim: bool = True,
eps: float = 1e-9,
) -> torch.Tensor:
"""
Finds the mean of the input tensor across the specified dimension.
If the `weight` argument is provided, computes weighted mean.
Args:
x: tensor of shape `(*, D)`, where D is assumed to be spatial;
weights: if given, non-negative tensor of shape `(*,)`. It must be
broadcastable to `x.shape[:-1]`. Note that the weights for
the last (spatial) dimension are assumed same;
dim: dimension(s) in `x` to average over;
keepdim: tells whether to keep the resulting singleton dimension.
eps: minimum clamping value in the denominator.
Returns:
the mean tensor:
* if `weights` is None => `mean(x, dim)`,
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
"""
args = {"dim": dim, "keepdim": keepdim}
if weight is None:
return x.mean(**args)
if any(
xd != wd and xd != 1 and wd != 1
for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1])
):
raise ValueError("wmean: weights are not compatible with the tensor")
return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp(
eps
)
def eyes(
dim: int,
N: int,
device: Optional[torch.device] = None,
# pyre-fixme[9]: dtype has type `dtype`; used as `Type[torch.float32]`.
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Generates a batch of `N` identity matrices of shape `(N, dim, dim)`.
Args:
**dim**: The dimensionality of the identity matrices.
**N**: The number of identity matrices.
**device**: The device to be used for allocating the matrices.
**dtype**: The datatype of the matrices.
Returns:
**identities**: A batch of identity matrices of shape `(N, dim, dim)`.
"""
identities = torch.eye(dim, device=device, dtype=dtype)
return identities[None].repeat(N, 1, 1)
def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
"""
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
padded representation and returns it together with the number of points
per batch. Otherwise, returns the input itself with the number of points
set to the size of the second dimension of `pcl`.
"""
if is_pointclouds(pcl):
X = pcl.points_padded() # type: ignore
num_points = pcl.num_points_per_cloud() # type: ignore
elif torch.is_tensor(pcl):
X = pcl
num_points = X.shape[1] * torch.ones( # type: ignore
X.shape[0], device=X.device, dtype=torch.int64
)
else:
raise ValueError(
"The inputs X, Y should be either Pointclouds objects or tensors."
)
return X, num_points
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
"""Checks whether the input `pcl` is an instance of `Pointclouds`
by checking the existence of `points_padded` and `num_points_per_cloud`
functions.
"""
return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud")
def get_point_covariances(
points_padded: torch.Tensor,
num_points_per_cloud: torch.Tensor,
neighborhood_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the per-point covariance matrices by of the 3D locations of
K-nearest neighbors of each point.
Args:
**points_padded**: Input point clouds as a padded tensor
of shape `(minibatch, num_points, dim)`.
**num_points_per_cloud**: Number of points per cloud
of shape `(minibatch,)`.
**neighborhood_size**: Number of nearest neighbors for each point
used to estimate the covariance matrices.
Returns:
**covariances**: A batch of per-point covariance matrices
of shape `(minibatch, dim, dim)`.
**k_nearest_neighbors**: A batch of `neighborhood_size` nearest
neighbors for each of the point cloud points
of shape `(minibatch, num_points, neighborhood_size, dim)`.
"""
# get K nearest neighbor idx for each point in the point cloud
k_nearest_neighbors = knn_points(
points_padded,
points_padded,
lengths1=num_points_per_cloud,
lengths2=num_points_per_cloud,
K=neighborhood_size,
return_nn=True,
).knn
# obtain the mean of the neighborhood
pt_mean = k_nearest_neighbors.mean(2, keepdim=True)
# compute the diff of the neighborhood and the mean of the neighborhood
central_diff = k_nearest_neighbors - pt_mean
# per-nn-point covariances
per_pt_cov = central_diff.unsqueeze(4) * central_diff.unsqueeze(3)
# per-point covariances
covariances = per_pt_cov.mean(2)
return covariances, k_nearest_neighbors