mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
Chamfer for Pointclouds object
Summary: Allow Pointclouds objects and heterogenous data to be provided for Chamfer loss. Remove "none" as an option for point_reduction because it doesn't make sense and in the current implementation is effectively the same as "sum". Possible improvement: create specialised operations for sum and cosine_similarity of padded tensors, to avoid having to create masks. sum would be useful elsewhere. Reviewed By: gkioxari Differential Revision: D20816301 fbshipit-source-id: 0f32073210225d157c029d80de450eecdb64f4d2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
677b0bd5ae
commit
790eb8c402
@@ -1,54 +1,101 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
|
||||
from pytorch3d.ops.knn import knn_gather, knn_points
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
|
||||
def _validate_chamfer_reduction_inputs(batch_reduction: str, point_reduction: str):
|
||||
def _validate_chamfer_reduction_inputs(
|
||||
batch_reduction: Union[str, None], point_reduction: str
|
||||
):
|
||||
"""Check the requested reductions are valid.
|
||||
|
||||
Args:
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["none", "mean", "sum"].
|
||||
batch, can be one of ["mean", "sum"] or None.
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["none", "mean", "sum"].
|
||||
points, can be one of ["mean", "sum"].
|
||||
"""
|
||||
if batch_reduction not in ["none", "mean", "sum"]:
|
||||
raise ValueError('batch_reduction must be one of ["none", "mean", "sum"]')
|
||||
if point_reduction not in ["none", "mean", "sum"]:
|
||||
raise ValueError('point_reduction must be one of ["none", "mean", "sum"]')
|
||||
if batch_reduction == "none" and point_reduction == "none":
|
||||
raise ValueError('batch_reduction and point_reduction cannot both be "none".')
|
||||
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
|
||||
if point_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('point_reduction must be one of ["mean", "sum"]')
|
||||
|
||||
|
||||
def _handle_pointcloud_input(
|
||||
points: Union[torch.Tensor, Pointclouds],
|
||||
lengths: Union[torch.Tensor, None],
|
||||
normals: Union[torch.Tensor, None],
|
||||
):
|
||||
"""
|
||||
If points is an instance of Pointclouds, retrieve the padded points tensor
|
||||
along with the number of points per batch and the padded normals.
|
||||
Otherwise, return the input points (and normals) with the number of points per cloud
|
||||
set to the size of the second dimension of `points`.
|
||||
"""
|
||||
if isinstance(points, Pointclouds):
|
||||
X = points.points_padded()
|
||||
lengths = points.num_points_per_cloud()
|
||||
normals = points.normals_padded() # either a tensor or None
|
||||
elif torch.is_tensor(points):
|
||||
if points.ndim != 3:
|
||||
raise ValueError("Expected points to be of shape (N, P, D)")
|
||||
X = points
|
||||
if lengths is not None and (
|
||||
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
|
||||
):
|
||||
raise ValueError("Expected lengths to be of shape (N,)")
|
||||
if lengths is None:
|
||||
lengths = torch.full(
|
||||
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
|
||||
)
|
||||
if normals is not None and normals.ndim != 3:
|
||||
raise ValueError("Expected normals to be of shape (N, P, 3")
|
||||
else:
|
||||
raise ValueError(
|
||||
"The input pointclouds should be either "
|
||||
+ "Pointclouds objects or torch.Tensor of shape "
|
||||
+ "(minibatch, num_points, 3)."
|
||||
)
|
||||
return X, lengths, normals
|
||||
|
||||
|
||||
def chamfer_distance(
|
||||
x,
|
||||
y,
|
||||
x_lengths=None,
|
||||
y_lengths=None,
|
||||
x_normals=None,
|
||||
y_normals=None,
|
||||
weights=None,
|
||||
batch_reduction: str = "mean",
|
||||
batch_reduction: Union[str, None] = "mean",
|
||||
point_reduction: str = "mean",
|
||||
):
|
||||
"""
|
||||
Chamfer distance between two pointclouds x and y.
|
||||
|
||||
Args:
|
||||
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
|
||||
with P1 points in each batch element, batch size N and feature
|
||||
dimension D.
|
||||
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
|
||||
with P2 points in each batch element, batch size N and feature
|
||||
dimension D.
|
||||
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
|
||||
a batch of point clouds with at most P1 points in each batch element,
|
||||
batch size N and feature dimension D.
|
||||
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
|
||||
a batch of point clouds with at most P2 points in each batch element,
|
||||
batch size N and feature dimension D.
|
||||
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
|
||||
cloud in x.
|
||||
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
|
||||
cloud in x.
|
||||
x_normals: Optional FloatTensor of shape (N, P1, D).
|
||||
y_normals: Optional FloatTensor of shape (N, P2, D).
|
||||
weights: Optional FloatTensor of shape (N,) giving weights for
|
||||
batch elements for reduction operation.
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["none", "mean", "sum"].
|
||||
batch, can be one of ["mean", "sum"] or None.
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["none", "mean", "sum"].
|
||||
points, can be one of ["mean", "sum"].
|
||||
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
@@ -61,16 +108,31 @@ def chamfer_distance(
|
||||
"""
|
||||
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
||||
|
||||
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
|
||||
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
|
||||
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
|
||||
N, P1, D = x.shape
|
||||
P2 = y.shape[1]
|
||||
|
||||
# Check if inputs are heterogeneous and create a lengths mask.
|
||||
is_x_heterogeneous = ~(x_lengths == P1).all()
|
||||
is_y_heterogeneous = ~(y_lengths == P2).all()
|
||||
x_mask = (
|
||||
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
|
||||
) # shape [N, P1]
|
||||
y_mask = (
|
||||
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
|
||||
) # shape [N, P2]
|
||||
|
||||
if y.shape[0] != N or y.shape[2] != D:
|
||||
raise ValueError("y does not have the correct shape.")
|
||||
if weights is not None:
|
||||
if weights.size(0) != N:
|
||||
raise ValueError("weights must be of shape (N,).")
|
||||
if not (weights >= 0).all():
|
||||
raise ValueError("weights can not be nonnegative.")
|
||||
raise ValueError("weights cannot be negative.")
|
||||
if weights.sum() == 0.0:
|
||||
weights = weights.view(N, 1)
|
||||
if batch_reduction in ["mean", "sum"]:
|
||||
@@ -80,46 +142,60 @@ def chamfer_distance(
|
||||
)
|
||||
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
|
||||
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
cham_norm_x = x.new_zeros(())
|
||||
cham_norm_y = x.new_zeros(())
|
||||
|
||||
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
|
||||
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
|
||||
x_dists, x_idx = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
|
||||
y_dists, y_idx = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
|
||||
|
||||
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
|
||||
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
|
||||
cham_x = x_dists[..., 0] # (N, P1)
|
||||
cham_y = y_dists[..., 0] # (N, P2)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
cham_x[x_mask] = 0.0
|
||||
if is_y_heterogeneous:
|
||||
cham_y[y_mask] = 0.0
|
||||
|
||||
if weights is not None:
|
||||
cham_x *= weights.view(N, 1)
|
||||
cham_y *= weights.view(N, 1)
|
||||
|
||||
if return_normals:
|
||||
# Gather the normals using the indices and keep only value for k=0
|
||||
x_normals_near = knn_gather(y_normals, x_idx, y_lengths)[..., 0, :]
|
||||
y_normals_near = knn_gather(x_normals, y_idx, x_lengths)[..., 0, :]
|
||||
|
||||
cham_norm_x = 1 - torch.abs(
|
||||
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
|
||||
)
|
||||
cham_norm_y = 1 - torch.abs(
|
||||
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
|
||||
)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
cham_norm_x[x_mask] = 0.0
|
||||
if is_y_heterogeneous:
|
||||
cham_norm_y[y_mask] = 0.0
|
||||
|
||||
if weights is not None:
|
||||
cham_norm_x *= weights.view(N, 1)
|
||||
cham_norm_y *= weights.view(N, 1)
|
||||
|
||||
if point_reduction != "none":
|
||||
# If not 'none' then either 'sum' or 'mean'.
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
cham_y = cham_y.sum(1) # (N,)
|
||||
# Apply point reduction
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
cham_y = cham_y.sum(1) # (N,)
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
cham_x /= x_lengths
|
||||
cham_y /= y_lengths
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
cham_x /= P1
|
||||
cham_y /= P2
|
||||
if return_normals:
|
||||
cham_norm_x /= P1
|
||||
cham_norm_y /= P2
|
||||
cham_norm_x /= x_lengths
|
||||
cham_norm_y /= y_lengths
|
||||
|
||||
if batch_reduction != "none":
|
||||
if batch_reduction is not None:
|
||||
# batch_reduction == "sum"
|
||||
cham_x = cham_x.sum()
|
||||
cham_y = cham_y.sum()
|
||||
if return_normals:
|
||||
|
||||
Reference in New Issue
Block a user