diff --git a/pytorch3d/loss/chamfer.py b/pytorch3d/loss/chamfer.py index bd7c420c..58ac3daa 100644 --- a/pytorch3d/loss/chamfer.py +++ b/pytorch3d/loss/chamfer.py @@ -1,54 +1,101 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Union + import torch import torch.nn.functional as F -from pytorch3d.ops.nearest_neighbor_points import nn_points_idx +from pytorch3d.ops.knn import knn_gather, knn_points +from pytorch3d.structures.pointclouds import Pointclouds -def _validate_chamfer_reduction_inputs(batch_reduction: str, point_reduction: str): +def _validate_chamfer_reduction_inputs( + batch_reduction: Union[str, None], point_reduction: str +): """Check the requested reductions are valid. Args: batch_reduction: Reduction operation to apply for the loss across the - batch, can be one of ["none", "mean", "sum"]. + batch, can be one of ["mean", "sum"] or None. point_reduction: Reduction operation to apply for the loss across the - points, can be one of ["none", "mean", "sum"]. + points, can be one of ["mean", "sum"]. """ - if batch_reduction not in ["none", "mean", "sum"]: - raise ValueError('batch_reduction must be one of ["none", "mean", "sum"]') - if point_reduction not in ["none", "mean", "sum"]: - raise ValueError('point_reduction must be one of ["none", "mean", "sum"]') - if batch_reduction == "none" and point_reduction == "none": - raise ValueError('batch_reduction and point_reduction cannot both be "none".') + if batch_reduction is not None and batch_reduction not in ["mean", "sum"]: + raise ValueError('batch_reduction must be one of ["mean", "sum"] or None') + if point_reduction not in ["mean", "sum"]: + raise ValueError('point_reduction must be one of ["mean", "sum"]') + + +def _handle_pointcloud_input( + points: Union[torch.Tensor, Pointclouds], + lengths: Union[torch.Tensor, None], + normals: Union[torch.Tensor, None], +): + """ + If points is an instance of Pointclouds, retrieve the padded points tensor + along with the number of points per batch and the padded normals. + Otherwise, return the input points (and normals) with the number of points per cloud + set to the size of the second dimension of `points`. + """ + if isinstance(points, Pointclouds): + X = points.points_padded() + lengths = points.num_points_per_cloud() + normals = points.normals_padded() # either a tensor or None + elif torch.is_tensor(points): + if points.ndim != 3: + raise ValueError("Expected points to be of shape (N, P, D)") + X = points + if lengths is not None and ( + lengths.ndim != 1 or lengths.shape[0] != X.shape[0] + ): + raise ValueError("Expected lengths to be of shape (N,)") + if lengths is None: + lengths = torch.full( + (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device + ) + if normals is not None and normals.ndim != 3: + raise ValueError("Expected normals to be of shape (N, P, 3") + else: + raise ValueError( + "The input pointclouds should be either " + + "Pointclouds objects or torch.Tensor of shape " + + "(minibatch, num_points, 3)." + ) + return X, lengths, normals def chamfer_distance( x, y, + x_lengths=None, + y_lengths=None, x_normals=None, y_normals=None, weights=None, - batch_reduction: str = "mean", + batch_reduction: Union[str, None] = "mean", point_reduction: str = "mean", ): """ Chamfer distance between two pointclouds x and y. Args: - x: FloatTensor of shape (N, P1, D) representing a batch of point clouds - with P1 points in each batch element, batch size N and feature - dimension D. - y: FloatTensor of shape (N, P2, D) representing a batch of point clouds - with P2 points in each batch element, batch size N and feature - dimension D. + x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing + a batch of point clouds with at most P1 points in each batch element, + batch size N and feature dimension D. + y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing + a batch of point clouds with at most P2 points in each batch element, + batch size N and feature dimension D. + x_lengths: Optional LongTensor of shape (N,) giving the number of points in each + cloud in x. + y_lengths: Optional LongTensor of shape (N,) giving the number of points in each + cloud in x. x_normals: Optional FloatTensor of shape (N, P1, D). y_normals: Optional FloatTensor of shape (N, P2, D). weights: Optional FloatTensor of shape (N,) giving weights for batch elements for reduction operation. batch_reduction: Reduction operation to apply for the loss across the - batch, can be one of ["none", "mean", "sum"]. + batch, can be one of ["mean", "sum"] or None. point_reduction: Reduction operation to apply for the loss across the - points, can be one of ["none", "mean", "sum"]. + points, can be one of ["mean", "sum"]. Returns: 2-element tuple containing @@ -61,16 +108,31 @@ def chamfer_distance( """ _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) + x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) + y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) + + return_normals = x_normals is not None and y_normals is not None + N, P1, D = x.shape P2 = y.shape[1] + # Check if inputs are heterogeneous and create a lengths mask. + is_x_heterogeneous = ~(x_lengths == P1).all() + is_y_heterogeneous = ~(y_lengths == P2).all() + x_mask = ( + torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] + ) # shape [N, P1] + y_mask = ( + torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] + ) # shape [N, P2] + if y.shape[0] != N or y.shape[2] != D: raise ValueError("y does not have the correct shape.") if weights is not None: if weights.size(0) != N: raise ValueError("weights must be of shape (N,).") if not (weights >= 0).all(): - raise ValueError("weights can not be nonnegative.") + raise ValueError("weights cannot be negative.") if weights.sum() == 0.0: weights = weights.view(N, 1) if batch_reduction in ["mean", "sum"]: @@ -80,46 +142,60 @@ def chamfer_distance( ) return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0) - return_normals = x_normals is not None and y_normals is not None cham_norm_x = x.new_zeros(()) cham_norm_y = x.new_zeros(()) - x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals) - y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals) + x_dists, x_idx = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) + y_dists, y_idx = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) - cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1) - cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2) + cham_x = x_dists[..., 0] # (N, P1) + cham_y = y_dists[..., 0] # (N, P2) + + if is_x_heterogeneous: + cham_x[x_mask] = 0.0 + if is_y_heterogeneous: + cham_y[y_mask] = 0.0 if weights is not None: cham_x *= weights.view(N, 1) cham_y *= weights.view(N, 1) if return_normals: + # Gather the normals using the indices and keep only value for k=0 + x_normals_near = knn_gather(y_normals, x_idx, y_lengths)[..., 0, :] + y_normals_near = knn_gather(x_normals, y_idx, x_lengths)[..., 0, :] + cham_norm_x = 1 - torch.abs( F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6) ) cham_norm_y = 1 - torch.abs( F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6) ) + + if is_x_heterogeneous: + cham_norm_x[x_mask] = 0.0 + if is_y_heterogeneous: + cham_norm_y[y_mask] = 0.0 + if weights is not None: cham_norm_x *= weights.view(N, 1) cham_norm_y *= weights.view(N, 1) - if point_reduction != "none": - # If not 'none' then either 'sum' or 'mean'. - cham_x = cham_x.sum(1) # (N,) - cham_y = cham_y.sum(1) # (N,) + # Apply point reduction + cham_x = cham_x.sum(1) # (N,) + cham_y = cham_y.sum(1) # (N,) + if return_normals: + cham_norm_x = cham_norm_x.sum(1) # (N,) + cham_norm_y = cham_norm_y.sum(1) # (N,) + if point_reduction == "mean": + cham_x /= x_lengths + cham_y /= y_lengths if return_normals: - cham_norm_x = cham_norm_x.sum(1) # (N,) - cham_norm_y = cham_norm_y.sum(1) # (N,) - if point_reduction == "mean": - cham_x /= P1 - cham_y /= P2 - if return_normals: - cham_norm_x /= P1 - cham_norm_y /= P2 + cham_norm_x /= x_lengths + cham_norm_y /= y_lengths - if batch_reduction != "none": + if batch_reduction is not None: + # batch_reduction == "sum" cham_x = cham_x.sum() cham_y = cham_y.sum() if return_normals: diff --git a/tests/bm_chamfer.py b/tests/bm_chamfer.py index 500019d9..0dcdb803 100644 --- a/tests/bm_chamfer.py +++ b/tests/bm_chamfer.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from itertools import product import torch from fvcore.common.benchmark import benchmark @@ -20,8 +21,23 @@ def bm_chamfer() -> None: ) if torch.cuda.is_available(): - kwargs_list = kwargs_list_naive + [ - {"batch_size": 1, "P1": 1000, "P2": 3000, "return_normals": False}, - {"batch_size": 1, "P1": 1000, "P2": 30000, "return_normals": True}, - ] + kwargs_list = [] + batch_size = [1, 32] + P1 = [32, 1000, 10000] + P2 = [64, 3000, 30000] + return_normals = [True, False] + homogeneous = [True, False] + test_cases = product(batch_size, P1, P2, return_normals, homogeneous) + + for case in test_cases: + b, p1, p2, n, h = case + kwargs_list.append( + { + "batch_size": b, + "P1": p1, + "P2": p2, + "return_normals": n, + "homogeneous": h, + } + ) benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1) diff --git a/tests/test_chamfer.py b/tests/test_chamfer.py index 05bde5ef..ff72fdc9 100644 --- a/tests/test_chamfer.py +++ b/tests/test_chamfer.py @@ -1,111 +1,442 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import unittest +from collections import namedtuple +import numpy as np import torch import torch.nn.functional as F from common_testing import TestCaseMixin from pytorch3d.loss import chamfer_distance +from pytorch3d.structures.pointclouds import Pointclouds +from pytorch3d.structures.utils import list_to_padded + + +# Output of init_pointclouds +points_normals = namedtuple( + "points_normals", "p1_lengths p2_lengths cloud1 cloud2 p1 p2 n1 n2 weights" +) class TestChamfer(TestCaseMixin, unittest.TestCase): - @staticmethod - def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64): - """ - Randomly initialize two batches of point clouds of sizes - (N, P1, D) and (N, P2, D) and return random normal vectors for - each batch of size (N, P1, 3) and (N, P2, 3). - """ - device = torch.device("cuda:0") - p1 = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device) - p1_normals = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device) - p1_normals = p1_normals / p1_normals.norm(dim=2, p=2, keepdim=True) - p2 = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device) - p2_normals = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device) - p2_normals = p2_normals / p2_normals.norm(dim=2, p=2, keepdim=True) - weights = torch.rand((batch_size,), dtype=torch.float32, device=device) - - return p1, p2, p1_normals, p2_normals, weights + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) @staticmethod - def chamfer_distance_naive(p1, p2, p1_normals=None, p2_normals=None): + def init_pointclouds(N, P1, P2, device, requires_grad: bool = True): + """ + Create 2 pointclouds object and associated padded points/normals tensors by + starting from lists. The clouds and tensors have the same data. The + leaf nodes for the clouds are a list of tensors. The padded tensor can be + used directly as a leaf node. + """ + p1_lengths = torch.randint(P1, size=(N,), dtype=torch.int64, device=device) + p2_lengths = torch.randint(P2, size=(N,), dtype=torch.int64, device=device) + weights = torch.rand((N,), dtype=torch.float32, device=device) + + # list of points and normals tensors + p1_list = [] + p2_list = [] + n1_list = [] + n2_list = [] + for i in range(N): + l1 = p1_lengths[i] + l2 = p2_lengths[i] + p1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device)) + p2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device)) + n1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device)) + n2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device)) + + n1_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n1_list] + n2_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n2_list] + + # Clone the lists and initialize padded tensors. + p1 = list_to_padded([p.clone() for p in p1_list]) + p2 = list_to_padded([p.clone() for p in p2_list]) + n1 = list_to_padded([p.clone() for p in n1_list]) + n2 = list_to_padded([p.clone() for p in n2_list]) + + # Set requires_grad for all tensors in the lists and + # padded tensors. + if requires_grad: + for p in p2_list + p1_list + n1_list + n2_list + [p1, p2, n1, n2]: + p.requires_grad = True + + # Create pointclouds objects + cloud1 = Pointclouds(points=p1_list, normals=n1_list) + cloud2 = Pointclouds(points=p2_list, normals=n2_list) + + # Return pointclouds objects and padded tensors + return points_normals( + p1_lengths=p1_lengths, + p2_lengths=p2_lengths, + cloud1=cloud1, + cloud2=cloud2, + p1=p1, + p2=p2, + n1=n1, + n2=n2, + weights=weights, + ) + + @staticmethod + def chamfer_distance_naive_pointclouds(p1, p2): """ Naive iterative implementation of nearest neighbor and chamfer distance. + x and y are assumed to be pointclouds objects with points and optionally normals. + This functions supports heterogeneous pointclouds in a batch. Returns lists of the unreduced loss and loss_normals. """ - N, P1, D = p1.shape - P2 = p2.size(1) + x = p1.points_padded() + y = p2.points_padded() + N, P1, D = x.shape + P2 = y.size(1) + x_lengths = p1.num_points_per_cloud() + y_lengths = p2.num_points_per_cloud() + x_normals = p1.normals_padded() + y_normals = p2.normals_padded() + device = torch.device("cuda:0") - return_normals = p1_normals is not None and p2_normals is not None + return_normals = x_normals is not None and y_normals is not None + + # Initialize all distances to + inf + dist = torch.ones((N, P1, P2), dtype=torch.float32, device=device) * np.inf + + x_mask = ( + torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] + ) # shape [N, P1] + y_mask = ( + torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] + ) # shape [N, P2] + + is_x_heterogeneous = ~(x_lengths == P1).all() + is_y_heterogeneous = ~(y_lengths == P2).all() + + # Only calculate the distances for the points which are not masked + for n in range(N): + for i1 in range(x_lengths[n]): + for i2 in range(y_lengths[n]): + dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2) + + x_dist = torch.min(dist, dim=2)[0] # (N, P1) + y_dist = torch.min(dist, dim=1)[0] # (N, P2) + + if is_x_heterogeneous: + x_dist[x_mask] = 0.0 + if is_y_heterogeneous: + y_dist[y_mask] = 0.0 + + loss = [x_dist, y_dist] + + lnorm = [x.new_zeros(()), x.new_zeros(())] + + if return_normals: + x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3) + y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3) + lnorm1 = 1 - torch.abs( + F.cosine_similarity( + x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 + ) + ) + lnorm2 = 1 - torch.abs( + F.cosine_similarity( + y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 + ) + ) + + if is_x_heterogeneous: + lnorm1[x_mask] = 0.0 + if is_y_heterogeneous: + lnorm2[y_mask] = 0.0 + + lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)] + + return loss, lnorm + + @staticmethod + def chamfer_distance_naive(x, y, x_normals=None, y_normals=None): + """ + Naive iterative implementation of nearest neighbor and chamfer distance. + Returns lists of the unreduced loss and loss_normals. This naive + version only supports homogeneous pointcouds in a batch. + """ + N, P1, D = x.shape + P2 = y.size(1) + device = torch.device("cuda:0") + return_normals = x_normals is not None and y_normals is not None dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device) for n in range(N): for i1 in range(P1): for i2 in range(P2): - dist[n, i1, i2] = torch.sum((p1[n, i1, :] - p2[n, i2, :]) ** 2) + dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2) loss = [ torch.min(dist, dim=2)[0], # (N, P1) torch.min(dist, dim=1)[0], # (N, P2) ] - - lnorm = [p1.new_zeros(()), p1.new_zeros(())] + lnorm = [x.new_zeros(()), x.new_zeros(())] if return_normals: - p1_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3) - p2_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3) + x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3) + y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3) lnorm1 = 1 - torch.abs( F.cosine_similarity( - p1_normals, p2_normals.gather(1, p1_index), dim=2, eps=1e-6 + x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6 ) ) lnorm2 = 1 - torch.abs( F.cosine_similarity( - p2_normals, p1_normals.gather(1, p2_index), dim=2, eps=1e-6 + y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6 ) ) lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)] return loss, lnorm - def test_chamfer_default_no_normals(self): + def test_chamfer_point_batch_reduction_mean(self): """ - Compare chamfer loss with naive implementation using default - input values and no normals. + Compare output of vectorized chamfer loss with naive implementation + for the default settings (point_reduction = "mean" and batch_reduction = "mean") + and no normals. + This tests only uses homogeneous pointclouds. """ - N, P1, P2 = 7, 10, 18 - p1, p2, _, _, weights = TestChamfer.init_pointclouds(N, P1, P2) - pred_loss, _ = TestChamfer.chamfer_distance_naive(p1, p2) - loss, loss_norm = chamfer_distance(p1, p2, weights=weights) + N, max_P1, max_P2 = 7, 10, 18 + device = "cuda:0" + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + P1 = p1.shape[1] + P2 = p2.shape[1] + + pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2) + + # point_reduction = "mean". + loss, loss_norm = chamfer_distance(p11, p22, weights=weights) pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2 pred_loss *= weights pred_loss = pred_loss.sum() / weights.sum() + self.assertClose(loss, pred_loss) self.assertTrue(loss_norm is None) - def test_chamfer_point_reduction(self): + # Check gradients + self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22) + + def test_chamfer_vs_naive_pointcloud(self): """ - Compare output of vectorized chamfer loss with naive implementation - for point_reduction in ["mean", "sum", "none"] and - batch_reduction = "none". + Test the default settings for chamfer_distance + (point reduction = "mean" and batch_reduction="mean") but with heterogeneous + pointclouds as input. Compare with the naive implementation of chamfer + which supports heterogeneous pointcloud objects. """ - N, P1, P2 = 7, 10, 18 - p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds( - N, P1, P2 + N, max_P1, max_P2 = 3, 70, 70 + device = "cuda:0" + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + weights = points_normals.weights + x_lengths = points_normals.p1_lengths + y_lengths = points_normals.p2_lengths + + # Chamfer with tensors as input for heterogeneous pointclouds. + cham_tensor, norm_tensor = chamfer_distance( + points_normals.p1, + points_normals.p2, + x_normals=points_normals.n1, + y_normals=points_normals.n2, + x_lengths=points_normals.p1_lengths, + y_lengths=points_normals.p2_lengths, + weights=weights, ) + # Chamfer with pointclouds as input. + pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds( + points_normals.cloud1, points_normals.cloud2 + ) + + # Mean reduction point loss. + pred_loss[0] *= weights.view(N, 1) + pred_loss[1] *= weights.view(N, 1) + pred_loss_mean = ( + pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths + ) + pred_loss_mean = pred_loss_mean.sum() + pred_loss_mean /= weights.sum() + + # Mean reduction norm loss. + pred_norm_loss[0] *= weights.view(N, 1) + pred_norm_loss[1] *= weights.view(N, 1) + pred_norm_loss_mean = ( + pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths + ) + pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum() + + self.assertClose(pred_loss_mean, cham_tensor) + self.assertClose(pred_norm_loss_mean, norm_tensor) + + self._check_gradients( + cham_tensor, + norm_tensor, + pred_loss_mean, + pred_norm_loss_mean, + points_normals.cloud1.points_list(), + points_normals.p1, + points_normals.cloud2.points_list(), + points_normals.p2, + points_normals.cloud1.normals_list(), + points_normals.n1, + points_normals.cloud2.normals_list(), + points_normals.n2, + x_lengths, + y_lengths, + ) + + def test_chamfer_pointcloud_object_withnormals(self): + N = 5 + P1, P2 = 100, 100 + device = "cuda:0" + + reductions = [ + ("sum", "sum"), + ("mean", "sum"), + ("sum", "mean"), + ("mean", "mean"), + ("sum", None), + ("mean", None), + ] + for (point_reduction, batch_reduction) in reductions: + + # Reinitialize all the tensors so that the + # backward pass can be computed. + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + + # Chamfer with pointclouds as input. + cham_cloud, norm_cloud = chamfer_distance( + points_normals.cloud1, + points_normals.cloud2, + point_reduction=point_reduction, + batch_reduction=batch_reduction, + ) + + # Chamfer with tensors as input. + cham_tensor, norm_tensor = chamfer_distance( + points_normals.p1, + points_normals.p2, + x_lengths=points_normals.p1_lengths, + y_lengths=points_normals.p2_lengths, + x_normals=points_normals.n1, + y_normals=points_normals.n2, + point_reduction=point_reduction, + batch_reduction=batch_reduction, + ) + + self.assertClose(cham_cloud, cham_tensor) + self.assertClose(norm_cloud, norm_tensor) + self._check_gradients( + cham_tensor, + norm_tensor, + cham_cloud, + norm_cloud, + points_normals.cloud1.points_list(), + points_normals.p1, + points_normals.cloud2.points_list(), + points_normals.p2, + points_normals.cloud1.normals_list(), + points_normals.n1, + points_normals.cloud2.normals_list(), + points_normals.n2, + points_normals.p1_lengths, + points_normals.p2_lengths, + ) + + def test_chamfer_pointcloud_object_nonormals(self): + N = 5 + P1, P2 = 100, 100 + device = "cuda:0" + + reductions = [ + ("sum", "sum"), + ("mean", "sum"), + ("sum", "mean"), + ("mean", "mean"), + ("sum", None), + ("mean", None), + ] + for (point_reduction, batch_reduction) in reductions: + + # Reinitialize all the tensors so that the + # backward pass can be computed. + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + + # Chamfer with pointclouds as input. + cham_cloud, _ = chamfer_distance( + points_normals.cloud1, + points_normals.cloud2, + point_reduction=point_reduction, + batch_reduction=batch_reduction, + ) + + # Chamfer with tensors as input. + cham_tensor, _ = chamfer_distance( + points_normals.p1, + points_normals.p2, + x_lengths=points_normals.p1_lengths, + y_lengths=points_normals.p2_lengths, + point_reduction=point_reduction, + batch_reduction=batch_reduction, + ) + + self.assertClose(cham_cloud, cham_tensor) + self._check_gradients( + cham_tensor, + None, + cham_cloud, + None, + points_normals.cloud1.points_list(), + points_normals.p1, + points_normals.cloud2.points_list(), + points_normals.p2, + lengths1=points_normals.p1_lengths, + lengths2=points_normals.p2_lengths, + ) + + def test_chamfer_point_reduction_mean(self): + """ + Compare output of vectorized chamfer loss with naive implementation + for point_reduction = "mean" and batch_reduction = None. + """ + N, max_P1, max_P2 = 7, 10, 18 + device = "cuda:0" + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + p1_normals = points_normals.n1 + p2_normals = points_normals.n2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + P1 = p1.shape[1] + P2 = p2.shape[1] + pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( - p1, p2, p1_normals, p2_normals + p1, p2, x_normals=p1_normals, y_normals=p2_normals ) # point_reduction = "mean". loss, loss_norm = chamfer_distance( - p1, - p2, - p1_normals, - p2_normals, + p11, + p22, + x_normals=p1_normals, + y_normals=p2_normals, weights=weights, - batch_reduction="none", + batch_reduction=None, point_reduction="mean", ) pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2 @@ -118,14 +449,40 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): pred_loss_norm_mean *= weights self.assertClose(loss_norm, pred_loss_norm_mean) - # point_reduction = "sum". + # Check gradients + self._check_gradients( + loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22 + ) + + def test_chamfer_point_reduction_sum(self): + """ + Compare output of vectorized chamfer loss with naive implementation + for point_reduction = "sum" and batch_reduction = None. + """ + N, P1, P2 = 7, 10, 18 + device = "cuda:0" + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + p1_normals = points_normals.n1 + p2_normals = points_normals.n2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + + pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( + p1, p2, x_normals=p1_normals, y_normals=p2_normals + ) + loss, loss_norm = chamfer_distance( - p1, - p2, - p1_normals, - p2_normals, + p11, + p22, + x_normals=p1_normals, + y_normals=p2_normals, weights=weights, - batch_reduction="none", + batch_reduction=None, point_reduction="sum", ) pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1) @@ -136,92 +493,110 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): pred_loss_norm_sum *= weights self.assertClose(loss_norm, pred_loss_norm_sum) - # Error when point_reduction = "none" and batch_reduction = "none". - with self.assertRaises(ValueError): - chamfer_distance( - p1, p2, weights=weights, batch_reduction="none", point_reduction="none" + # Check gradients + self._check_gradients( + loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22 + ) + + def _check_gradients( + self, + loss, + loss_norm, + pred_loss, + pred_loss_norm, + x1, + x2, + y1, + y2, + xn1=None, # normals + xn2=None, # normals + yn1=None, # normals + yn2=None, # normals + lengths1=None, + lengths2=None, + ): + """ + x1 and x2 can have different types based on the leaf node used in the calculation: + e.g. x1 may be a list of tensors whereas x2 is a padded tensor. + This also applies for the pairs: (y1, y2), (xn1, xn2), (yn1, yn2). + """ + grad_loss = torch.rand(loss.shape, device=loss.device, dtype=loss.dtype) + + # Loss for normals is optional. Iniitalize to 0. + norm_loss_term = pred_norm_loss_term = 0.0 + if loss_norm is not None and pred_loss_norm is not None: + grad_normals = torch.rand( + loss_norm.shape, device=loss.device, dtype=loss.dtype ) + norm_loss_term = loss_norm * grad_normals + pred_norm_loss_term = pred_loss_norm * grad_normals - # Error when batch_reduction is not in ["none", "mean", "sum"]. - with self.assertRaises(ValueError): - chamfer_distance(p1, p2, weights=weights, batch_reduction="max") + l1 = (loss * grad_loss) + norm_loss_term + l1.sum().backward() + l2 = (pred_loss * grad_loss) + pred_norm_loss_term + l2.sum().backward() - def test_chamfer_batch_reduction(self): + self._check_grad_by_type(x1, x2, lengths1) + self._check_grad_by_type(y1, y2, lengths2) + + # If leaf nodes for normals are passed in, check their gradients. + if all(n is not None for n in [xn1, xn2, yn1, yn2]): + self._check_grad_by_type(xn1, xn2, lengths1) + self._check_grad_by_type(yn1, yn2, lengths2) + + def _check_grad_by_type(self, x1, x2, lengths=None): """ - Compare output of vectorized chamfer loss with naive implementation - for batch_reduction in ["mean", "sum"] and point_reduction = "none". + x1 and x2 can be of different types e.g. list or tensor - compare appropriately + based on the types. """ - N, P1, P2 = 7, 10, 18 - p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds( - N, P1, P2 - ) + error_msg = "All values for gradient checks must be tensors or lists of tensors" - pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( - p1, p2, p1_normals, p2_normals - ) + if all(isinstance(p, list) for p in [x1, x2]): + # Lists of tensors + for i in range(len(x1)): + self.assertClose(x1[i].grad, x2[i].grad) + elif isinstance(x1, list) and torch.is_tensor(x2): + self.assertIsNotNone(lengths) # lengths is required - # batch_reduction = "sum". - loss, loss_norm = chamfer_distance( - p1, - p2, - p1_normals, - p2_normals, - weights=weights, - batch_reduction="sum", - point_reduction="none", - ) - pred_loss[0] *= weights.view(N, 1) - pred_loss[1] *= weights.view(N, 1) - pred_loss = pred_loss[0].sum() + pred_loss[1].sum() - self.assertClose(loss, pred_loss) - - pred_loss_norm[0] *= weights.view(N, 1) - pred_loss_norm[1] *= weights.view(N, 1) - pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum() - self.assertClose(loss_norm, pred_loss_norm) - - # batch_reduction = "mean". - loss, loss_norm = chamfer_distance( - p1, - p2, - p1_normals, - p2_normals, - weights=weights, - batch_reduction="mean", - point_reduction="none", - ) - - pred_loss /= weights.sum() - self.assertClose(loss, pred_loss) - - pred_loss_norm /= weights.sum() - self.assertClose(loss_norm, pred_loss_norm) - - # Error when point_reduction is not in ["none", "mean", "sum"]. - with self.assertRaises(ValueError): - chamfer_distance(p1, p2, weights=weights, point_reduction="max") + # List of tensors vs padded tensor + for i in range(len(x1)): + self.assertClose(x1[i].grad, x2.grad[i, : lengths[i]]) + self.assertTrue(x2.grad[i, lengths[i] :].sum().item() == 0.0) + elif all(torch.is_tensor(p) for p in [x1, x2]): + # Two tensors + self.assertClose(x1.grad, x2.grad) + else: + raise ValueError(error_msg) def test_chamfer_joint_reduction(self): """ Compare output of vectorized chamfer loss with naive implementation - for batch_reduction in ["mean", "sum"] and + when batch_reduction in ["mean", "sum"] and point_reduction in ["mean", "sum"]. """ - N, P1, P2 = 7, 10, 18 - p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds( - N, P1, P2 - ) + N, max_P1, max_P2 = 7, 10, 18 + device = "cuda:0" + + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + p1_normals = points_normals.n1 + p2_normals = points_normals.n2 + weights = points_normals.weights + + P1 = p1.shape[1] + P2 = p2.shape[1] pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( - p1, p2, p1_normals, p2_normals + p1, p2, x_normals=p1_normals, y_normals=p2_normals ) # batch_reduction = "sum", point_reduction = "sum". loss, loss_norm = chamfer_distance( p1, p2, - p1_normals, - p2_normals, + x_normals=p1_normals, + y_normals=p2_normals, weights=weights, batch_reduction="sum", point_reduction="sum", @@ -244,8 +619,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): loss, loss_norm = chamfer_distance( p1, p2, - p1_normals, - p2_normals, + x_normals=p1_normals, + y_normals=p2_normals, weights=weights, batch_reduction="mean", point_reduction="sum", @@ -260,8 +635,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): loss, loss_norm = chamfer_distance( p1, p2, - p1_normals, - p2_normals, + x_normals=p1_normals, + y_normals=p2_normals, weights=weights, batch_reduction="sum", point_reduction="mean", @@ -280,8 +655,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): loss, loss_norm = chamfer_distance( p1, p2, - p1_normals, - p2_normals, + x_normals=p1_normals, + y_normals=p2_normals, weights=weights, batch_reduction="mean", point_reduction="mean", @@ -292,6 +667,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): pred_loss_norm_mean /= weights.sum() self.assertClose(loss_norm, pred_loss_norm_mean) + # Error when batch_reduction is not in ["mean", "sum"] or None. + with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"): + chamfer_distance(p1, p2, weights=weights, batch_reduction="max") + + # Error when point_reduction is not in ["mean", "sum"]. + with self.assertRaisesRegex(ValueError, "point_reduction must be one of"): + chamfer_distance(p1, p2, weights=weights, point_reduction=None) + def test_incorrect_weights(self): N, P1, P2 = 16, 64, 128 device = torch.device("cuda:0") @@ -312,7 +695,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): self.assertTrue(loss_norm.requires_grad) loss, loss_norm = chamfer_distance( - p1, p2, weights=weights, batch_reduction="none" + p1, p2, weights=weights, batch_reduction=None ) self.assertClose(loss.cpu(), torch.zeros((N, N))) self.assertTrue(loss.requires_grad) @@ -327,16 +710,53 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError): loss, loss_norm = chamfer_distance(p1, p2, weights=weights) + def test_incorrect_inputs(self): + N, P1, P2 = 7, 10, 18 + device = "cuda:0" + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + p1_normals = points_normals.n1 + + # Normals of wrong shape + with self.assertRaisesRegex(ValueError, "Expected normals to be of shape"): + chamfer_distance(p1, p2, x_normals=p1_normals[None]) + + # Points of wrong shape + with self.assertRaisesRegex(ValueError, "Expected points to be of shape"): + chamfer_distance(p1[None], p2) + + # Lengths of wrong shape + with self.assertRaisesRegex(ValueError, "Expected lengths to be of shape"): + chamfer_distance(p1, p2, x_lengths=torch.tensor([1, 2, 3], device=device)) + + # Points are not a tensor or Pointclouds + with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"): + chamfer_distance(x=[1, 1, 1], y=[1, 1, 1]) + @staticmethod - def chamfer_with_init(batch_size: int, P1: int, P2: int, return_normals: bool): - p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds( + def chamfer_with_init( + batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool + ): + p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds( batch_size, P1, P2 ) + if homogeneous: + # Set lengths to None so in Chamfer it assumes + # there is no padding. + l1 = l2 = None + torch.cuda.synchronize() def loss(): loss, loss_normals = chamfer_distance( - p1, p2, p1_normals, p2_normals, weights=weights + p1, + p2, + x_lengths=l1, + y_lengths=l2, + x_normals=p1_normals, + y_normals=p2_normals, + weights=weights, ) torch.cuda.synchronize() @@ -346,14 +766,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): def chamfer_naive_with_init( batch_size: int, P1: int, P2: int, return_normals: bool ): - p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds( + p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds( batch_size, P1, P2 ) torch.cuda.synchronize() def loss(): loss, loss_normals = TestChamfer.chamfer_distance_naive( - p1, p2, p1_normals, p2_normals + p1, p2, x_normals=p1_normals, y_normals=p2_normals ) torch.cuda.synchronize()