mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Chamfer for Pointclouds object
Summary: Allow Pointclouds objects and heterogenous data to be provided for Chamfer loss. Remove "none" as an option for point_reduction because it doesn't make sense and in the current implementation is effectively the same as "sum". Possible improvement: create specialised operations for sum and cosine_similarity of padded tensors, to avoid having to create masks. sum would be useful elsewhere. Reviewed By: gkioxari Differential Revision: D20816301 fbshipit-source-id: 0f32073210225d157c029d80de450eecdb64f4d2
This commit is contained in:
parent
677b0bd5ae
commit
790eb8c402
@ -1,54 +1,101 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
|
from pytorch3d.ops.knn import knn_gather, knn_points
|
||||||
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
|
||||||
|
|
||||||
def _validate_chamfer_reduction_inputs(batch_reduction: str, point_reduction: str):
|
def _validate_chamfer_reduction_inputs(
|
||||||
|
batch_reduction: Union[str, None], point_reduction: str
|
||||||
|
):
|
||||||
"""Check the requested reductions are valid.
|
"""Check the requested reductions are valid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_reduction: Reduction operation to apply for the loss across the
|
batch_reduction: Reduction operation to apply for the loss across the
|
||||||
batch, can be one of ["none", "mean", "sum"].
|
batch, can be one of ["mean", "sum"] or None.
|
||||||
point_reduction: Reduction operation to apply for the loss across the
|
point_reduction: Reduction operation to apply for the loss across the
|
||||||
points, can be one of ["none", "mean", "sum"].
|
points, can be one of ["mean", "sum"].
|
||||||
"""
|
"""
|
||||||
if batch_reduction not in ["none", "mean", "sum"]:
|
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
|
||||||
raise ValueError('batch_reduction must be one of ["none", "mean", "sum"]')
|
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
|
||||||
if point_reduction not in ["none", "mean", "sum"]:
|
if point_reduction not in ["mean", "sum"]:
|
||||||
raise ValueError('point_reduction must be one of ["none", "mean", "sum"]')
|
raise ValueError('point_reduction must be one of ["mean", "sum"]')
|
||||||
if batch_reduction == "none" and point_reduction == "none":
|
|
||||||
raise ValueError('batch_reduction and point_reduction cannot both be "none".')
|
|
||||||
|
def _handle_pointcloud_input(
|
||||||
|
points: Union[torch.Tensor, Pointclouds],
|
||||||
|
lengths: Union[torch.Tensor, None],
|
||||||
|
normals: Union[torch.Tensor, None],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
If points is an instance of Pointclouds, retrieve the padded points tensor
|
||||||
|
along with the number of points per batch and the padded normals.
|
||||||
|
Otherwise, return the input points (and normals) with the number of points per cloud
|
||||||
|
set to the size of the second dimension of `points`.
|
||||||
|
"""
|
||||||
|
if isinstance(points, Pointclouds):
|
||||||
|
X = points.points_padded()
|
||||||
|
lengths = points.num_points_per_cloud()
|
||||||
|
normals = points.normals_padded() # either a tensor or None
|
||||||
|
elif torch.is_tensor(points):
|
||||||
|
if points.ndim != 3:
|
||||||
|
raise ValueError("Expected points to be of shape (N, P, D)")
|
||||||
|
X = points
|
||||||
|
if lengths is not None and (
|
||||||
|
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
|
||||||
|
):
|
||||||
|
raise ValueError("Expected lengths to be of shape (N,)")
|
||||||
|
if lengths is None:
|
||||||
|
lengths = torch.full(
|
||||||
|
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
|
||||||
|
)
|
||||||
|
if normals is not None and normals.ndim != 3:
|
||||||
|
raise ValueError("Expected normals to be of shape (N, P, 3")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The input pointclouds should be either "
|
||||||
|
+ "Pointclouds objects or torch.Tensor of shape "
|
||||||
|
+ "(minibatch, num_points, 3)."
|
||||||
|
)
|
||||||
|
return X, lengths, normals
|
||||||
|
|
||||||
|
|
||||||
def chamfer_distance(
|
def chamfer_distance(
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
|
x_lengths=None,
|
||||||
|
y_lengths=None,
|
||||||
x_normals=None,
|
x_normals=None,
|
||||||
y_normals=None,
|
y_normals=None,
|
||||||
weights=None,
|
weights=None,
|
||||||
batch_reduction: str = "mean",
|
batch_reduction: Union[str, None] = "mean",
|
||||||
point_reduction: str = "mean",
|
point_reduction: str = "mean",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Chamfer distance between two pointclouds x and y.
|
Chamfer distance between two pointclouds x and y.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
|
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
|
||||||
with P1 points in each batch element, batch size N and feature
|
a batch of point clouds with at most P1 points in each batch element,
|
||||||
dimension D.
|
batch size N and feature dimension D.
|
||||||
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
|
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
|
||||||
with P2 points in each batch element, batch size N and feature
|
a batch of point clouds with at most P2 points in each batch element,
|
||||||
dimension D.
|
batch size N and feature dimension D.
|
||||||
|
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
|
||||||
|
cloud in x.
|
||||||
|
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
|
||||||
|
cloud in x.
|
||||||
x_normals: Optional FloatTensor of shape (N, P1, D).
|
x_normals: Optional FloatTensor of shape (N, P1, D).
|
||||||
y_normals: Optional FloatTensor of shape (N, P2, D).
|
y_normals: Optional FloatTensor of shape (N, P2, D).
|
||||||
weights: Optional FloatTensor of shape (N,) giving weights for
|
weights: Optional FloatTensor of shape (N,) giving weights for
|
||||||
batch elements for reduction operation.
|
batch elements for reduction operation.
|
||||||
batch_reduction: Reduction operation to apply for the loss across the
|
batch_reduction: Reduction operation to apply for the loss across the
|
||||||
batch, can be one of ["none", "mean", "sum"].
|
batch, can be one of ["mean", "sum"] or None.
|
||||||
point_reduction: Reduction operation to apply for the loss across the
|
point_reduction: Reduction operation to apply for the loss across the
|
||||||
points, can be one of ["none", "mean", "sum"].
|
points, can be one of ["mean", "sum"].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
2-element tuple containing
|
2-element tuple containing
|
||||||
@ -61,16 +108,31 @@ def chamfer_distance(
|
|||||||
"""
|
"""
|
||||||
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
||||||
|
|
||||||
|
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
|
||||||
|
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
|
||||||
|
|
||||||
|
return_normals = x_normals is not None and y_normals is not None
|
||||||
|
|
||||||
N, P1, D = x.shape
|
N, P1, D = x.shape
|
||||||
P2 = y.shape[1]
|
P2 = y.shape[1]
|
||||||
|
|
||||||
|
# Check if inputs are heterogeneous and create a lengths mask.
|
||||||
|
is_x_heterogeneous = ~(x_lengths == P1).all()
|
||||||
|
is_y_heterogeneous = ~(y_lengths == P2).all()
|
||||||
|
x_mask = (
|
||||||
|
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
|
||||||
|
) # shape [N, P1]
|
||||||
|
y_mask = (
|
||||||
|
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
|
||||||
|
) # shape [N, P2]
|
||||||
|
|
||||||
if y.shape[0] != N or y.shape[2] != D:
|
if y.shape[0] != N or y.shape[2] != D:
|
||||||
raise ValueError("y does not have the correct shape.")
|
raise ValueError("y does not have the correct shape.")
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
if weights.size(0) != N:
|
if weights.size(0) != N:
|
||||||
raise ValueError("weights must be of shape (N,).")
|
raise ValueError("weights must be of shape (N,).")
|
||||||
if not (weights >= 0).all():
|
if not (weights >= 0).all():
|
||||||
raise ValueError("weights can not be nonnegative.")
|
raise ValueError("weights cannot be negative.")
|
||||||
if weights.sum() == 0.0:
|
if weights.sum() == 0.0:
|
||||||
weights = weights.view(N, 1)
|
weights = weights.view(N, 1)
|
||||||
if batch_reduction in ["mean", "sum"]:
|
if batch_reduction in ["mean", "sum"]:
|
||||||
@ -80,46 +142,60 @@ def chamfer_distance(
|
|||||||
)
|
)
|
||||||
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
|
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
|
||||||
|
|
||||||
return_normals = x_normals is not None and y_normals is not None
|
|
||||||
cham_norm_x = x.new_zeros(())
|
cham_norm_x = x.new_zeros(())
|
||||||
cham_norm_y = x.new_zeros(())
|
cham_norm_y = x.new_zeros(())
|
||||||
|
|
||||||
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
|
x_dists, x_idx = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
|
||||||
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
|
y_dists, y_idx = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
|
||||||
|
|
||||||
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
|
cham_x = x_dists[..., 0] # (N, P1)
|
||||||
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
|
cham_y = y_dists[..., 0] # (N, P2)
|
||||||
|
|
||||||
|
if is_x_heterogeneous:
|
||||||
|
cham_x[x_mask] = 0.0
|
||||||
|
if is_y_heterogeneous:
|
||||||
|
cham_y[y_mask] = 0.0
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
cham_x *= weights.view(N, 1)
|
cham_x *= weights.view(N, 1)
|
||||||
cham_y *= weights.view(N, 1)
|
cham_y *= weights.view(N, 1)
|
||||||
|
|
||||||
if return_normals:
|
if return_normals:
|
||||||
|
# Gather the normals using the indices and keep only value for k=0
|
||||||
|
x_normals_near = knn_gather(y_normals, x_idx, y_lengths)[..., 0, :]
|
||||||
|
y_normals_near = knn_gather(x_normals, y_idx, x_lengths)[..., 0, :]
|
||||||
|
|
||||||
cham_norm_x = 1 - torch.abs(
|
cham_norm_x = 1 - torch.abs(
|
||||||
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
|
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
|
||||||
)
|
)
|
||||||
cham_norm_y = 1 - torch.abs(
|
cham_norm_y = 1 - torch.abs(
|
||||||
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
|
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_x_heterogeneous:
|
||||||
|
cham_norm_x[x_mask] = 0.0
|
||||||
|
if is_y_heterogeneous:
|
||||||
|
cham_norm_y[y_mask] = 0.0
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
cham_norm_x *= weights.view(N, 1)
|
cham_norm_x *= weights.view(N, 1)
|
||||||
cham_norm_y *= weights.view(N, 1)
|
cham_norm_y *= weights.view(N, 1)
|
||||||
|
|
||||||
if point_reduction != "none":
|
# Apply point reduction
|
||||||
# If not 'none' then either 'sum' or 'mean'.
|
cham_x = cham_x.sum(1) # (N,)
|
||||||
cham_x = cham_x.sum(1) # (N,)
|
cham_y = cham_y.sum(1) # (N,)
|
||||||
cham_y = cham_y.sum(1) # (N,)
|
if return_normals:
|
||||||
|
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||||
|
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
||||||
|
if point_reduction == "mean":
|
||||||
|
cham_x /= x_lengths
|
||||||
|
cham_y /= y_lengths
|
||||||
if return_normals:
|
if return_normals:
|
||||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
cham_norm_x /= x_lengths
|
||||||
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
cham_norm_y /= y_lengths
|
||||||
if point_reduction == "mean":
|
|
||||||
cham_x /= P1
|
|
||||||
cham_y /= P2
|
|
||||||
if return_normals:
|
|
||||||
cham_norm_x /= P1
|
|
||||||
cham_norm_y /= P2
|
|
||||||
|
|
||||||
if batch_reduction != "none":
|
if batch_reduction is not None:
|
||||||
|
# batch_reduction == "sum"
|
||||||
cham_x = cham_x.sum()
|
cham_x = cham_x.sum()
|
||||||
cham_y = cham_y.sum()
|
cham_y = cham_y.sum()
|
||||||
if return_normals:
|
if return_normals:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fvcore.common.benchmark import benchmark
|
from fvcore.common.benchmark import benchmark
|
||||||
@ -20,8 +21,23 @@ def bm_chamfer() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
kwargs_list = kwargs_list_naive + [
|
kwargs_list = []
|
||||||
{"batch_size": 1, "P1": 1000, "P2": 3000, "return_normals": False},
|
batch_size = [1, 32]
|
||||||
{"batch_size": 1, "P1": 1000, "P2": 30000, "return_normals": True},
|
P1 = [32, 1000, 10000]
|
||||||
]
|
P2 = [64, 3000, 30000]
|
||||||
|
return_normals = [True, False]
|
||||||
|
homogeneous = [True, False]
|
||||||
|
test_cases = product(batch_size, P1, P2, return_normals, homogeneous)
|
||||||
|
|
||||||
|
for case in test_cases:
|
||||||
|
b, p1, p2, n, h = case
|
||||||
|
kwargs_list.append(
|
||||||
|
{
|
||||||
|
"batch_size": b,
|
||||||
|
"P1": p1,
|
||||||
|
"P2": p2,
|
||||||
|
"return_normals": n,
|
||||||
|
"homogeneous": h,
|
||||||
|
}
|
||||||
|
)
|
||||||
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)
|
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)
|
||||||
|
@ -1,111 +1,442 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.loss import chamfer_distance
|
from pytorch3d.loss import chamfer_distance
|
||||||
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
from pytorch3d.structures.utils import list_to_padded
|
||||||
|
|
||||||
|
|
||||||
|
# Output of init_pointclouds
|
||||||
|
points_normals = namedtuple(
|
||||||
|
"points_normals", "p1_lengths p2_lengths cloud1 cloud2 p1 p2 n1 n2 weights"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestChamfer(TestCaseMixin, unittest.TestCase):
|
class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||||
@staticmethod
|
def setUp(self) -> None:
|
||||||
def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
|
super().setUp()
|
||||||
"""
|
torch.manual_seed(1)
|
||||||
Randomly initialize two batches of point clouds of sizes
|
|
||||||
(N, P1, D) and (N, P2, D) and return random normal vectors for
|
|
||||||
each batch of size (N, P1, 3) and (N, P2, 3).
|
|
||||||
"""
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
p1 = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
|
|
||||||
p1_normals = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
|
|
||||||
p1_normals = p1_normals / p1_normals.norm(dim=2, p=2, keepdim=True)
|
|
||||||
p2 = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
|
|
||||||
p2_normals = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
|
|
||||||
p2_normals = p2_normals / p2_normals.norm(dim=2, p=2, keepdim=True)
|
|
||||||
weights = torch.rand((batch_size,), dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
return p1, p2, p1_normals, p2_normals, weights
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chamfer_distance_naive(p1, p2, p1_normals=None, p2_normals=None):
|
def init_pointclouds(N, P1, P2, device, requires_grad: bool = True):
|
||||||
|
"""
|
||||||
|
Create 2 pointclouds object and associated padded points/normals tensors by
|
||||||
|
starting from lists. The clouds and tensors have the same data. The
|
||||||
|
leaf nodes for the clouds are a list of tensors. The padded tensor can be
|
||||||
|
used directly as a leaf node.
|
||||||
|
"""
|
||||||
|
p1_lengths = torch.randint(P1, size=(N,), dtype=torch.int64, device=device)
|
||||||
|
p2_lengths = torch.randint(P2, size=(N,), dtype=torch.int64, device=device)
|
||||||
|
weights = torch.rand((N,), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# list of points and normals tensors
|
||||||
|
p1_list = []
|
||||||
|
p2_list = []
|
||||||
|
n1_list = []
|
||||||
|
n2_list = []
|
||||||
|
for i in range(N):
|
||||||
|
l1 = p1_lengths[i]
|
||||||
|
l2 = p2_lengths[i]
|
||||||
|
p1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device))
|
||||||
|
p2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device))
|
||||||
|
n1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device))
|
||||||
|
n2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device))
|
||||||
|
|
||||||
|
n1_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n1_list]
|
||||||
|
n2_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n2_list]
|
||||||
|
|
||||||
|
# Clone the lists and initialize padded tensors.
|
||||||
|
p1 = list_to_padded([p.clone() for p in p1_list])
|
||||||
|
p2 = list_to_padded([p.clone() for p in p2_list])
|
||||||
|
n1 = list_to_padded([p.clone() for p in n1_list])
|
||||||
|
n2 = list_to_padded([p.clone() for p in n2_list])
|
||||||
|
|
||||||
|
# Set requires_grad for all tensors in the lists and
|
||||||
|
# padded tensors.
|
||||||
|
if requires_grad:
|
||||||
|
for p in p2_list + p1_list + n1_list + n2_list + [p1, p2, n1, n2]:
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
# Create pointclouds objects
|
||||||
|
cloud1 = Pointclouds(points=p1_list, normals=n1_list)
|
||||||
|
cloud2 = Pointclouds(points=p2_list, normals=n2_list)
|
||||||
|
|
||||||
|
# Return pointclouds objects and padded tensors
|
||||||
|
return points_normals(
|
||||||
|
p1_lengths=p1_lengths,
|
||||||
|
p2_lengths=p2_lengths,
|
||||||
|
cloud1=cloud1,
|
||||||
|
cloud2=cloud2,
|
||||||
|
p1=p1,
|
||||||
|
p2=p2,
|
||||||
|
n1=n1,
|
||||||
|
n2=n2,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chamfer_distance_naive_pointclouds(p1, p2):
|
||||||
"""
|
"""
|
||||||
Naive iterative implementation of nearest neighbor and chamfer distance.
|
Naive iterative implementation of nearest neighbor and chamfer distance.
|
||||||
|
x and y are assumed to be pointclouds objects with points and optionally normals.
|
||||||
|
This functions supports heterogeneous pointclouds in a batch.
|
||||||
Returns lists of the unreduced loss and loss_normals.
|
Returns lists of the unreduced loss and loss_normals.
|
||||||
"""
|
"""
|
||||||
N, P1, D = p1.shape
|
x = p1.points_padded()
|
||||||
P2 = p2.size(1)
|
y = p2.points_padded()
|
||||||
|
N, P1, D = x.shape
|
||||||
|
P2 = y.size(1)
|
||||||
|
x_lengths = p1.num_points_per_cloud()
|
||||||
|
y_lengths = p2.num_points_per_cloud()
|
||||||
|
x_normals = p1.normals_padded()
|
||||||
|
y_normals = p2.normals_padded()
|
||||||
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
return_normals = p1_normals is not None and p2_normals is not None
|
return_normals = x_normals is not None and y_normals is not None
|
||||||
|
|
||||||
|
# Initialize all distances to + inf
|
||||||
|
dist = torch.ones((N, P1, P2), dtype=torch.float32, device=device) * np.inf
|
||||||
|
|
||||||
|
x_mask = (
|
||||||
|
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
|
||||||
|
) # shape [N, P1]
|
||||||
|
y_mask = (
|
||||||
|
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
|
||||||
|
) # shape [N, P2]
|
||||||
|
|
||||||
|
is_x_heterogeneous = ~(x_lengths == P1).all()
|
||||||
|
is_y_heterogeneous = ~(y_lengths == P2).all()
|
||||||
|
|
||||||
|
# Only calculate the distances for the points which are not masked
|
||||||
|
for n in range(N):
|
||||||
|
for i1 in range(x_lengths[n]):
|
||||||
|
for i2 in range(y_lengths[n]):
|
||||||
|
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
|
||||||
|
|
||||||
|
x_dist = torch.min(dist, dim=2)[0] # (N, P1)
|
||||||
|
y_dist = torch.min(dist, dim=1)[0] # (N, P2)
|
||||||
|
|
||||||
|
if is_x_heterogeneous:
|
||||||
|
x_dist[x_mask] = 0.0
|
||||||
|
if is_y_heterogeneous:
|
||||||
|
y_dist[y_mask] = 0.0
|
||||||
|
|
||||||
|
loss = [x_dist, y_dist]
|
||||||
|
|
||||||
|
lnorm = [x.new_zeros(()), x.new_zeros(())]
|
||||||
|
|
||||||
|
if return_normals:
|
||||||
|
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
|
||||||
|
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
|
||||||
|
lnorm1 = 1 - torch.abs(
|
||||||
|
F.cosine_similarity(
|
||||||
|
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
|
||||||
|
)
|
||||||
|
)
|
||||||
|
lnorm2 = 1 - torch.abs(
|
||||||
|
F.cosine_similarity(
|
||||||
|
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_x_heterogeneous:
|
||||||
|
lnorm1[x_mask] = 0.0
|
||||||
|
if is_y_heterogeneous:
|
||||||
|
lnorm2[y_mask] = 0.0
|
||||||
|
|
||||||
|
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
|
||||||
|
|
||||||
|
return loss, lnorm
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None):
|
||||||
|
"""
|
||||||
|
Naive iterative implementation of nearest neighbor and chamfer distance.
|
||||||
|
Returns lists of the unreduced loss and loss_normals. This naive
|
||||||
|
version only supports homogeneous pointcouds in a batch.
|
||||||
|
"""
|
||||||
|
N, P1, D = x.shape
|
||||||
|
P2 = y.size(1)
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
return_normals = x_normals is not None and y_normals is not None
|
||||||
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
|
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for n in range(N):
|
for n in range(N):
|
||||||
for i1 in range(P1):
|
for i1 in range(P1):
|
||||||
for i2 in range(P2):
|
for i2 in range(P2):
|
||||||
dist[n, i1, i2] = torch.sum((p1[n, i1, :] - p2[n, i2, :]) ** 2)
|
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
|
||||||
|
|
||||||
loss = [
|
loss = [
|
||||||
torch.min(dist, dim=2)[0], # (N, P1)
|
torch.min(dist, dim=2)[0], # (N, P1)
|
||||||
torch.min(dist, dim=1)[0], # (N, P2)
|
torch.min(dist, dim=1)[0], # (N, P2)
|
||||||
]
|
]
|
||||||
|
lnorm = [x.new_zeros(()), x.new_zeros(())]
|
||||||
lnorm = [p1.new_zeros(()), p1.new_zeros(())]
|
|
||||||
|
|
||||||
if return_normals:
|
if return_normals:
|
||||||
p1_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
|
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
|
||||||
p2_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
|
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
|
||||||
lnorm1 = 1 - torch.abs(
|
lnorm1 = 1 - torch.abs(
|
||||||
F.cosine_similarity(
|
F.cosine_similarity(
|
||||||
p1_normals, p2_normals.gather(1, p1_index), dim=2, eps=1e-6
|
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
lnorm2 = 1 - torch.abs(
|
lnorm2 = 1 - torch.abs(
|
||||||
F.cosine_similarity(
|
F.cosine_similarity(
|
||||||
p2_normals, p1_normals.gather(1, p2_index), dim=2, eps=1e-6
|
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
|
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
|
||||||
|
|
||||||
return loss, lnorm
|
return loss, lnorm
|
||||||
|
|
||||||
def test_chamfer_default_no_normals(self):
|
def test_chamfer_point_batch_reduction_mean(self):
|
||||||
"""
|
"""
|
||||||
Compare chamfer loss with naive implementation using default
|
Compare output of vectorized chamfer loss with naive implementation
|
||||||
input values and no normals.
|
for the default settings (point_reduction = "mean" and batch_reduction = "mean")
|
||||||
|
and no normals.
|
||||||
|
This tests only uses homogeneous pointclouds.
|
||||||
"""
|
"""
|
||||||
N, P1, P2 = 7, 10, 18
|
N, max_P1, max_P2 = 7, 10, 18
|
||||||
p1, p2, _, _, weights = TestChamfer.init_pointclouds(N, P1, P2)
|
device = "cuda:0"
|
||||||
pred_loss, _ = TestChamfer.chamfer_distance_naive(p1, p2)
|
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||||
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
|
p1 = points_normals.p1
|
||||||
|
p2 = points_normals.p2
|
||||||
|
weights = points_normals.weights
|
||||||
|
p11 = p1.detach().clone()
|
||||||
|
p22 = p2.detach().clone()
|
||||||
|
p11.requires_grad = True
|
||||||
|
p22.requires_grad = True
|
||||||
|
P1 = p1.shape[1]
|
||||||
|
P2 = p2.shape[1]
|
||||||
|
|
||||||
|
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2)
|
||||||
|
|
||||||
|
# point_reduction = "mean".
|
||||||
|
loss, loss_norm = chamfer_distance(p11, p22, weights=weights)
|
||||||
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
||||||
pred_loss *= weights
|
pred_loss *= weights
|
||||||
pred_loss = pred_loss.sum() / weights.sum()
|
pred_loss = pred_loss.sum() / weights.sum()
|
||||||
|
|
||||||
self.assertClose(loss, pred_loss)
|
self.assertClose(loss, pred_loss)
|
||||||
self.assertTrue(loss_norm is None)
|
self.assertTrue(loss_norm is None)
|
||||||
|
|
||||||
def test_chamfer_point_reduction(self):
|
# Check gradients
|
||||||
|
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
|
||||||
|
|
||||||
|
def test_chamfer_vs_naive_pointcloud(self):
|
||||||
"""
|
"""
|
||||||
Compare output of vectorized chamfer loss with naive implementation
|
Test the default settings for chamfer_distance
|
||||||
for point_reduction in ["mean", "sum", "none"] and
|
(point reduction = "mean" and batch_reduction="mean") but with heterogeneous
|
||||||
batch_reduction = "none".
|
pointclouds as input. Compare with the naive implementation of chamfer
|
||||||
|
which supports heterogeneous pointcloud objects.
|
||||||
"""
|
"""
|
||||||
N, P1, P2 = 7, 10, 18
|
N, max_P1, max_P2 = 3, 70, 70
|
||||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
device = "cuda:0"
|
||||||
N, P1, P2
|
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||||
|
weights = points_normals.weights
|
||||||
|
x_lengths = points_normals.p1_lengths
|
||||||
|
y_lengths = points_normals.p2_lengths
|
||||||
|
|
||||||
|
# Chamfer with tensors as input for heterogeneous pointclouds.
|
||||||
|
cham_tensor, norm_tensor = chamfer_distance(
|
||||||
|
points_normals.p1,
|
||||||
|
points_normals.p2,
|
||||||
|
x_normals=points_normals.n1,
|
||||||
|
y_normals=points_normals.n2,
|
||||||
|
x_lengths=points_normals.p1_lengths,
|
||||||
|
y_lengths=points_normals.p2_lengths,
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Chamfer with pointclouds as input.
|
||||||
|
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
|
||||||
|
points_normals.cloud1, points_normals.cloud2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mean reduction point loss.
|
||||||
|
pred_loss[0] *= weights.view(N, 1)
|
||||||
|
pred_loss[1] *= weights.view(N, 1)
|
||||||
|
pred_loss_mean = (
|
||||||
|
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
|
||||||
|
)
|
||||||
|
pred_loss_mean = pred_loss_mean.sum()
|
||||||
|
pred_loss_mean /= weights.sum()
|
||||||
|
|
||||||
|
# Mean reduction norm loss.
|
||||||
|
pred_norm_loss[0] *= weights.view(N, 1)
|
||||||
|
pred_norm_loss[1] *= weights.view(N, 1)
|
||||||
|
pred_norm_loss_mean = (
|
||||||
|
pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths
|
||||||
|
)
|
||||||
|
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
|
||||||
|
|
||||||
|
self.assertClose(pred_loss_mean, cham_tensor)
|
||||||
|
self.assertClose(pred_norm_loss_mean, norm_tensor)
|
||||||
|
|
||||||
|
self._check_gradients(
|
||||||
|
cham_tensor,
|
||||||
|
norm_tensor,
|
||||||
|
pred_loss_mean,
|
||||||
|
pred_norm_loss_mean,
|
||||||
|
points_normals.cloud1.points_list(),
|
||||||
|
points_normals.p1,
|
||||||
|
points_normals.cloud2.points_list(),
|
||||||
|
points_normals.p2,
|
||||||
|
points_normals.cloud1.normals_list(),
|
||||||
|
points_normals.n1,
|
||||||
|
points_normals.cloud2.normals_list(),
|
||||||
|
points_normals.n2,
|
||||||
|
x_lengths,
|
||||||
|
y_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chamfer_pointcloud_object_withnormals(self):
|
||||||
|
N = 5
|
||||||
|
P1, P2 = 100, 100
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
reductions = [
|
||||||
|
("sum", "sum"),
|
||||||
|
("mean", "sum"),
|
||||||
|
("sum", "mean"),
|
||||||
|
("mean", "mean"),
|
||||||
|
("sum", None),
|
||||||
|
("mean", None),
|
||||||
|
]
|
||||||
|
for (point_reduction, batch_reduction) in reductions:
|
||||||
|
|
||||||
|
# Reinitialize all the tensors so that the
|
||||||
|
# backward pass can be computed.
|
||||||
|
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||||
|
|
||||||
|
# Chamfer with pointclouds as input.
|
||||||
|
cham_cloud, norm_cloud = chamfer_distance(
|
||||||
|
points_normals.cloud1,
|
||||||
|
points_normals.cloud2,
|
||||||
|
point_reduction=point_reduction,
|
||||||
|
batch_reduction=batch_reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chamfer with tensors as input.
|
||||||
|
cham_tensor, norm_tensor = chamfer_distance(
|
||||||
|
points_normals.p1,
|
||||||
|
points_normals.p2,
|
||||||
|
x_lengths=points_normals.p1_lengths,
|
||||||
|
y_lengths=points_normals.p2_lengths,
|
||||||
|
x_normals=points_normals.n1,
|
||||||
|
y_normals=points_normals.n2,
|
||||||
|
point_reduction=point_reduction,
|
||||||
|
batch_reduction=batch_reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertClose(cham_cloud, cham_tensor)
|
||||||
|
self.assertClose(norm_cloud, norm_tensor)
|
||||||
|
self._check_gradients(
|
||||||
|
cham_tensor,
|
||||||
|
norm_tensor,
|
||||||
|
cham_cloud,
|
||||||
|
norm_cloud,
|
||||||
|
points_normals.cloud1.points_list(),
|
||||||
|
points_normals.p1,
|
||||||
|
points_normals.cloud2.points_list(),
|
||||||
|
points_normals.p2,
|
||||||
|
points_normals.cloud1.normals_list(),
|
||||||
|
points_normals.n1,
|
||||||
|
points_normals.cloud2.normals_list(),
|
||||||
|
points_normals.n2,
|
||||||
|
points_normals.p1_lengths,
|
||||||
|
points_normals.p2_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chamfer_pointcloud_object_nonormals(self):
|
||||||
|
N = 5
|
||||||
|
P1, P2 = 100, 100
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
reductions = [
|
||||||
|
("sum", "sum"),
|
||||||
|
("mean", "sum"),
|
||||||
|
("sum", "mean"),
|
||||||
|
("mean", "mean"),
|
||||||
|
("sum", None),
|
||||||
|
("mean", None),
|
||||||
|
]
|
||||||
|
for (point_reduction, batch_reduction) in reductions:
|
||||||
|
|
||||||
|
# Reinitialize all the tensors so that the
|
||||||
|
# backward pass can be computed.
|
||||||
|
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||||
|
|
||||||
|
# Chamfer with pointclouds as input.
|
||||||
|
cham_cloud, _ = chamfer_distance(
|
||||||
|
points_normals.cloud1,
|
||||||
|
points_normals.cloud2,
|
||||||
|
point_reduction=point_reduction,
|
||||||
|
batch_reduction=batch_reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chamfer with tensors as input.
|
||||||
|
cham_tensor, _ = chamfer_distance(
|
||||||
|
points_normals.p1,
|
||||||
|
points_normals.p2,
|
||||||
|
x_lengths=points_normals.p1_lengths,
|
||||||
|
y_lengths=points_normals.p2_lengths,
|
||||||
|
point_reduction=point_reduction,
|
||||||
|
batch_reduction=batch_reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertClose(cham_cloud, cham_tensor)
|
||||||
|
self._check_gradients(
|
||||||
|
cham_tensor,
|
||||||
|
None,
|
||||||
|
cham_cloud,
|
||||||
|
None,
|
||||||
|
points_normals.cloud1.points_list(),
|
||||||
|
points_normals.p1,
|
||||||
|
points_normals.cloud2.points_list(),
|
||||||
|
points_normals.p2,
|
||||||
|
lengths1=points_normals.p1_lengths,
|
||||||
|
lengths2=points_normals.p2_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chamfer_point_reduction_mean(self):
|
||||||
|
"""
|
||||||
|
Compare output of vectorized chamfer loss with naive implementation
|
||||||
|
for point_reduction = "mean" and batch_reduction = None.
|
||||||
|
"""
|
||||||
|
N, max_P1, max_P2 = 7, 10, 18
|
||||||
|
device = "cuda:0"
|
||||||
|
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||||
|
p1 = points_normals.p1
|
||||||
|
p2 = points_normals.p2
|
||||||
|
p1_normals = points_normals.n1
|
||||||
|
p2_normals = points_normals.n2
|
||||||
|
weights = points_normals.weights
|
||||||
|
p11 = p1.detach().clone()
|
||||||
|
p22 = p2.detach().clone()
|
||||||
|
p11.requires_grad = True
|
||||||
|
p22.requires_grad = True
|
||||||
|
P1 = p1.shape[1]
|
||||||
|
P2 = p2.shape[1]
|
||||||
|
|
||||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||||
p1, p2, p1_normals, p2_normals
|
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||||
)
|
)
|
||||||
|
|
||||||
# point_reduction = "mean".
|
# point_reduction = "mean".
|
||||||
loss, loss_norm = chamfer_distance(
|
loss, loss_norm = chamfer_distance(
|
||||||
p1,
|
p11,
|
||||||
p2,
|
p22,
|
||||||
p1_normals,
|
x_normals=p1_normals,
|
||||||
p2_normals,
|
y_normals=p2_normals,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
batch_reduction="none",
|
batch_reduction=None,
|
||||||
point_reduction="mean",
|
point_reduction="mean",
|
||||||
)
|
)
|
||||||
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
||||||
@ -118,14 +449,40 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
pred_loss_norm_mean *= weights
|
pred_loss_norm_mean *= weights
|
||||||
self.assertClose(loss_norm, pred_loss_norm_mean)
|
self.assertClose(loss_norm, pred_loss_norm_mean)
|
||||||
|
|
||||||
# point_reduction = "sum".
|
# Check gradients
|
||||||
|
self._check_gradients(
|
||||||
|
loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chamfer_point_reduction_sum(self):
|
||||||
|
"""
|
||||||
|
Compare output of vectorized chamfer loss with naive implementation
|
||||||
|
for point_reduction = "sum" and batch_reduction = None.
|
||||||
|
"""
|
||||||
|
N, P1, P2 = 7, 10, 18
|
||||||
|
device = "cuda:0"
|
||||||
|
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||||
|
p1 = points_normals.p1
|
||||||
|
p2 = points_normals.p2
|
||||||
|
p1_normals = points_normals.n1
|
||||||
|
p2_normals = points_normals.n2
|
||||||
|
weights = points_normals.weights
|
||||||
|
p11 = p1.detach().clone()
|
||||||
|
p22 = p2.detach().clone()
|
||||||
|
p11.requires_grad = True
|
||||||
|
p22.requires_grad = True
|
||||||
|
|
||||||
|
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||||
|
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||||
|
)
|
||||||
|
|
||||||
loss, loss_norm = chamfer_distance(
|
loss, loss_norm = chamfer_distance(
|
||||||
p1,
|
p11,
|
||||||
p2,
|
p22,
|
||||||
p1_normals,
|
x_normals=p1_normals,
|
||||||
p2_normals,
|
y_normals=p2_normals,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
batch_reduction="none",
|
batch_reduction=None,
|
||||||
point_reduction="sum",
|
point_reduction="sum",
|
||||||
)
|
)
|
||||||
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
|
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
|
||||||
@ -136,92 +493,110 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
pred_loss_norm_sum *= weights
|
pred_loss_norm_sum *= weights
|
||||||
self.assertClose(loss_norm, pred_loss_norm_sum)
|
self.assertClose(loss_norm, pred_loss_norm_sum)
|
||||||
|
|
||||||
# Error when point_reduction = "none" and batch_reduction = "none".
|
# Check gradients
|
||||||
with self.assertRaises(ValueError):
|
self._check_gradients(
|
||||||
chamfer_distance(
|
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
|
||||||
p1, p2, weights=weights, batch_reduction="none", point_reduction="none"
|
)
|
||||||
|
|
||||||
|
def _check_gradients(
|
||||||
|
self,
|
||||||
|
loss,
|
||||||
|
loss_norm,
|
||||||
|
pred_loss,
|
||||||
|
pred_loss_norm,
|
||||||
|
x1,
|
||||||
|
x2,
|
||||||
|
y1,
|
||||||
|
y2,
|
||||||
|
xn1=None, # normals
|
||||||
|
xn2=None, # normals
|
||||||
|
yn1=None, # normals
|
||||||
|
yn2=None, # normals
|
||||||
|
lengths1=None,
|
||||||
|
lengths2=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
x1 and x2 can have different types based on the leaf node used in the calculation:
|
||||||
|
e.g. x1 may be a list of tensors whereas x2 is a padded tensor.
|
||||||
|
This also applies for the pairs: (y1, y2), (xn1, xn2), (yn1, yn2).
|
||||||
|
"""
|
||||||
|
grad_loss = torch.rand(loss.shape, device=loss.device, dtype=loss.dtype)
|
||||||
|
|
||||||
|
# Loss for normals is optional. Iniitalize to 0.
|
||||||
|
norm_loss_term = pred_norm_loss_term = 0.0
|
||||||
|
if loss_norm is not None and pred_loss_norm is not None:
|
||||||
|
grad_normals = torch.rand(
|
||||||
|
loss_norm.shape, device=loss.device, dtype=loss.dtype
|
||||||
)
|
)
|
||||||
|
norm_loss_term = loss_norm * grad_normals
|
||||||
|
pred_norm_loss_term = pred_loss_norm * grad_normals
|
||||||
|
|
||||||
# Error when batch_reduction is not in ["none", "mean", "sum"].
|
l1 = (loss * grad_loss) + norm_loss_term
|
||||||
with self.assertRaises(ValueError):
|
l1.sum().backward()
|
||||||
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
|
l2 = (pred_loss * grad_loss) + pred_norm_loss_term
|
||||||
|
l2.sum().backward()
|
||||||
|
|
||||||
def test_chamfer_batch_reduction(self):
|
self._check_grad_by_type(x1, x2, lengths1)
|
||||||
|
self._check_grad_by_type(y1, y2, lengths2)
|
||||||
|
|
||||||
|
# If leaf nodes for normals are passed in, check their gradients.
|
||||||
|
if all(n is not None for n in [xn1, xn2, yn1, yn2]):
|
||||||
|
self._check_grad_by_type(xn1, xn2, lengths1)
|
||||||
|
self._check_grad_by_type(yn1, yn2, lengths2)
|
||||||
|
|
||||||
|
def _check_grad_by_type(self, x1, x2, lengths=None):
|
||||||
"""
|
"""
|
||||||
Compare output of vectorized chamfer loss with naive implementation
|
x1 and x2 can be of different types e.g. list or tensor - compare appropriately
|
||||||
for batch_reduction in ["mean", "sum"] and point_reduction = "none".
|
based on the types.
|
||||||
"""
|
"""
|
||||||
N, P1, P2 = 7, 10, 18
|
error_msg = "All values for gradient checks must be tensors or lists of tensors"
|
||||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
|
||||||
N, P1, P2
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
if all(isinstance(p, list) for p in [x1, x2]):
|
||||||
p1, p2, p1_normals, p2_normals
|
# Lists of tensors
|
||||||
)
|
for i in range(len(x1)):
|
||||||
|
self.assertClose(x1[i].grad, x2[i].grad)
|
||||||
|
elif isinstance(x1, list) and torch.is_tensor(x2):
|
||||||
|
self.assertIsNotNone(lengths) # lengths is required
|
||||||
|
|
||||||
# batch_reduction = "sum".
|
# List of tensors vs padded tensor
|
||||||
loss, loss_norm = chamfer_distance(
|
for i in range(len(x1)):
|
||||||
p1,
|
self.assertClose(x1[i].grad, x2.grad[i, : lengths[i]])
|
||||||
p2,
|
self.assertTrue(x2.grad[i, lengths[i] :].sum().item() == 0.0)
|
||||||
p1_normals,
|
elif all(torch.is_tensor(p) for p in [x1, x2]):
|
||||||
p2_normals,
|
# Two tensors
|
||||||
weights=weights,
|
self.assertClose(x1.grad, x2.grad)
|
||||||
batch_reduction="sum",
|
else:
|
||||||
point_reduction="none",
|
raise ValueError(error_msg)
|
||||||
)
|
|
||||||
pred_loss[0] *= weights.view(N, 1)
|
|
||||||
pred_loss[1] *= weights.view(N, 1)
|
|
||||||
pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
|
|
||||||
self.assertClose(loss, pred_loss)
|
|
||||||
|
|
||||||
pred_loss_norm[0] *= weights.view(N, 1)
|
|
||||||
pred_loss_norm[1] *= weights.view(N, 1)
|
|
||||||
pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
|
|
||||||
self.assertClose(loss_norm, pred_loss_norm)
|
|
||||||
|
|
||||||
# batch_reduction = "mean".
|
|
||||||
loss, loss_norm = chamfer_distance(
|
|
||||||
p1,
|
|
||||||
p2,
|
|
||||||
p1_normals,
|
|
||||||
p2_normals,
|
|
||||||
weights=weights,
|
|
||||||
batch_reduction="mean",
|
|
||||||
point_reduction="none",
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_loss /= weights.sum()
|
|
||||||
self.assertClose(loss, pred_loss)
|
|
||||||
|
|
||||||
pred_loss_norm /= weights.sum()
|
|
||||||
self.assertClose(loss_norm, pred_loss_norm)
|
|
||||||
|
|
||||||
# Error when point_reduction is not in ["none", "mean", "sum"].
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
|
|
||||||
|
|
||||||
def test_chamfer_joint_reduction(self):
|
def test_chamfer_joint_reduction(self):
|
||||||
"""
|
"""
|
||||||
Compare output of vectorized chamfer loss with naive implementation
|
Compare output of vectorized chamfer loss with naive implementation
|
||||||
for batch_reduction in ["mean", "sum"] and
|
when batch_reduction in ["mean", "sum"] and
|
||||||
point_reduction in ["mean", "sum"].
|
point_reduction in ["mean", "sum"].
|
||||||
"""
|
"""
|
||||||
N, P1, P2 = 7, 10, 18
|
N, max_P1, max_P2 = 7, 10, 18
|
||||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
device = "cuda:0"
|
||||||
N, P1, P2
|
|
||||||
)
|
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||||
|
p1 = points_normals.p1
|
||||||
|
p2 = points_normals.p2
|
||||||
|
p1_normals = points_normals.n1
|
||||||
|
p2_normals = points_normals.n2
|
||||||
|
weights = points_normals.weights
|
||||||
|
|
||||||
|
P1 = p1.shape[1]
|
||||||
|
P2 = p2.shape[1]
|
||||||
|
|
||||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||||
p1, p2, p1_normals, p2_normals
|
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||||
)
|
)
|
||||||
|
|
||||||
# batch_reduction = "sum", point_reduction = "sum".
|
# batch_reduction = "sum", point_reduction = "sum".
|
||||||
loss, loss_norm = chamfer_distance(
|
loss, loss_norm = chamfer_distance(
|
||||||
p1,
|
p1,
|
||||||
p2,
|
p2,
|
||||||
p1_normals,
|
x_normals=p1_normals,
|
||||||
p2_normals,
|
y_normals=p2_normals,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
batch_reduction="sum",
|
batch_reduction="sum",
|
||||||
point_reduction="sum",
|
point_reduction="sum",
|
||||||
@ -244,8 +619,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
loss, loss_norm = chamfer_distance(
|
loss, loss_norm = chamfer_distance(
|
||||||
p1,
|
p1,
|
||||||
p2,
|
p2,
|
||||||
p1_normals,
|
x_normals=p1_normals,
|
||||||
p2_normals,
|
y_normals=p2_normals,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
batch_reduction="mean",
|
batch_reduction="mean",
|
||||||
point_reduction="sum",
|
point_reduction="sum",
|
||||||
@ -260,8 +635,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
loss, loss_norm = chamfer_distance(
|
loss, loss_norm = chamfer_distance(
|
||||||
p1,
|
p1,
|
||||||
p2,
|
p2,
|
||||||
p1_normals,
|
x_normals=p1_normals,
|
||||||
p2_normals,
|
y_normals=p2_normals,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
batch_reduction="sum",
|
batch_reduction="sum",
|
||||||
point_reduction="mean",
|
point_reduction="mean",
|
||||||
@ -280,8 +655,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
loss, loss_norm = chamfer_distance(
|
loss, loss_norm = chamfer_distance(
|
||||||
p1,
|
p1,
|
||||||
p2,
|
p2,
|
||||||
p1_normals,
|
x_normals=p1_normals,
|
||||||
p2_normals,
|
y_normals=p2_normals,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
batch_reduction="mean",
|
batch_reduction="mean",
|
||||||
point_reduction="mean",
|
point_reduction="mean",
|
||||||
@ -292,6 +667,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
pred_loss_norm_mean /= weights.sum()
|
pred_loss_norm_mean /= weights.sum()
|
||||||
self.assertClose(loss_norm, pred_loss_norm_mean)
|
self.assertClose(loss_norm, pred_loss_norm_mean)
|
||||||
|
|
||||||
|
# Error when batch_reduction is not in ["mean", "sum"] or None.
|
||||||
|
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
|
||||||
|
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
|
||||||
|
|
||||||
|
# Error when point_reduction is not in ["mean", "sum"].
|
||||||
|
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
|
||||||
|
chamfer_distance(p1, p2, weights=weights, point_reduction=None)
|
||||||
|
|
||||||
def test_incorrect_weights(self):
|
def test_incorrect_weights(self):
|
||||||
N, P1, P2 = 16, 64, 128
|
N, P1, P2 = 16, 64, 128
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@ -312,7 +695,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertTrue(loss_norm.requires_grad)
|
self.assertTrue(loss_norm.requires_grad)
|
||||||
|
|
||||||
loss, loss_norm = chamfer_distance(
|
loss, loss_norm = chamfer_distance(
|
||||||
p1, p2, weights=weights, batch_reduction="none"
|
p1, p2, weights=weights, batch_reduction=None
|
||||||
)
|
)
|
||||||
self.assertClose(loss.cpu(), torch.zeros((N, N)))
|
self.assertClose(loss.cpu(), torch.zeros((N, N)))
|
||||||
self.assertTrue(loss.requires_grad)
|
self.assertTrue(loss.requires_grad)
|
||||||
@ -327,16 +710,53 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
|
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
|
||||||
|
|
||||||
|
def test_incorrect_inputs(self):
|
||||||
|
N, P1, P2 = 7, 10, 18
|
||||||
|
device = "cuda:0"
|
||||||
|
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||||
|
p1 = points_normals.p1
|
||||||
|
p2 = points_normals.p2
|
||||||
|
p1_normals = points_normals.n1
|
||||||
|
|
||||||
|
# Normals of wrong shape
|
||||||
|
with self.assertRaisesRegex(ValueError, "Expected normals to be of shape"):
|
||||||
|
chamfer_distance(p1, p2, x_normals=p1_normals[None])
|
||||||
|
|
||||||
|
# Points of wrong shape
|
||||||
|
with self.assertRaisesRegex(ValueError, "Expected points to be of shape"):
|
||||||
|
chamfer_distance(p1[None], p2)
|
||||||
|
|
||||||
|
# Lengths of wrong shape
|
||||||
|
with self.assertRaisesRegex(ValueError, "Expected lengths to be of shape"):
|
||||||
|
chamfer_distance(p1, p2, x_lengths=torch.tensor([1, 2, 3], device=device))
|
||||||
|
|
||||||
|
# Points are not a tensor or Pointclouds
|
||||||
|
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
|
||||||
|
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chamfer_with_init(batch_size: int, P1: int, P2: int, return_normals: bool):
|
def chamfer_with_init(
|
||||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool
|
||||||
|
):
|
||||||
|
p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds(
|
||||||
batch_size, P1, P2
|
batch_size, P1, P2
|
||||||
)
|
)
|
||||||
|
if homogeneous:
|
||||||
|
# Set lengths to None so in Chamfer it assumes
|
||||||
|
# there is no padding.
|
||||||
|
l1 = l2 = None
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def loss():
|
def loss():
|
||||||
loss, loss_normals = chamfer_distance(
|
loss, loss_normals = chamfer_distance(
|
||||||
p1, p2, p1_normals, p2_normals, weights=weights
|
p1,
|
||||||
|
p2,
|
||||||
|
x_lengths=l1,
|
||||||
|
y_lengths=l2,
|
||||||
|
x_normals=p1_normals,
|
||||||
|
y_normals=p2_normals,
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@ -346,14 +766,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
def chamfer_naive_with_init(
|
def chamfer_naive_with_init(
|
||||||
batch_size: int, P1: int, P2: int, return_normals: bool
|
batch_size: int, P1: int, P2: int, return_normals: bool
|
||||||
):
|
):
|
||||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds(
|
||||||
batch_size, P1, P2
|
batch_size, P1, P2
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def loss():
|
def loss():
|
||||||
loss, loss_normals = TestChamfer.chamfer_distance_naive(
|
loss, loss_normals = TestChamfer.chamfer_distance_naive(
|
||||||
p1, p2, p1_normals, p2_normals
|
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user