Chamfer for Pointclouds object

Summary:
Allow Pointclouds objects and heterogenous data to be provided for Chamfer loss. Remove "none" as an option for point_reduction because it doesn't make sense and in the current implementation is effectively the same as "sum".

Possible improvement: create specialised operations for sum and cosine_similarity of padded tensors, to avoid having to create masks. sum would be useful elsewhere.

Reviewed By: gkioxari

Differential Revision: D20816301

fbshipit-source-id: 0f32073210225d157c029d80de450eecdb64f4d2
This commit is contained in:
Nikhila Ravi 2020-04-15 14:07:47 -07:00 committed by Facebook GitHub Bot
parent 677b0bd5ae
commit 790eb8c402
3 changed files with 685 additions and 173 deletions

View File

@ -1,54 +1,101 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Union
import torch
import torch.nn.functional as F
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
from pytorch3d.ops.knn import knn_gather, knn_points
from pytorch3d.structures.pointclouds import Pointclouds
def _validate_chamfer_reduction_inputs(batch_reduction: str, point_reduction: str):
def _validate_chamfer_reduction_inputs(
batch_reduction: Union[str, None], point_reduction: str
):
"""Check the requested reductions are valid.
Args:
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["none", "mean", "sum"].
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["none", "mean", "sum"].
points, can be one of ["mean", "sum"].
"""
if batch_reduction not in ["none", "mean", "sum"]:
raise ValueError('batch_reduction must be one of ["none", "mean", "sum"]')
if point_reduction not in ["none", "mean", "sum"]:
raise ValueError('point_reduction must be one of ["none", "mean", "sum"]')
if batch_reduction == "none" and point_reduction == "none":
raise ValueError('batch_reduction and point_reduction cannot both be "none".')
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
if point_reduction not in ["mean", "sum"]:
raise ValueError('point_reduction must be one of ["mean", "sum"]')
def _handle_pointcloud_input(
points: Union[torch.Tensor, Pointclouds],
lengths: Union[torch.Tensor, None],
normals: Union[torch.Tensor, None],
):
"""
If points is an instance of Pointclouds, retrieve the padded points tensor
along with the number of points per batch and the padded normals.
Otherwise, return the input points (and normals) with the number of points per cloud
set to the size of the second dimension of `points`.
"""
if isinstance(points, Pointclouds):
X = points.points_padded()
lengths = points.num_points_per_cloud()
normals = points.normals_padded() # either a tensor or None
elif torch.is_tensor(points):
if points.ndim != 3:
raise ValueError("Expected points to be of shape (N, P, D)")
X = points
if lengths is not None and (
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
):
raise ValueError("Expected lengths to be of shape (N,)")
if lengths is None:
lengths = torch.full(
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
)
if normals is not None and normals.ndim != 3:
raise ValueError("Expected normals to be of shape (N, P, 3")
else:
raise ValueError(
"The input pointclouds should be either "
+ "Pointclouds objects or torch.Tensor of shape "
+ "(minibatch, num_points, 3)."
)
return X, lengths, normals
def chamfer_distance(
x,
y,
x_lengths=None,
y_lengths=None,
x_normals=None,
y_normals=None,
weights=None,
batch_reduction: str = "mean",
batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean",
):
"""
Chamfer distance between two pointclouds x and y.
Args:
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
with P1 points in each batch element, batch size N and feature
dimension D.
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
with P2 points in each batch element, batch size N and feature
dimension D.
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
a batch of point clouds with at most P1 points in each batch element,
batch size N and feature dimension D.
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
a batch of point clouds with at most P2 points in each batch element,
batch size N and feature dimension D.
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
cloud in x.
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
cloud in x.
x_normals: Optional FloatTensor of shape (N, P1, D).
y_normals: Optional FloatTensor of shape (N, P2, D).
weights: Optional FloatTensor of shape (N,) giving weights for
batch elements for reduction operation.
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["none", "mean", "sum"].
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["none", "mean", "sum"].
points, can be one of ["mean", "sum"].
Returns:
2-element tuple containing
@ -61,16 +108,31 @@ def chamfer_distance(
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
return_normals = x_normals is not None and y_normals is not None
N, P1, D = x.shape
P2 = y.shape[1]
# Check if inputs are heterogeneous and create a lengths mask.
is_x_heterogeneous = ~(x_lengths == P1).all()
is_y_heterogeneous = ~(y_lengths == P2).all()
x_mask = (
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
) # shape [N, P1]
y_mask = (
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
) # shape [N, P2]
if y.shape[0] != N or y.shape[2] != D:
raise ValueError("y does not have the correct shape.")
if weights is not None:
if weights.size(0) != N:
raise ValueError("weights must be of shape (N,).")
if not (weights >= 0).all():
raise ValueError("weights can not be nonnegative.")
raise ValueError("weights cannot be negative.")
if weights.sum() == 0.0:
weights = weights.view(N, 1)
if batch_reduction in ["mean", "sum"]:
@ -80,46 +142,60 @@ def chamfer_distance(
)
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
return_normals = x_normals is not None and y_normals is not None
cham_norm_x = x.new_zeros(())
cham_norm_y = x.new_zeros(())
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
x_dists, x_idx = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
y_dists, y_idx = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
cham_x = x_dists[..., 0] # (N, P1)
cham_y = y_dists[..., 0] # (N, P2)
if is_x_heterogeneous:
cham_x[x_mask] = 0.0
if is_y_heterogeneous:
cham_y[y_mask] = 0.0
if weights is not None:
cham_x *= weights.view(N, 1)
cham_y *= weights.view(N, 1)
if return_normals:
# Gather the normals using the indices and keep only value for k=0
x_normals_near = knn_gather(y_normals, x_idx, y_lengths)[..., 0, :]
y_normals_near = knn_gather(x_normals, y_idx, x_lengths)[..., 0, :]
cham_norm_x = 1 - torch.abs(
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
)
cham_norm_y = 1 - torch.abs(
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
)
if is_x_heterogeneous:
cham_norm_x[x_mask] = 0.0
if is_y_heterogeneous:
cham_norm_y[y_mask] = 0.0
if weights is not None:
cham_norm_x *= weights.view(N, 1)
cham_norm_y *= weights.view(N, 1)
if point_reduction != "none":
# If not 'none' then either 'sum' or 'mean'.
cham_x = cham_x.sum(1) # (N,)
cham_y = cham_y.sum(1) # (N,)
# Apply point reduction
cham_x = cham_x.sum(1) # (N,)
cham_y = cham_y.sum(1) # (N,)
if return_normals:
cham_norm_x = cham_norm_x.sum(1) # (N,)
cham_norm_y = cham_norm_y.sum(1) # (N,)
if point_reduction == "mean":
cham_x /= x_lengths
cham_y /= y_lengths
if return_normals:
cham_norm_x = cham_norm_x.sum(1) # (N,)
cham_norm_y = cham_norm_y.sum(1) # (N,)
if point_reduction == "mean":
cham_x /= P1
cham_y /= P2
if return_normals:
cham_norm_x /= P1
cham_norm_y /= P2
cham_norm_x /= x_lengths
cham_norm_y /= y_lengths
if batch_reduction != "none":
if batch_reduction is not None:
# batch_reduction == "sum"
cham_x = cham_x.sum()
cham_y = cham_y.sum()
if return_normals:

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
@ -20,8 +21,23 @@ def bm_chamfer() -> None:
)
if torch.cuda.is_available():
kwargs_list = kwargs_list_naive + [
{"batch_size": 1, "P1": 1000, "P2": 3000, "return_normals": False},
{"batch_size": 1, "P1": 1000, "P2": 30000, "return_normals": True},
]
kwargs_list = []
batch_size = [1, 32]
P1 = [32, 1000, 10000]
P2 = [64, 3000, 30000]
return_normals = [True, False]
homogeneous = [True, False]
test_cases = product(batch_size, P1, P2, return_normals, homogeneous)
for case in test_cases:
b, p1, p2, n, h = case
kwargs_list.append(
{
"batch_size": b,
"P1": p1,
"P2": p2,
"return_normals": n,
"homogeneous": h,
}
)
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)

View File

@ -1,111 +1,442 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from pytorch3d.loss import chamfer_distance
from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.structures.utils import list_to_padded
# Output of init_pointclouds
points_normals = namedtuple(
"points_normals", "p1_lengths p2_lengths cloud1 cloud2 p1 p2 n1 n2 weights"
)
class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod
def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
"""
Randomly initialize two batches of point clouds of sizes
(N, P1, D) and (N, P2, D) and return random normal vectors for
each batch of size (N, P1, 3) and (N, P2, 3).
"""
device = torch.device("cuda:0")
p1 = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
p1_normals = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
p1_normals = p1_normals / p1_normals.norm(dim=2, p=2, keepdim=True)
p2 = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
p2_normals = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
p2_normals = p2_normals / p2_normals.norm(dim=2, p=2, keepdim=True)
weights = torch.rand((batch_size,), dtype=torch.float32, device=device)
return p1, p2, p1_normals, p2_normals, weights
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
@staticmethod
def chamfer_distance_naive(p1, p2, p1_normals=None, p2_normals=None):
def init_pointclouds(N, P1, P2, device, requires_grad: bool = True):
"""
Create 2 pointclouds object and associated padded points/normals tensors by
starting from lists. The clouds and tensors have the same data. The
leaf nodes for the clouds are a list of tensors. The padded tensor can be
used directly as a leaf node.
"""
p1_lengths = torch.randint(P1, size=(N,), dtype=torch.int64, device=device)
p2_lengths = torch.randint(P2, size=(N,), dtype=torch.int64, device=device)
weights = torch.rand((N,), dtype=torch.float32, device=device)
# list of points and normals tensors
p1_list = []
p2_list = []
n1_list = []
n2_list = []
for i in range(N):
l1 = p1_lengths[i]
l2 = p2_lengths[i]
p1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device))
p2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device))
n1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device))
n2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device))
n1_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n1_list]
n2_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n2_list]
# Clone the lists and initialize padded tensors.
p1 = list_to_padded([p.clone() for p in p1_list])
p2 = list_to_padded([p.clone() for p in p2_list])
n1 = list_to_padded([p.clone() for p in n1_list])
n2 = list_to_padded([p.clone() for p in n2_list])
# Set requires_grad for all tensors in the lists and
# padded tensors.
if requires_grad:
for p in p2_list + p1_list + n1_list + n2_list + [p1, p2, n1, n2]:
p.requires_grad = True
# Create pointclouds objects
cloud1 = Pointclouds(points=p1_list, normals=n1_list)
cloud2 = Pointclouds(points=p2_list, normals=n2_list)
# Return pointclouds objects and padded tensors
return points_normals(
p1_lengths=p1_lengths,
p2_lengths=p2_lengths,
cloud1=cloud1,
cloud2=cloud2,
p1=p1,
p2=p2,
n1=n1,
n2=n2,
weights=weights,
)
@staticmethod
def chamfer_distance_naive_pointclouds(p1, p2):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals.
This functions supports heterogeneous pointclouds in a batch.
Returns lists of the unreduced loss and loss_normals.
"""
N, P1, D = p1.shape
P2 = p2.size(1)
x = p1.points_padded()
y = p2.points_padded()
N, P1, D = x.shape
P2 = y.size(1)
x_lengths = p1.num_points_per_cloud()
y_lengths = p2.num_points_per_cloud()
x_normals = p1.normals_padded()
y_normals = p2.normals_padded()
device = torch.device("cuda:0")
return_normals = p1_normals is not None and p2_normals is not None
return_normals = x_normals is not None and y_normals is not None
# Initialize all distances to + inf
dist = torch.ones((N, P1, P2), dtype=torch.float32, device=device) * np.inf
x_mask = (
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
) # shape [N, P1]
y_mask = (
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
) # shape [N, P2]
is_x_heterogeneous = ~(x_lengths == P1).all()
is_y_heterogeneous = ~(y_lengths == P2).all()
# Only calculate the distances for the points which are not masked
for n in range(N):
for i1 in range(x_lengths[n]):
for i2 in range(y_lengths[n]):
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
x_dist = torch.min(dist, dim=2)[0] # (N, P1)
y_dist = torch.min(dist, dim=1)[0] # (N, P2)
if is_x_heterogeneous:
x_dist[x_mask] = 0.0
if is_y_heterogeneous:
y_dist[y_mask] = 0.0
loss = [x_dist, y_dist]
lnorm = [x.new_zeros(()), x.new_zeros(())]
if return_normals:
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
lnorm1 = 1 - torch.abs(
F.cosine_similarity(
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
)
lnorm2 = 1 - torch.abs(
F.cosine_similarity(
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
)
if is_x_heterogeneous:
lnorm1[x_mask] = 0.0
if is_y_heterogeneous:
lnorm2[y_mask] = 0.0
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
return loss, lnorm
@staticmethod
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
Returns lists of the unreduced loss and loss_normals. This naive
version only supports homogeneous pointcouds in a batch.
"""
N, P1, D = x.shape
P2 = y.size(1)
device = torch.device("cuda:0")
return_normals = x_normals is not None and y_normals is not None
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
for n in range(N):
for i1 in range(P1):
for i2 in range(P2):
dist[n, i1, i2] = torch.sum((p1[n, i1, :] - p2[n, i2, :]) ** 2)
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
loss = [
torch.min(dist, dim=2)[0], # (N, P1)
torch.min(dist, dim=1)[0], # (N, P2)
]
lnorm = [p1.new_zeros(()), p1.new_zeros(())]
lnorm = [x.new_zeros(()), x.new_zeros(())]
if return_normals:
p1_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
p2_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
lnorm1 = 1 - torch.abs(
F.cosine_similarity(
p1_normals, p2_normals.gather(1, p1_index), dim=2, eps=1e-6
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
)
)
lnorm2 = 1 - torch.abs(
F.cosine_similarity(
p2_normals, p1_normals.gather(1, p2_index), dim=2, eps=1e-6
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
)
)
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
return loss, lnorm
def test_chamfer_default_no_normals(self):
def test_chamfer_point_batch_reduction_mean(self):
"""
Compare chamfer loss with naive implementation using default
input values and no normals.
Compare output of vectorized chamfer loss with naive implementation
for the default settings (point_reduction = "mean" and batch_reduction = "mean")
and no normals.
This tests only uses homogeneous pointclouds.
"""
N, P1, P2 = 7, 10, 18
p1, p2, _, _, weights = TestChamfer.init_pointclouds(N, P1, P2)
pred_loss, _ = TestChamfer.chamfer_distance_naive(p1, p2)
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
P2 = p2.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2)
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(p11, p22, weights=weights)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
def test_chamfer_point_reduction(self):
# Check gradients
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
def test_chamfer_vs_naive_pointcloud(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction in ["mean", "sum", "none"] and
batch_reduction = "none".
Test the default settings for chamfer_distance
(point reduction = "mean" and batch_reduction="mean") but with heterogeneous
pointclouds as input. Compare with the naive implementation of chamfer
which supports heterogeneous pointcloud objects.
"""
N, P1, P2 = 7, 10, 18
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
N, P1, P2
N, max_P1, max_P2 = 3, 70, 70
device = "cuda:0"
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
)
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2
)
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss_mean = (
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
)
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss[1] *= weights.view(N, 1)
pred_norm_loss_mean = (
pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths
)
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
def test_chamfer_pointcloud_object_withnormals(self):
N = 5
P1, P2 = 100, 100
device = "cuda:0"
reductions = [
("sum", "sum"),
("mean", "sum"),
("sum", "mean"),
("mean", "mean"),
("sum", None),
("mean", None),
]
for (point_reduction, batch_reduction) in reductions:
# Reinitialize all the tensors so that the
# backward pass can be computed.
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
# Chamfer with pointclouds as input.
cham_cloud, norm_cloud = chamfer_distance(
points_normals.cloud1,
points_normals.cloud2,
point_reduction=point_reduction,
batch_reduction=batch_reduction,
)
# Chamfer with tensors as input.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
point_reduction=point_reduction,
batch_reduction=batch_reduction,
)
self.assertClose(cham_cloud, cham_tensor)
self.assertClose(norm_cloud, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
cham_cloud,
norm_cloud,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
points_normals.p1_lengths,
points_normals.p2_lengths,
)
def test_chamfer_pointcloud_object_nonormals(self):
N = 5
P1, P2 = 100, 100
device = "cuda:0"
reductions = [
("sum", "sum"),
("mean", "sum"),
("sum", "mean"),
("mean", "mean"),
("sum", None),
("mean", None),
]
for (point_reduction, batch_reduction) in reductions:
# Reinitialize all the tensors so that the
# backward pass can be computed.
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
# Chamfer with pointclouds as input.
cham_cloud, _ = chamfer_distance(
points_normals.cloud1,
points_normals.cloud2,
point_reduction=point_reduction,
batch_reduction=batch_reduction,
)
# Chamfer with tensors as input.
cham_tensor, _ = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
point_reduction=point_reduction,
batch_reduction=batch_reduction,
)
self.assertClose(cham_cloud, cham_tensor)
self._check_gradients(
cham_tensor,
None,
cham_cloud,
None,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
lengths1=points_normals.p1_lengths,
lengths2=points_normals.p2_lengths,
)
def test_chamfer_point_reduction_mean(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = "mean" and batch_reduction = None.
"""
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
P2 = p2.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, p1_normals, p2_normals
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction="none",
batch_reduction=None,
point_reduction="mean",
)
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
@ -118,14 +449,40 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
pred_loss_norm_mean *= weights
self.assertClose(loss_norm, pred_loss_norm_mean)
# point_reduction = "sum".
# Check gradients
self._check_gradients(
loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
)
def test_chamfer_point_reduction_sum(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = "sum" and batch_reduction = None.
"""
N, P1, P2 = 7, 10, 18
device = "cuda:0"
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
p11,
p22,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction="none",
batch_reduction=None,
point_reduction="sum",
)
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
@ -136,92 +493,110 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
pred_loss_norm_sum *= weights
self.assertClose(loss_norm, pred_loss_norm_sum)
# Error when point_reduction = "none" and batch_reduction = "none".
with self.assertRaises(ValueError):
chamfer_distance(
p1, p2, weights=weights, batch_reduction="none", point_reduction="none"
# Check gradients
self._check_gradients(
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
)
def _check_gradients(
self,
loss,
loss_norm,
pred_loss,
pred_loss_norm,
x1,
x2,
y1,
y2,
xn1=None, # normals
xn2=None, # normals
yn1=None, # normals
yn2=None, # normals
lengths1=None,
lengths2=None,
):
"""
x1 and x2 can have different types based on the leaf node used in the calculation:
e.g. x1 may be a list of tensors whereas x2 is a padded tensor.
This also applies for the pairs: (y1, y2), (xn1, xn2), (yn1, yn2).
"""
grad_loss = torch.rand(loss.shape, device=loss.device, dtype=loss.dtype)
# Loss for normals is optional. Iniitalize to 0.
norm_loss_term = pred_norm_loss_term = 0.0
if loss_norm is not None and pred_loss_norm is not None:
grad_normals = torch.rand(
loss_norm.shape, device=loss.device, dtype=loss.dtype
)
norm_loss_term = loss_norm * grad_normals
pred_norm_loss_term = pred_loss_norm * grad_normals
# Error when batch_reduction is not in ["none", "mean", "sum"].
with self.assertRaises(ValueError):
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
l1 = (loss * grad_loss) + norm_loss_term
l1.sum().backward()
l2 = (pred_loss * grad_loss) + pred_norm_loss_term
l2.sum().backward()
def test_chamfer_batch_reduction(self):
self._check_grad_by_type(x1, x2, lengths1)
self._check_grad_by_type(y1, y2, lengths2)
# If leaf nodes for normals are passed in, check their gradients.
if all(n is not None for n in [xn1, xn2, yn1, yn2]):
self._check_grad_by_type(xn1, xn2, lengths1)
self._check_grad_by_type(yn1, yn2, lengths2)
def _check_grad_by_type(self, x1, x2, lengths=None):
"""
Compare output of vectorized chamfer loss with naive implementation
for batch_reduction in ["mean", "sum"] and point_reduction = "none".
x1 and x2 can be of different types e.g. list or tensor - compare appropriately
based on the types.
"""
N, P1, P2 = 7, 10, 18
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
N, P1, P2
)
error_msg = "All values for gradient checks must be tensors or lists of tensors"
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, p1_normals, p2_normals
)
if all(isinstance(p, list) for p in [x1, x2]):
# Lists of tensors
for i in range(len(x1)):
self.assertClose(x1[i].grad, x2[i].grad)
elif isinstance(x1, list) and torch.is_tensor(x2):
self.assertIsNotNone(lengths) # lengths is required
# batch_reduction = "sum".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="sum",
point_reduction="none",
)
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
self.assertClose(loss, pred_loss)
pred_loss_norm[0] *= weights.view(N, 1)
pred_loss_norm[1] *= weights.view(N, 1)
pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
self.assertClose(loss_norm, pred_loss_norm)
# batch_reduction = "mean".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="mean",
point_reduction="none",
)
pred_loss /= weights.sum()
self.assertClose(loss, pred_loss)
pred_loss_norm /= weights.sum()
self.assertClose(loss_norm, pred_loss_norm)
# Error when point_reduction is not in ["none", "mean", "sum"].
with self.assertRaises(ValueError):
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
# List of tensors vs padded tensor
for i in range(len(x1)):
self.assertClose(x1[i].grad, x2.grad[i, : lengths[i]])
self.assertTrue(x2.grad[i, lengths[i] :].sum().item() == 0.0)
elif all(torch.is_tensor(p) for p in [x1, x2]):
# Two tensors
self.assertClose(x1.grad, x2.grad)
else:
raise ValueError(error_msg)
def test_chamfer_joint_reduction(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for batch_reduction in ["mean", "sum"] and
when batch_reduction in ["mean", "sum"] and
point_reduction in ["mean", "sum"].
"""
N, P1, P2 = 7, 10, 18
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
N, P1, P2
)
N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0"
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
p2_normals = points_normals.n2
weights = points_normals.weights
P1 = p1.shape[1]
P2 = p2.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, p1_normals, p2_normals
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
# batch_reduction = "sum", point_reduction = "sum".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction="sum",
point_reduction="sum",
@ -244,8 +619,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction="mean",
point_reduction="sum",
@ -260,8 +635,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction="sum",
point_reduction="mean",
@ -280,8 +655,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
batch_reduction="mean",
point_reduction="mean",
@ -292,6 +667,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
pred_loss_norm_mean /= weights.sum()
self.assertClose(loss_norm, pred_loss_norm_mean)
# Error when batch_reduction is not in ["mean", "sum"] or None.
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
# Error when point_reduction is not in ["mean", "sum"].
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
chamfer_distance(p1, p2, weights=weights, point_reduction=None)
def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128
device = torch.device("cuda:0")
@ -312,7 +695,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
self.assertTrue(loss_norm.requires_grad)
loss, loss_norm = chamfer_distance(
p1, p2, weights=weights, batch_reduction="none"
p1, p2, weights=weights, batch_reduction=None
)
self.assertClose(loss.cpu(), torch.zeros((N, N)))
self.assertTrue(loss.requires_grad)
@ -327,16 +710,53 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError):
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
def test_incorrect_inputs(self):
N, P1, P2 = 7, 10, 18
device = "cuda:0"
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
p1_normals = points_normals.n1
# Normals of wrong shape
with self.assertRaisesRegex(ValueError, "Expected normals to be of shape"):
chamfer_distance(p1, p2, x_normals=p1_normals[None])
# Points of wrong shape
with self.assertRaisesRegex(ValueError, "Expected points to be of shape"):
chamfer_distance(p1[None], p2)
# Lengths of wrong shape
with self.assertRaisesRegex(ValueError, "Expected lengths to be of shape"):
chamfer_distance(p1, p2, x_lengths=torch.tensor([1, 2, 3], device=device))
# Points are not a tensor or Pointclouds
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
@staticmethod
def chamfer_with_init(batch_size: int, P1: int, P2: int, return_normals: bool):
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
def chamfer_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool
):
p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds(
batch_size, P1, P2
)
if homogeneous:
# Set lengths to None so in Chamfer it assumes
# there is no padding.
l1 = l2 = None
torch.cuda.synchronize()
def loss():
loss, loss_normals = chamfer_distance(
p1, p2, p1_normals, p2_normals, weights=weights
p1,
p2,
x_lengths=l1,
y_lengths=l2,
x_normals=p1_normals,
y_normals=p2_normals,
weights=weights,
)
torch.cuda.synchronize()
@ -346,14 +766,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def chamfer_naive_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool
):
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds(
batch_size, P1, P2
)
torch.cuda.synchronize()
def loss():
loss, loss_normals = TestChamfer.chamfer_distance_naive(
p1, p2, p1_normals, p2_normals
p1, p2, x_normals=p1_normals, y_normals=p2_normals
)
torch.cuda.synchronize()