mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Chamfer for Pointclouds object
Summary: Allow Pointclouds objects and heterogenous data to be provided for Chamfer loss. Remove "none" as an option for point_reduction because it doesn't make sense and in the current implementation is effectively the same as "sum". Possible improvement: create specialised operations for sum and cosine_similarity of padded tensors, to avoid having to create masks. sum would be useful elsewhere. Reviewed By: gkioxari Differential Revision: D20816301 fbshipit-source-id: 0f32073210225d157c029d80de450eecdb64f4d2
This commit is contained in:
parent
677b0bd5ae
commit
790eb8c402
@ -1,54 +1,101 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
|
||||
from pytorch3d.ops.knn import knn_gather, knn_points
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
|
||||
def _validate_chamfer_reduction_inputs(batch_reduction: str, point_reduction: str):
|
||||
def _validate_chamfer_reduction_inputs(
|
||||
batch_reduction: Union[str, None], point_reduction: str
|
||||
):
|
||||
"""Check the requested reductions are valid.
|
||||
|
||||
Args:
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["none", "mean", "sum"].
|
||||
batch, can be one of ["mean", "sum"] or None.
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["none", "mean", "sum"].
|
||||
points, can be one of ["mean", "sum"].
|
||||
"""
|
||||
if batch_reduction not in ["none", "mean", "sum"]:
|
||||
raise ValueError('batch_reduction must be one of ["none", "mean", "sum"]')
|
||||
if point_reduction not in ["none", "mean", "sum"]:
|
||||
raise ValueError('point_reduction must be one of ["none", "mean", "sum"]')
|
||||
if batch_reduction == "none" and point_reduction == "none":
|
||||
raise ValueError('batch_reduction and point_reduction cannot both be "none".')
|
||||
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
|
||||
if point_reduction not in ["mean", "sum"]:
|
||||
raise ValueError('point_reduction must be one of ["mean", "sum"]')
|
||||
|
||||
|
||||
def _handle_pointcloud_input(
|
||||
points: Union[torch.Tensor, Pointclouds],
|
||||
lengths: Union[torch.Tensor, None],
|
||||
normals: Union[torch.Tensor, None],
|
||||
):
|
||||
"""
|
||||
If points is an instance of Pointclouds, retrieve the padded points tensor
|
||||
along with the number of points per batch and the padded normals.
|
||||
Otherwise, return the input points (and normals) with the number of points per cloud
|
||||
set to the size of the second dimension of `points`.
|
||||
"""
|
||||
if isinstance(points, Pointclouds):
|
||||
X = points.points_padded()
|
||||
lengths = points.num_points_per_cloud()
|
||||
normals = points.normals_padded() # either a tensor or None
|
||||
elif torch.is_tensor(points):
|
||||
if points.ndim != 3:
|
||||
raise ValueError("Expected points to be of shape (N, P, D)")
|
||||
X = points
|
||||
if lengths is not None and (
|
||||
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
|
||||
):
|
||||
raise ValueError("Expected lengths to be of shape (N,)")
|
||||
if lengths is None:
|
||||
lengths = torch.full(
|
||||
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
|
||||
)
|
||||
if normals is not None and normals.ndim != 3:
|
||||
raise ValueError("Expected normals to be of shape (N, P, 3")
|
||||
else:
|
||||
raise ValueError(
|
||||
"The input pointclouds should be either "
|
||||
+ "Pointclouds objects or torch.Tensor of shape "
|
||||
+ "(minibatch, num_points, 3)."
|
||||
)
|
||||
return X, lengths, normals
|
||||
|
||||
|
||||
def chamfer_distance(
|
||||
x,
|
||||
y,
|
||||
x_lengths=None,
|
||||
y_lengths=None,
|
||||
x_normals=None,
|
||||
y_normals=None,
|
||||
weights=None,
|
||||
batch_reduction: str = "mean",
|
||||
batch_reduction: Union[str, None] = "mean",
|
||||
point_reduction: str = "mean",
|
||||
):
|
||||
"""
|
||||
Chamfer distance between two pointclouds x and y.
|
||||
|
||||
Args:
|
||||
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
|
||||
with P1 points in each batch element, batch size N and feature
|
||||
dimension D.
|
||||
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
|
||||
with P2 points in each batch element, batch size N and feature
|
||||
dimension D.
|
||||
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
|
||||
a batch of point clouds with at most P1 points in each batch element,
|
||||
batch size N and feature dimension D.
|
||||
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
|
||||
a batch of point clouds with at most P2 points in each batch element,
|
||||
batch size N and feature dimension D.
|
||||
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
|
||||
cloud in x.
|
||||
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
|
||||
cloud in x.
|
||||
x_normals: Optional FloatTensor of shape (N, P1, D).
|
||||
y_normals: Optional FloatTensor of shape (N, P2, D).
|
||||
weights: Optional FloatTensor of shape (N,) giving weights for
|
||||
batch elements for reduction operation.
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["none", "mean", "sum"].
|
||||
batch, can be one of ["mean", "sum"] or None.
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["none", "mean", "sum"].
|
||||
points, can be one of ["mean", "sum"].
|
||||
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
@ -61,16 +108,31 @@ def chamfer_distance(
|
||||
"""
|
||||
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
||||
|
||||
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
|
||||
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
|
||||
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
|
||||
N, P1, D = x.shape
|
||||
P2 = y.shape[1]
|
||||
|
||||
# Check if inputs are heterogeneous and create a lengths mask.
|
||||
is_x_heterogeneous = ~(x_lengths == P1).all()
|
||||
is_y_heterogeneous = ~(y_lengths == P2).all()
|
||||
x_mask = (
|
||||
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
|
||||
) # shape [N, P1]
|
||||
y_mask = (
|
||||
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
|
||||
) # shape [N, P2]
|
||||
|
||||
if y.shape[0] != N or y.shape[2] != D:
|
||||
raise ValueError("y does not have the correct shape.")
|
||||
if weights is not None:
|
||||
if weights.size(0) != N:
|
||||
raise ValueError("weights must be of shape (N,).")
|
||||
if not (weights >= 0).all():
|
||||
raise ValueError("weights can not be nonnegative.")
|
||||
raise ValueError("weights cannot be negative.")
|
||||
if weights.sum() == 0.0:
|
||||
weights = weights.view(N, 1)
|
||||
if batch_reduction in ["mean", "sum"]:
|
||||
@ -80,46 +142,60 @@ def chamfer_distance(
|
||||
)
|
||||
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
|
||||
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
cham_norm_x = x.new_zeros(())
|
||||
cham_norm_y = x.new_zeros(())
|
||||
|
||||
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
|
||||
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
|
||||
x_dists, x_idx = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
|
||||
y_dists, y_idx = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
|
||||
|
||||
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
|
||||
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
|
||||
cham_x = x_dists[..., 0] # (N, P1)
|
||||
cham_y = y_dists[..., 0] # (N, P2)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
cham_x[x_mask] = 0.0
|
||||
if is_y_heterogeneous:
|
||||
cham_y[y_mask] = 0.0
|
||||
|
||||
if weights is not None:
|
||||
cham_x *= weights.view(N, 1)
|
||||
cham_y *= weights.view(N, 1)
|
||||
|
||||
if return_normals:
|
||||
# Gather the normals using the indices and keep only value for k=0
|
||||
x_normals_near = knn_gather(y_normals, x_idx, y_lengths)[..., 0, :]
|
||||
y_normals_near = knn_gather(x_normals, y_idx, x_lengths)[..., 0, :]
|
||||
|
||||
cham_norm_x = 1 - torch.abs(
|
||||
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
|
||||
)
|
||||
cham_norm_y = 1 - torch.abs(
|
||||
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
|
||||
)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
cham_norm_x[x_mask] = 0.0
|
||||
if is_y_heterogeneous:
|
||||
cham_norm_y[y_mask] = 0.0
|
||||
|
||||
if weights is not None:
|
||||
cham_norm_x *= weights.view(N, 1)
|
||||
cham_norm_y *= weights.view(N, 1)
|
||||
|
||||
if point_reduction != "none":
|
||||
# If not 'none' then either 'sum' or 'mean'.
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
cham_y = cham_y.sum(1) # (N,)
|
||||
# Apply point reduction
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
cham_y = cham_y.sum(1) # (N,)
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
cham_x /= x_lengths
|
||||
cham_y /= y_lengths
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
cham_x /= P1
|
||||
cham_y /= P2
|
||||
if return_normals:
|
||||
cham_norm_x /= P1
|
||||
cham_norm_y /= P2
|
||||
cham_norm_x /= x_lengths
|
||||
cham_norm_y /= y_lengths
|
||||
|
||||
if batch_reduction != "none":
|
||||
if batch_reduction is not None:
|
||||
# batch_reduction == "sum"
|
||||
cham_x = cham_x.sum()
|
||||
cham_y = cham_y.sum()
|
||||
if return_normals:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
@ -20,8 +21,23 @@ def bm_chamfer() -> None:
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
kwargs_list = kwargs_list_naive + [
|
||||
{"batch_size": 1, "P1": 1000, "P2": 3000, "return_normals": False},
|
||||
{"batch_size": 1, "P1": 1000, "P2": 30000, "return_normals": True},
|
||||
]
|
||||
kwargs_list = []
|
||||
batch_size = [1, 32]
|
||||
P1 = [32, 1000, 10000]
|
||||
P2 = [64, 3000, 30000]
|
||||
return_normals = [True, False]
|
||||
homogeneous = [True, False]
|
||||
test_cases = product(batch_size, P1, P2, return_normals, homogeneous)
|
||||
|
||||
for case in test_cases:
|
||||
b, p1, p2, n, h = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"batch_size": b,
|
||||
"P1": p1,
|
||||
"P2": p2,
|
||||
"return_normals": n,
|
||||
"homogeneous": h,
|
||||
}
|
||||
)
|
||||
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)
|
||||
|
@ -1,111 +1,442 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.structures.utils import list_to_padded
|
||||
|
||||
|
||||
# Output of init_pointclouds
|
||||
points_normals = namedtuple(
|
||||
"points_normals", "p1_lengths p2_lengths cloud1 cloud2 p1 p2 n1 n2 weights"
|
||||
)
|
||||
|
||||
|
||||
class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
@staticmethod
|
||||
def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
|
||||
"""
|
||||
Randomly initialize two batches of point clouds of sizes
|
||||
(N, P1, D) and (N, P2, D) and return random normal vectors for
|
||||
each batch of size (N, P1, 3) and (N, P2, 3).
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
p1 = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
|
||||
p1_normals = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
|
||||
p1_normals = p1_normals / p1_normals.norm(dim=2, p=2, keepdim=True)
|
||||
p2 = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
|
||||
p2_normals = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
|
||||
p2_normals = p2_normals / p2_normals.norm(dim=2, p=2, keepdim=True)
|
||||
weights = torch.rand((batch_size,), dtype=torch.float32, device=device)
|
||||
|
||||
return p1, p2, p1_normals, p2_normals, weights
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
@staticmethod
|
||||
def chamfer_distance_naive(p1, p2, p1_normals=None, p2_normals=None):
|
||||
def init_pointclouds(N, P1, P2, device, requires_grad: bool = True):
|
||||
"""
|
||||
Create 2 pointclouds object and associated padded points/normals tensors by
|
||||
starting from lists. The clouds and tensors have the same data. The
|
||||
leaf nodes for the clouds are a list of tensors. The padded tensor can be
|
||||
used directly as a leaf node.
|
||||
"""
|
||||
p1_lengths = torch.randint(P1, size=(N,), dtype=torch.int64, device=device)
|
||||
p2_lengths = torch.randint(P2, size=(N,), dtype=torch.int64, device=device)
|
||||
weights = torch.rand((N,), dtype=torch.float32, device=device)
|
||||
|
||||
# list of points and normals tensors
|
||||
p1_list = []
|
||||
p2_list = []
|
||||
n1_list = []
|
||||
n2_list = []
|
||||
for i in range(N):
|
||||
l1 = p1_lengths[i]
|
||||
l2 = p2_lengths[i]
|
||||
p1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device))
|
||||
p2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device))
|
||||
n1_list.append(torch.rand((l1, 3), dtype=torch.float32, device=device))
|
||||
n2_list.append(torch.rand((l2, 3), dtype=torch.float32, device=device))
|
||||
|
||||
n1_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n1_list]
|
||||
n2_list = [n / n.norm(dim=-1, p=2, keepdim=True) for n in n2_list]
|
||||
|
||||
# Clone the lists and initialize padded tensors.
|
||||
p1 = list_to_padded([p.clone() for p in p1_list])
|
||||
p2 = list_to_padded([p.clone() for p in p2_list])
|
||||
n1 = list_to_padded([p.clone() for p in n1_list])
|
||||
n2 = list_to_padded([p.clone() for p in n2_list])
|
||||
|
||||
# Set requires_grad for all tensors in the lists and
|
||||
# padded tensors.
|
||||
if requires_grad:
|
||||
for p in p2_list + p1_list + n1_list + n2_list + [p1, p2, n1, n2]:
|
||||
p.requires_grad = True
|
||||
|
||||
# Create pointclouds objects
|
||||
cloud1 = Pointclouds(points=p1_list, normals=n1_list)
|
||||
cloud2 = Pointclouds(points=p2_list, normals=n2_list)
|
||||
|
||||
# Return pointclouds objects and padded tensors
|
||||
return points_normals(
|
||||
p1_lengths=p1_lengths,
|
||||
p2_lengths=p2_lengths,
|
||||
cloud1=cloud1,
|
||||
cloud2=cloud2,
|
||||
p1=p1,
|
||||
p2=p2,
|
||||
n1=n1,
|
||||
n2=n2,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def chamfer_distance_naive_pointclouds(p1, p2):
|
||||
"""
|
||||
Naive iterative implementation of nearest neighbor and chamfer distance.
|
||||
x and y are assumed to be pointclouds objects with points and optionally normals.
|
||||
This functions supports heterogeneous pointclouds in a batch.
|
||||
Returns lists of the unreduced loss and loss_normals.
|
||||
"""
|
||||
N, P1, D = p1.shape
|
||||
P2 = p2.size(1)
|
||||
x = p1.points_padded()
|
||||
y = p2.points_padded()
|
||||
N, P1, D = x.shape
|
||||
P2 = y.size(1)
|
||||
x_lengths = p1.num_points_per_cloud()
|
||||
y_lengths = p2.num_points_per_cloud()
|
||||
x_normals = p1.normals_padded()
|
||||
y_normals = p2.normals_padded()
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
return_normals = p1_normals is not None and p2_normals is not None
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
|
||||
# Initialize all distances to + inf
|
||||
dist = torch.ones((N, P1, P2), dtype=torch.float32, device=device) * np.inf
|
||||
|
||||
x_mask = (
|
||||
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
|
||||
) # shape [N, P1]
|
||||
y_mask = (
|
||||
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
|
||||
) # shape [N, P2]
|
||||
|
||||
is_x_heterogeneous = ~(x_lengths == P1).all()
|
||||
is_y_heterogeneous = ~(y_lengths == P2).all()
|
||||
|
||||
# Only calculate the distances for the points which are not masked
|
||||
for n in range(N):
|
||||
for i1 in range(x_lengths[n]):
|
||||
for i2 in range(y_lengths[n]):
|
||||
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
|
||||
|
||||
x_dist = torch.min(dist, dim=2)[0] # (N, P1)
|
||||
y_dist = torch.min(dist, dim=1)[0] # (N, P2)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
x_dist[x_mask] = 0.0
|
||||
if is_y_heterogeneous:
|
||||
y_dist[y_mask] = 0.0
|
||||
|
||||
loss = [x_dist, y_dist]
|
||||
|
||||
lnorm = [x.new_zeros(()), x.new_zeros(())]
|
||||
|
||||
if return_normals:
|
||||
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
|
||||
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
|
||||
lnorm1 = 1 - torch.abs(
|
||||
F.cosine_similarity(
|
||||
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
|
||||
)
|
||||
)
|
||||
lnorm2 = 1 - torch.abs(
|
||||
F.cosine_similarity(
|
||||
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
|
||||
)
|
||||
)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
lnorm1[x_mask] = 0.0
|
||||
if is_y_heterogeneous:
|
||||
lnorm2[y_mask] = 0.0
|
||||
|
||||
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
|
||||
|
||||
return loss, lnorm
|
||||
|
||||
@staticmethod
|
||||
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None):
|
||||
"""
|
||||
Naive iterative implementation of nearest neighbor and chamfer distance.
|
||||
Returns lists of the unreduced loss and loss_normals. This naive
|
||||
version only supports homogeneous pointcouds in a batch.
|
||||
"""
|
||||
N, P1, D = x.shape
|
||||
P2 = y.size(1)
|
||||
device = torch.device("cuda:0")
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
|
||||
|
||||
for n in range(N):
|
||||
for i1 in range(P1):
|
||||
for i2 in range(P2):
|
||||
dist[n, i1, i2] = torch.sum((p1[n, i1, :] - p2[n, i2, :]) ** 2)
|
||||
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
|
||||
|
||||
loss = [
|
||||
torch.min(dist, dim=2)[0], # (N, P1)
|
||||
torch.min(dist, dim=1)[0], # (N, P2)
|
||||
]
|
||||
|
||||
lnorm = [p1.new_zeros(()), p1.new_zeros(())]
|
||||
lnorm = [x.new_zeros(()), x.new_zeros(())]
|
||||
|
||||
if return_normals:
|
||||
p1_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
|
||||
p2_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
|
||||
x_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
|
||||
y_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
|
||||
lnorm1 = 1 - torch.abs(
|
||||
F.cosine_similarity(
|
||||
p1_normals, p2_normals.gather(1, p1_index), dim=2, eps=1e-6
|
||||
x_normals, y_normals.gather(1, x_index), dim=2, eps=1e-6
|
||||
)
|
||||
)
|
||||
lnorm2 = 1 - torch.abs(
|
||||
F.cosine_similarity(
|
||||
p2_normals, p1_normals.gather(1, p2_index), dim=2, eps=1e-6
|
||||
y_normals, x_normals.gather(1, y_index), dim=2, eps=1e-6
|
||||
)
|
||||
)
|
||||
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
|
||||
|
||||
return loss, lnorm
|
||||
|
||||
def test_chamfer_default_no_normals(self):
|
||||
def test_chamfer_point_batch_reduction_mean(self):
|
||||
"""
|
||||
Compare chamfer loss with naive implementation using default
|
||||
input values and no normals.
|
||||
Compare output of vectorized chamfer loss with naive implementation
|
||||
for the default settings (point_reduction = "mean" and batch_reduction = "mean")
|
||||
and no normals.
|
||||
This tests only uses homogeneous pointclouds.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
p1, p2, _, _, weights = TestChamfer.init_pointclouds(N, P1, P2)
|
||||
pred_loss, _ = TestChamfer.chamfer_distance_naive(p1, p2)
|
||||
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
weights = points_normals.weights
|
||||
p11 = p1.detach().clone()
|
||||
p22 = p2.detach().clone()
|
||||
p11.requires_grad = True
|
||||
p22.requires_grad = True
|
||||
P1 = p1.shape[1]
|
||||
P2 = p2.shape[1]
|
||||
|
||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2)
|
||||
|
||||
# point_reduction = "mean".
|
||||
loss, loss_norm = chamfer_distance(p11, p22, weights=weights)
|
||||
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
||||
pred_loss *= weights
|
||||
pred_loss = pred_loss.sum() / weights.sum()
|
||||
|
||||
self.assertClose(loss, pred_loss)
|
||||
self.assertTrue(loss_norm is None)
|
||||
|
||||
def test_chamfer_point_reduction(self):
|
||||
# Check gradients
|
||||
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
|
||||
|
||||
def test_chamfer_vs_naive_pointcloud(self):
|
||||
"""
|
||||
Compare output of vectorized chamfer loss with naive implementation
|
||||
for point_reduction in ["mean", "sum", "none"] and
|
||||
batch_reduction = "none".
|
||||
Test the default settings for chamfer_distance
|
||||
(point reduction = "mean" and batch_reduction="mean") but with heterogeneous
|
||||
pointclouds as input. Compare with the naive implementation of chamfer
|
||||
which supports heterogeneous pointcloud objects.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
||||
N, P1, P2
|
||||
N, max_P1, max_P2 = 3, 70, 70
|
||||
device = "cuda:0"
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
weights = points_normals.weights
|
||||
x_lengths = points_normals.p1_lengths
|
||||
y_lengths = points_normals.p2_lengths
|
||||
|
||||
# Chamfer with tensors as input for heterogeneous pointclouds.
|
||||
cham_tensor, norm_tensor = chamfer_distance(
|
||||
points_normals.p1,
|
||||
points_normals.p2,
|
||||
x_normals=points_normals.n1,
|
||||
y_normals=points_normals.n2,
|
||||
x_lengths=points_normals.p1_lengths,
|
||||
y_lengths=points_normals.p2_lengths,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
# Chamfer with pointclouds as input.
|
||||
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
|
||||
points_normals.cloud1, points_normals.cloud2
|
||||
)
|
||||
|
||||
# Mean reduction point loss.
|
||||
pred_loss[0] *= weights.view(N, 1)
|
||||
pred_loss[1] *= weights.view(N, 1)
|
||||
pred_loss_mean = (
|
||||
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
|
||||
)
|
||||
pred_loss_mean = pred_loss_mean.sum()
|
||||
pred_loss_mean /= weights.sum()
|
||||
|
||||
# Mean reduction norm loss.
|
||||
pred_norm_loss[0] *= weights.view(N, 1)
|
||||
pred_norm_loss[1] *= weights.view(N, 1)
|
||||
pred_norm_loss_mean = (
|
||||
pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths
|
||||
)
|
||||
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
|
||||
|
||||
self.assertClose(pred_loss_mean, cham_tensor)
|
||||
self.assertClose(pred_norm_loss_mean, norm_tensor)
|
||||
|
||||
self._check_gradients(
|
||||
cham_tensor,
|
||||
norm_tensor,
|
||||
pred_loss_mean,
|
||||
pred_norm_loss_mean,
|
||||
points_normals.cloud1.points_list(),
|
||||
points_normals.p1,
|
||||
points_normals.cloud2.points_list(),
|
||||
points_normals.p2,
|
||||
points_normals.cloud1.normals_list(),
|
||||
points_normals.n1,
|
||||
points_normals.cloud2.normals_list(),
|
||||
points_normals.n2,
|
||||
x_lengths,
|
||||
y_lengths,
|
||||
)
|
||||
|
||||
def test_chamfer_pointcloud_object_withnormals(self):
|
||||
N = 5
|
||||
P1, P2 = 100, 100
|
||||
device = "cuda:0"
|
||||
|
||||
reductions = [
|
||||
("sum", "sum"),
|
||||
("mean", "sum"),
|
||||
("sum", "mean"),
|
||||
("mean", "mean"),
|
||||
("sum", None),
|
||||
("mean", None),
|
||||
]
|
||||
for (point_reduction, batch_reduction) in reductions:
|
||||
|
||||
# Reinitialize all the tensors so that the
|
||||
# backward pass can be computed.
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
|
||||
# Chamfer with pointclouds as input.
|
||||
cham_cloud, norm_cloud = chamfer_distance(
|
||||
points_normals.cloud1,
|
||||
points_normals.cloud2,
|
||||
point_reduction=point_reduction,
|
||||
batch_reduction=batch_reduction,
|
||||
)
|
||||
|
||||
# Chamfer with tensors as input.
|
||||
cham_tensor, norm_tensor = chamfer_distance(
|
||||
points_normals.p1,
|
||||
points_normals.p2,
|
||||
x_lengths=points_normals.p1_lengths,
|
||||
y_lengths=points_normals.p2_lengths,
|
||||
x_normals=points_normals.n1,
|
||||
y_normals=points_normals.n2,
|
||||
point_reduction=point_reduction,
|
||||
batch_reduction=batch_reduction,
|
||||
)
|
||||
|
||||
self.assertClose(cham_cloud, cham_tensor)
|
||||
self.assertClose(norm_cloud, norm_tensor)
|
||||
self._check_gradients(
|
||||
cham_tensor,
|
||||
norm_tensor,
|
||||
cham_cloud,
|
||||
norm_cloud,
|
||||
points_normals.cloud1.points_list(),
|
||||
points_normals.p1,
|
||||
points_normals.cloud2.points_list(),
|
||||
points_normals.p2,
|
||||
points_normals.cloud1.normals_list(),
|
||||
points_normals.n1,
|
||||
points_normals.cloud2.normals_list(),
|
||||
points_normals.n2,
|
||||
points_normals.p1_lengths,
|
||||
points_normals.p2_lengths,
|
||||
)
|
||||
|
||||
def test_chamfer_pointcloud_object_nonormals(self):
|
||||
N = 5
|
||||
P1, P2 = 100, 100
|
||||
device = "cuda:0"
|
||||
|
||||
reductions = [
|
||||
("sum", "sum"),
|
||||
("mean", "sum"),
|
||||
("sum", "mean"),
|
||||
("mean", "mean"),
|
||||
("sum", None),
|
||||
("mean", None),
|
||||
]
|
||||
for (point_reduction, batch_reduction) in reductions:
|
||||
|
||||
# Reinitialize all the tensors so that the
|
||||
# backward pass can be computed.
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
|
||||
# Chamfer with pointclouds as input.
|
||||
cham_cloud, _ = chamfer_distance(
|
||||
points_normals.cloud1,
|
||||
points_normals.cloud2,
|
||||
point_reduction=point_reduction,
|
||||
batch_reduction=batch_reduction,
|
||||
)
|
||||
|
||||
# Chamfer with tensors as input.
|
||||
cham_tensor, _ = chamfer_distance(
|
||||
points_normals.p1,
|
||||
points_normals.p2,
|
||||
x_lengths=points_normals.p1_lengths,
|
||||
y_lengths=points_normals.p2_lengths,
|
||||
point_reduction=point_reduction,
|
||||
batch_reduction=batch_reduction,
|
||||
)
|
||||
|
||||
self.assertClose(cham_cloud, cham_tensor)
|
||||
self._check_gradients(
|
||||
cham_tensor,
|
||||
None,
|
||||
cham_cloud,
|
||||
None,
|
||||
points_normals.cloud1.points_list(),
|
||||
points_normals.p1,
|
||||
points_normals.cloud2.points_list(),
|
||||
points_normals.p2,
|
||||
lengths1=points_normals.p1_lengths,
|
||||
lengths2=points_normals.p2_lengths,
|
||||
)
|
||||
|
||||
def test_chamfer_point_reduction_mean(self):
|
||||
"""
|
||||
Compare output of vectorized chamfer loss with naive implementation
|
||||
for point_reduction = "mean" and batch_reduction = None.
|
||||
"""
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
p1_normals = points_normals.n1
|
||||
p2_normals = points_normals.n2
|
||||
weights = points_normals.weights
|
||||
p11 = p1.detach().clone()
|
||||
p22 = p2.detach().clone()
|
||||
p11.requires_grad = True
|
||||
p22.requires_grad = True
|
||||
P1 = p1.shape[1]
|
||||
P2 = p2.shape[1]
|
||||
|
||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, p1_normals, p2_normals
|
||||
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||
)
|
||||
|
||||
# point_reduction = "mean".
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
p1_normals,
|
||||
p2_normals,
|
||||
p11,
|
||||
p22,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
batch_reduction="none",
|
||||
batch_reduction=None,
|
||||
point_reduction="mean",
|
||||
)
|
||||
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
||||
@ -118,14 +449,40 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
pred_loss_norm_mean *= weights
|
||||
self.assertClose(loss_norm, pred_loss_norm_mean)
|
||||
|
||||
# point_reduction = "sum".
|
||||
# Check gradients
|
||||
self._check_gradients(
|
||||
loss, loss_norm, pred_loss_mean, pred_loss_norm_mean, p1, p11, p2, p22
|
||||
)
|
||||
|
||||
def test_chamfer_point_reduction_sum(self):
|
||||
"""
|
||||
Compare output of vectorized chamfer loss with naive implementation
|
||||
for point_reduction = "sum" and batch_reduction = None.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
p1_normals = points_normals.n1
|
||||
p2_normals = points_normals.n2
|
||||
weights = points_normals.weights
|
||||
p11 = p1.detach().clone()
|
||||
p22 = p2.detach().clone()
|
||||
p11.requires_grad = True
|
||||
p22.requires_grad = True
|
||||
|
||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||
)
|
||||
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
p1_normals,
|
||||
p2_normals,
|
||||
p11,
|
||||
p22,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
batch_reduction="none",
|
||||
batch_reduction=None,
|
||||
point_reduction="sum",
|
||||
)
|
||||
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
|
||||
@ -136,92 +493,110 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
pred_loss_norm_sum *= weights
|
||||
self.assertClose(loss_norm, pred_loss_norm_sum)
|
||||
|
||||
# Error when point_reduction = "none" and batch_reduction = "none".
|
||||
with self.assertRaises(ValueError):
|
||||
chamfer_distance(
|
||||
p1, p2, weights=weights, batch_reduction="none", point_reduction="none"
|
||||
# Check gradients
|
||||
self._check_gradients(
|
||||
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
|
||||
)
|
||||
|
||||
def _check_gradients(
|
||||
self,
|
||||
loss,
|
||||
loss_norm,
|
||||
pred_loss,
|
||||
pred_loss_norm,
|
||||
x1,
|
||||
x2,
|
||||
y1,
|
||||
y2,
|
||||
xn1=None, # normals
|
||||
xn2=None, # normals
|
||||
yn1=None, # normals
|
||||
yn2=None, # normals
|
||||
lengths1=None,
|
||||
lengths2=None,
|
||||
):
|
||||
"""
|
||||
x1 and x2 can have different types based on the leaf node used in the calculation:
|
||||
e.g. x1 may be a list of tensors whereas x2 is a padded tensor.
|
||||
This also applies for the pairs: (y1, y2), (xn1, xn2), (yn1, yn2).
|
||||
"""
|
||||
grad_loss = torch.rand(loss.shape, device=loss.device, dtype=loss.dtype)
|
||||
|
||||
# Loss for normals is optional. Iniitalize to 0.
|
||||
norm_loss_term = pred_norm_loss_term = 0.0
|
||||
if loss_norm is not None and pred_loss_norm is not None:
|
||||
grad_normals = torch.rand(
|
||||
loss_norm.shape, device=loss.device, dtype=loss.dtype
|
||||
)
|
||||
norm_loss_term = loss_norm * grad_normals
|
||||
pred_norm_loss_term = pred_loss_norm * grad_normals
|
||||
|
||||
# Error when batch_reduction is not in ["none", "mean", "sum"].
|
||||
with self.assertRaises(ValueError):
|
||||
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
|
||||
l1 = (loss * grad_loss) + norm_loss_term
|
||||
l1.sum().backward()
|
||||
l2 = (pred_loss * grad_loss) + pred_norm_loss_term
|
||||
l2.sum().backward()
|
||||
|
||||
def test_chamfer_batch_reduction(self):
|
||||
self._check_grad_by_type(x1, x2, lengths1)
|
||||
self._check_grad_by_type(y1, y2, lengths2)
|
||||
|
||||
# If leaf nodes for normals are passed in, check their gradients.
|
||||
if all(n is not None for n in [xn1, xn2, yn1, yn2]):
|
||||
self._check_grad_by_type(xn1, xn2, lengths1)
|
||||
self._check_grad_by_type(yn1, yn2, lengths2)
|
||||
|
||||
def _check_grad_by_type(self, x1, x2, lengths=None):
|
||||
"""
|
||||
Compare output of vectorized chamfer loss with naive implementation
|
||||
for batch_reduction in ["mean", "sum"] and point_reduction = "none".
|
||||
x1 and x2 can be of different types e.g. list or tensor - compare appropriately
|
||||
based on the types.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
||||
N, P1, P2
|
||||
)
|
||||
error_msg = "All values for gradient checks must be tensors or lists of tensors"
|
||||
|
||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, p1_normals, p2_normals
|
||||
)
|
||||
if all(isinstance(p, list) for p in [x1, x2]):
|
||||
# Lists of tensors
|
||||
for i in range(len(x1)):
|
||||
self.assertClose(x1[i].grad, x2[i].grad)
|
||||
elif isinstance(x1, list) and torch.is_tensor(x2):
|
||||
self.assertIsNotNone(lengths) # lengths is required
|
||||
|
||||
# batch_reduction = "sum".
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
p1_normals,
|
||||
p2_normals,
|
||||
weights=weights,
|
||||
batch_reduction="sum",
|
||||
point_reduction="none",
|
||||
)
|
||||
pred_loss[0] *= weights.view(N, 1)
|
||||
pred_loss[1] *= weights.view(N, 1)
|
||||
pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
|
||||
self.assertClose(loss, pred_loss)
|
||||
|
||||
pred_loss_norm[0] *= weights.view(N, 1)
|
||||
pred_loss_norm[1] *= weights.view(N, 1)
|
||||
pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
|
||||
self.assertClose(loss_norm, pred_loss_norm)
|
||||
|
||||
# batch_reduction = "mean".
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
p1_normals,
|
||||
p2_normals,
|
||||
weights=weights,
|
||||
batch_reduction="mean",
|
||||
point_reduction="none",
|
||||
)
|
||||
|
||||
pred_loss /= weights.sum()
|
||||
self.assertClose(loss, pred_loss)
|
||||
|
||||
pred_loss_norm /= weights.sum()
|
||||
self.assertClose(loss_norm, pred_loss_norm)
|
||||
|
||||
# Error when point_reduction is not in ["none", "mean", "sum"].
|
||||
with self.assertRaises(ValueError):
|
||||
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
|
||||
# List of tensors vs padded tensor
|
||||
for i in range(len(x1)):
|
||||
self.assertClose(x1[i].grad, x2.grad[i, : lengths[i]])
|
||||
self.assertTrue(x2.grad[i, lengths[i] :].sum().item() == 0.0)
|
||||
elif all(torch.is_tensor(p) for p in [x1, x2]):
|
||||
# Two tensors
|
||||
self.assertClose(x1.grad, x2.grad)
|
||||
else:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def test_chamfer_joint_reduction(self):
|
||||
"""
|
||||
Compare output of vectorized chamfer loss with naive implementation
|
||||
for batch_reduction in ["mean", "sum"] and
|
||||
when batch_reduction in ["mean", "sum"] and
|
||||
point_reduction in ["mean", "sum"].
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
||||
N, P1, P2
|
||||
)
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
p1_normals = points_normals.n1
|
||||
p2_normals = points_normals.n2
|
||||
weights = points_normals.weights
|
||||
|
||||
P1 = p1.shape[1]
|
||||
P2 = p2.shape[1]
|
||||
|
||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, p1_normals, p2_normals
|
||||
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||
)
|
||||
|
||||
# batch_reduction = "sum", point_reduction = "sum".
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
p1_normals,
|
||||
p2_normals,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
batch_reduction="sum",
|
||||
point_reduction="sum",
|
||||
@ -244,8 +619,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
p1_normals,
|
||||
p2_normals,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
batch_reduction="mean",
|
||||
point_reduction="sum",
|
||||
@ -260,8 +635,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
p1_normals,
|
||||
p2_normals,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
batch_reduction="sum",
|
||||
point_reduction="mean",
|
||||
@ -280,8 +655,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
p1_normals,
|
||||
p2_normals,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
batch_reduction="mean",
|
||||
point_reduction="mean",
|
||||
@ -292,6 +667,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
pred_loss_norm_mean /= weights.sum()
|
||||
self.assertClose(loss_norm, pred_loss_norm_mean)
|
||||
|
||||
# Error when batch_reduction is not in ["mean", "sum"] or None.
|
||||
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
|
||||
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
|
||||
|
||||
# Error when point_reduction is not in ["mean", "sum"].
|
||||
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
|
||||
chamfer_distance(p1, p2, weights=weights, point_reduction=None)
|
||||
|
||||
def test_incorrect_weights(self):
|
||||
N, P1, P2 = 16, 64, 128
|
||||
device = torch.device("cuda:0")
|
||||
@ -312,7 +695,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(loss_norm.requires_grad)
|
||||
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1, p2, weights=weights, batch_reduction="none"
|
||||
p1, p2, weights=weights, batch_reduction=None
|
||||
)
|
||||
self.assertClose(loss.cpu(), torch.zeros((N, N)))
|
||||
self.assertTrue(loss.requires_grad)
|
||||
@ -327,16 +710,53 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
|
||||
|
||||
def test_incorrect_inputs(self):
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
p1_normals = points_normals.n1
|
||||
|
||||
# Normals of wrong shape
|
||||
with self.assertRaisesRegex(ValueError, "Expected normals to be of shape"):
|
||||
chamfer_distance(p1, p2, x_normals=p1_normals[None])
|
||||
|
||||
# Points of wrong shape
|
||||
with self.assertRaisesRegex(ValueError, "Expected points to be of shape"):
|
||||
chamfer_distance(p1[None], p2)
|
||||
|
||||
# Lengths of wrong shape
|
||||
with self.assertRaisesRegex(ValueError, "Expected lengths to be of shape"):
|
||||
chamfer_distance(p1, p2, x_lengths=torch.tensor([1, 2, 3], device=device))
|
||||
|
||||
# Points are not a tensor or Pointclouds
|
||||
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
|
||||
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
|
||||
|
||||
@staticmethod
|
||||
def chamfer_with_init(batch_size: int, P1: int, P2: int, return_normals: bool):
|
||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
||||
def chamfer_with_init(
|
||||
batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool
|
||||
):
|
||||
p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds(
|
||||
batch_size, P1, P2
|
||||
)
|
||||
if homogeneous:
|
||||
# Set lengths to None so in Chamfer it assumes
|
||||
# there is no padding.
|
||||
l1 = l2 = None
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def loss():
|
||||
loss, loss_normals = chamfer_distance(
|
||||
p1, p2, p1_normals, p2_normals, weights=weights
|
||||
p1,
|
||||
p2,
|
||||
x_lengths=l1,
|
||||
y_lengths=l2,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -346,14 +766,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
def chamfer_naive_with_init(
|
||||
batch_size: int, P1: int, P2: int, return_normals: bool
|
||||
):
|
||||
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
|
||||
p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds(
|
||||
batch_size, P1, P2
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def loss():
|
||||
loss, loss_normals = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, p1_normals, p2_normals
|
||||
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user