mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Summary: It seemed that even though the chamfer diff was rebased on top of the knn autograd diff, some of the final updates did not get applied. I'm really surprised that the sandcastle tests did not fail and prevent the diff from landing. Reviewed By: gkioxari Differential Revision: D21066156 fbshipit-source-id: 5216efe95180c1b6082d0bac404fa1920cfb7b02
216 lines
7.9 KiB
Python
216 lines
7.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
from typing import Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from pytorch3d.ops.knn import knn_gather, knn_points
|
|
from pytorch3d.structures.pointclouds import Pointclouds
|
|
|
|
|
|
def _validate_chamfer_reduction_inputs(
|
|
batch_reduction: Union[str, None], point_reduction: str
|
|
):
|
|
"""Check the requested reductions are valid.
|
|
|
|
Args:
|
|
batch_reduction: Reduction operation to apply for the loss across the
|
|
batch, can be one of ["mean", "sum"] or None.
|
|
point_reduction: Reduction operation to apply for the loss across the
|
|
points, can be one of ["mean", "sum"].
|
|
"""
|
|
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
|
|
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
|
|
if point_reduction not in ["mean", "sum"]:
|
|
raise ValueError('point_reduction must be one of ["mean", "sum"]')
|
|
|
|
|
|
def _handle_pointcloud_input(
|
|
points: Union[torch.Tensor, Pointclouds],
|
|
lengths: Union[torch.Tensor, None],
|
|
normals: Union[torch.Tensor, None],
|
|
):
|
|
"""
|
|
If points is an instance of Pointclouds, retrieve the padded points tensor
|
|
along with the number of points per batch and the padded normals.
|
|
Otherwise, return the input points (and normals) with the number of points per cloud
|
|
set to the size of the second dimension of `points`.
|
|
"""
|
|
if isinstance(points, Pointclouds):
|
|
X = points.points_padded()
|
|
lengths = points.num_points_per_cloud()
|
|
normals = points.normals_padded() # either a tensor or None
|
|
elif torch.is_tensor(points):
|
|
if points.ndim != 3:
|
|
raise ValueError("Expected points to be of shape (N, P, D)")
|
|
X = points
|
|
if lengths is not None and (
|
|
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
|
|
):
|
|
raise ValueError("Expected lengths to be of shape (N,)")
|
|
if lengths is None:
|
|
lengths = torch.full(
|
|
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
|
|
)
|
|
if normals is not None and normals.ndim != 3:
|
|
raise ValueError("Expected normals to be of shape (N, P, 3")
|
|
else:
|
|
raise ValueError(
|
|
"The input pointclouds should be either "
|
|
+ "Pointclouds objects or torch.Tensor of shape "
|
|
+ "(minibatch, num_points, 3)."
|
|
)
|
|
return X, lengths, normals
|
|
|
|
|
|
def chamfer_distance(
|
|
x,
|
|
y,
|
|
x_lengths=None,
|
|
y_lengths=None,
|
|
x_normals=None,
|
|
y_normals=None,
|
|
weights=None,
|
|
batch_reduction: Union[str, None] = "mean",
|
|
point_reduction: str = "mean",
|
|
):
|
|
"""
|
|
Chamfer distance between two pointclouds x and y.
|
|
|
|
Args:
|
|
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
|
|
a batch of point clouds with at most P1 points in each batch element,
|
|
batch size N and feature dimension D.
|
|
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
|
|
a batch of point clouds with at most P2 points in each batch element,
|
|
batch size N and feature dimension D.
|
|
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
|
|
cloud in x.
|
|
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
|
|
cloud in x.
|
|
x_normals: Optional FloatTensor of shape (N, P1, D).
|
|
y_normals: Optional FloatTensor of shape (N, P2, D).
|
|
weights: Optional FloatTensor of shape (N,) giving weights for
|
|
batch elements for reduction operation.
|
|
batch_reduction: Reduction operation to apply for the loss across the
|
|
batch, can be one of ["mean", "sum"] or None.
|
|
point_reduction: Reduction operation to apply for the loss across the
|
|
points, can be one of ["mean", "sum"].
|
|
|
|
Returns:
|
|
2-element tuple containing
|
|
|
|
- **loss**: Tensor giving the reduced distance between the pointclouds
|
|
in x and the pointclouds in y.
|
|
- **loss_normals**: Tensor giving the reduced cosine distance of normals
|
|
between pointclouds in x and pointclouds in y. Returns None if
|
|
x_normals and y_normals are None.
|
|
"""
|
|
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
|
|
|
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
|
|
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
|
|
|
|
return_normals = x_normals is not None and y_normals is not None
|
|
|
|
N, P1, D = x.shape
|
|
P2 = y.shape[1]
|
|
|
|
# Check if inputs are heterogeneous and create a lengths mask.
|
|
is_x_heterogeneous = ~(x_lengths == P1).all()
|
|
is_y_heterogeneous = ~(y_lengths == P2).all()
|
|
x_mask = (
|
|
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
|
|
) # shape [N, P1]
|
|
y_mask = (
|
|
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
|
|
) # shape [N, P2]
|
|
|
|
if y.shape[0] != N or y.shape[2] != D:
|
|
raise ValueError("y does not have the correct shape.")
|
|
if weights is not None:
|
|
if weights.size(0) != N:
|
|
raise ValueError("weights must be of shape (N,).")
|
|
if not (weights >= 0).all():
|
|
raise ValueError("weights cannot be negative.")
|
|
if weights.sum() == 0.0:
|
|
weights = weights.view(N, 1)
|
|
if batch_reduction in ["mean", "sum"]:
|
|
return (
|
|
(x.sum((1, 2)) * weights).sum() * 0.0,
|
|
(x.sum((1, 2)) * weights).sum() * 0.0,
|
|
)
|
|
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
|
|
|
|
cham_norm_x = x.new_zeros(())
|
|
cham_norm_y = x.new_zeros(())
|
|
|
|
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
|
|
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
|
|
|
|
cham_x = x_nn.dists[..., 0] # (N, P1)
|
|
cham_y = y_nn.dists[..., 0] # (N, P2)
|
|
|
|
if is_x_heterogeneous:
|
|
cham_x[x_mask] = 0.0
|
|
if is_y_heterogeneous:
|
|
cham_y[y_mask] = 0.0
|
|
|
|
if weights is not None:
|
|
cham_x *= weights.view(N, 1)
|
|
cham_y *= weights.view(N, 1)
|
|
|
|
if return_normals:
|
|
# Gather the normals using the indices and keep only value for k=0
|
|
x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
|
|
y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :]
|
|
|
|
cham_norm_x = 1 - torch.abs(
|
|
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
|
|
)
|
|
cham_norm_y = 1 - torch.abs(
|
|
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
|
|
)
|
|
|
|
if is_x_heterogeneous:
|
|
cham_norm_x[x_mask] = 0.0
|
|
if is_y_heterogeneous:
|
|
cham_norm_y[y_mask] = 0.0
|
|
|
|
if weights is not None:
|
|
cham_norm_x *= weights.view(N, 1)
|
|
cham_norm_y *= weights.view(N, 1)
|
|
|
|
# Apply point reduction
|
|
cham_x = cham_x.sum(1) # (N,)
|
|
cham_y = cham_y.sum(1) # (N,)
|
|
if return_normals:
|
|
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
|
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
|
if point_reduction == "mean":
|
|
cham_x /= x_lengths
|
|
cham_y /= y_lengths
|
|
if return_normals:
|
|
cham_norm_x /= x_lengths
|
|
cham_norm_y /= y_lengths
|
|
|
|
if batch_reduction is not None:
|
|
# batch_reduction == "sum"
|
|
cham_x = cham_x.sum()
|
|
cham_y = cham_y.sum()
|
|
if return_normals:
|
|
cham_norm_x = cham_norm_x.sum()
|
|
cham_norm_y = cham_norm_y.sum()
|
|
if batch_reduction == "mean":
|
|
div = weights.sum() if weights is not None else N
|
|
cham_x /= div
|
|
cham_y /= div
|
|
if return_normals:
|
|
cham_norm_x /= div
|
|
cham_norm_y /= div
|
|
|
|
cham_dist = cham_x + cham_y
|
|
cham_normals = cham_norm_x + cham_norm_y if return_normals else None
|
|
|
|
return cham_dist, cham_normals
|