pytorch3d/tests/test_chamfer.py
Jeremy Reizenstein 595aca27ea use assertClose
Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph.

Reviewed By: nikhilaravi

Differential Revision: D20556912

fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
2020-03-23 11:36:38 -07:00

374 lines
13 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
import torch.nn.functional as F
from pytorch3d.loss import chamfer_distance
from common_testing import TestCaseMixin
class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod
def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
"""
Randomly initialize two batches of point clouds of sizes
(N, P1, D) and (N, P2, D) and return random normal vectors for
each batch of size (N, P1, 3) and (N, P2, 3).
"""
device = torch.device("cuda:0")
p1 = torch.rand((batch_size, P1, 3), dtype=torch.float32, device=device)
p1_normals = torch.rand(
(batch_size, P1, 3), dtype=torch.float32, device=device
)
p1_normals = p1_normals / p1_normals.norm(dim=2, p=2, keepdim=True)
p2 = torch.rand((batch_size, P2, 3), dtype=torch.float32, device=device)
p2_normals = torch.rand(
(batch_size, P2, 3), dtype=torch.float32, device=device
)
p2_normals = p2_normals / p2_normals.norm(dim=2, p=2, keepdim=True)
weights = torch.rand((batch_size,), dtype=torch.float32, device=device)
return p1, p2, p1_normals, p2_normals, weights
@staticmethod
def chamfer_distance_naive(p1, p2, p1_normals=None, p2_normals=None):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
Returns lists of the unreduced loss and loss_normals.
"""
N, P1, D = p1.shape
P2 = p2.size(1)
device = torch.device("cuda:0")
return_normals = p1_normals is not None and p2_normals is not None
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
for n in range(N):
for i1 in range(P1):
for i2 in range(P2):
dist[n, i1, i2] = torch.sum(
(p1[n, i1, :] - p2[n, i2, :]) ** 2
)
loss = [
torch.min(dist, dim=2)[0], # (N, P1)
torch.min(dist, dim=1)[0], # (N, P2)
]
lnorm = [p1.new_zeros(()), p1.new_zeros(())]
if return_normals:
p1_index = dist.argmin(2).view(N, P1, 1).expand(N, P1, 3)
p2_index = dist.argmin(1).view(N, P2, 1).expand(N, P2, 3)
lnorm1 = 1 - torch.abs(
F.cosine_similarity(
p1_normals, p2_normals.gather(1, p1_index), dim=2, eps=1e-6
)
)
lnorm2 = 1 - torch.abs(
F.cosine_similarity(
p2_normals, p1_normals.gather(1, p2_index), dim=2, eps=1e-6
)
)
lnorm = [lnorm1, lnorm2] # [(N, P1), (N, P2)]
return loss, lnorm
def test_chamfer_default_no_normals(self):
"""
Compare chamfer loss with naive implementation using default
input values and no normals.
"""
N, P1, P2 = 7, 10, 18
p1, p2, _, _, weights = TestChamfer.init_pointclouds(N, P1, P2)
pred_loss, _ = TestChamfer.chamfer_distance_naive(p1, p2)
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
def test_chamfer_point_reduction(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction in ["mean", "sum", "none"] and
batch_reduction = "none".
"""
N, P1, P2 = 7, 10, 18
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
N, P1, P2
)
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, p1_normals, p2_normals
)
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="none",
point_reduction="mean",
)
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss_mean *= weights
self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean = (
pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
)
pred_loss_norm_mean *= weights
self.assertClose(loss_norm, pred_loss_norm_mean)
# point_reduction = "sum".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="none",
point_reduction="sum",
)
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
pred_loss_sum *= weights
self.assertClose(loss, pred_loss_sum)
pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(1)
pred_loss_norm_sum *= weights
self.assertClose(loss_norm, pred_loss_norm_sum)
# Error when point_reduction = "none" and batch_reduction = "none".
with self.assertRaises(ValueError):
chamfer_distance(
p1,
p2,
weights=weights,
batch_reduction="none",
point_reduction="none",
)
# Error when batch_reduction is not in ["none", "mean", "sum"].
with self.assertRaises(ValueError):
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
def test_chamfer_batch_reduction(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for batch_reduction in ["mean", "sum"] and point_reduction = "none".
"""
N, P1, P2 = 7, 10, 18
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
N, P1, P2
)
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, p1_normals, p2_normals
)
# batch_reduction = "sum".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="sum",
point_reduction="none",
)
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
self.assertClose(loss, pred_loss)
pred_loss_norm[0] *= weights.view(N, 1)
pred_loss_norm[1] *= weights.view(N, 1)
pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
self.assertClose(loss_norm, pred_loss_norm)
# batch_reduction = "mean".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="mean",
point_reduction="none",
)
pred_loss /= weights.sum()
self.assertClose(loss, pred_loss)
pred_loss_norm /= weights.sum()
self.assertClose(loss_norm, pred_loss_norm)
# Error when point_reduction is not in ["none", "mean", "sum"].
with self.assertRaises(ValueError):
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
def test_chamfer_joint_reduction(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for batch_reduction in ["mean", "sum"] and
point_reduction in ["mean", "sum"].
"""
N, P1, P2 = 7, 10, 18
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
N, P1, P2
)
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, p1_normals, p2_normals
)
# batch_reduction = "sum", point_reduction = "sum".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="sum",
point_reduction="sum",
)
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1) # point sum
pred_loss_sum = pred_loss_sum.sum() # batch sum
self.assertClose(loss, pred_loss_sum)
pred_loss_norm[0] *= weights.view(N, 1)
pred_loss_norm[1] *= weights.view(N, 1)
pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(
1
) # point sum.
pred_loss_norm_sum = pred_loss_norm_sum.sum() # batch sum
self.assertClose(loss_norm, pred_loss_norm_sum)
# batch_reduction = "mean", point_reduction = "sum".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="mean",
point_reduction="sum",
)
pred_loss_sum /= weights.sum()
self.assertClose(loss, pred_loss_sum)
pred_loss_norm_sum /= weights.sum()
self.assertClose(loss_norm, pred_loss_norm_sum)
# batch_reduction = "sum", point_reduction = "mean".
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="sum",
point_reduction="mean",
)
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss_mean = pred_loss_mean.sum()
self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean = (
pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
)
pred_loss_norm_mean = pred_loss_norm_mean.sum()
self.assertClose(loss_norm, pred_loss_norm_mean)
# batch_reduction = "mean", point_reduction = "mean". This is the default.
loss, loss_norm = chamfer_distance(
p1,
p2,
p1_normals,
p2_normals,
weights=weights,
batch_reduction="mean",
point_reduction="mean",
)
pred_loss_mean /= weights.sum()
self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean /= weights.sum()
self.assertClose(loss_norm, pred_loss_norm_mean)
def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128
device = torch.device("cuda:0")
p1 = torch.rand(
(N, P1, 3), dtype=torch.float32, device=device, requires_grad=True
)
p2 = torch.rand(
(N, P2, 3), dtype=torch.float32, device=device, requires_grad=True
)
weights = torch.zeros((N,), dtype=torch.float32, device=device)
loss, loss_norm = chamfer_distance(
p1, p2, weights=weights, batch_reduction="mean"
)
self.assertClose(loss.cpu(), torch.zeros(()))
self.assertTrue(loss.requires_grad)
self.assertClose(loss_norm.cpu(), torch.zeros(()))
self.assertTrue(loss_norm.requires_grad)
loss, loss_norm = chamfer_distance(
p1, p2, weights=weights, batch_reduction="none"
)
self.assertClose(loss.cpu(), torch.zeros((N, N)))
self.assertTrue(loss.requires_grad)
self.assertClose(loss_norm.cpu(), torch.zeros((N, N)))
self.assertTrue(loss_norm.requires_grad)
weights = torch.ones((N,), dtype=torch.float32, device=device) * -1
with self.assertRaises(ValueError):
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
weights = torch.zeros((N - 1,), dtype=torch.float32, device=device)
with self.assertRaises(ValueError):
loss, loss_norm = chamfer_distance(p1, p2, weights=weights)
@staticmethod
def chamfer_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool
):
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
batch_size, P1, P2
)
torch.cuda.synchronize()
def loss():
loss, loss_normals = chamfer_distance(
p1, p2, p1_normals, p2_normals, weights=weights
)
torch.cuda.synchronize()
return loss
@staticmethod
def chamfer_naive_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool
):
p1, p2, p1_normals, p2_normals, weights = TestChamfer.init_pointclouds(
batch_size, P1, P2
)
torch.cuda.synchronize()
def loss():
loss, loss_normals = TestChamfer.chamfer_distance_naive(
p1, p2, p1_normals, p2_normals
)
torch.cuda.synchronize()
return loss