use assertClose

Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph.

Reviewed By: nikhilaravi

Differential Revision: D20556912

fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
Jeremy Reizenstein
2020-03-23 11:33:10 -07:00
committed by Facebook GitHub Bot
parent 744ef0c2c8
commit 595aca27ea
13 changed files with 216 additions and 241 deletions

View File

@@ -6,8 +6,10 @@ import torch.nn.functional as F
from pytorch3d.loss import chamfer_distance
from common_testing import TestCaseMixin
class TestChamfer(unittest.TestCase):
class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod
def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
"""
@@ -85,7 +87,7 @@ class TestChamfer(unittest.TestCase):
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
self.assertTrue(torch.allclose(loss, pred_loss))
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
def test_chamfer_point_reduction(self):
@@ -115,13 +117,13 @@ class TestChamfer(unittest.TestCase):
)
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss_mean *= weights
self.assertTrue(torch.allclose(loss, pred_loss_mean))
self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean = (
pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
)
pred_loss_norm_mean *= weights
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean))
self.assertClose(loss_norm, pred_loss_norm_mean)
# point_reduction = "sum".
loss, loss_norm = chamfer_distance(
@@ -135,11 +137,11 @@ class TestChamfer(unittest.TestCase):
)
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
pred_loss_sum *= weights
self.assertTrue(torch.allclose(loss, pred_loss_sum))
self.assertClose(loss, pred_loss_sum)
pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(1)
pred_loss_norm_sum *= weights
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum))
self.assertClose(loss_norm, pred_loss_norm_sum)
# Error when point_reduction = "none" and batch_reduction = "none".
with self.assertRaises(ValueError):
@@ -182,12 +184,12 @@ class TestChamfer(unittest.TestCase):
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
self.assertTrue(torch.allclose(loss, pred_loss))
self.assertClose(loss, pred_loss)
pred_loss_norm[0] *= weights.view(N, 1)
pred_loss_norm[1] *= weights.view(N, 1)
pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm))
self.assertClose(loss_norm, pred_loss_norm)
# batch_reduction = "mean".
loss, loss_norm = chamfer_distance(
@@ -201,10 +203,10 @@ class TestChamfer(unittest.TestCase):
)
pred_loss /= weights.sum()
self.assertTrue(torch.allclose(loss, pred_loss))
self.assertClose(loss, pred_loss)
pred_loss_norm /= weights.sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm))
self.assertClose(loss_norm, pred_loss_norm)
# Error when point_reduction is not in ["none", "mean", "sum"].
with self.assertRaises(ValueError):
@@ -239,7 +241,7 @@ class TestChamfer(unittest.TestCase):
pred_loss[1] *= weights.view(N, 1)
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1) # point sum
pred_loss_sum = pred_loss_sum.sum() # batch sum
self.assertTrue(torch.allclose(loss, pred_loss_sum))
self.assertClose(loss, pred_loss_sum)
pred_loss_norm[0] *= weights.view(N, 1)
pred_loss_norm[1] *= weights.view(N, 1)
@@ -247,7 +249,7 @@ class TestChamfer(unittest.TestCase):
1
) # point sum.
pred_loss_norm_sum = pred_loss_norm_sum.sum() # batch sum
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum))
self.assertClose(loss_norm, pred_loss_norm_sum)
# batch_reduction = "mean", point_reduction = "sum".
loss, loss_norm = chamfer_distance(
@@ -260,10 +262,10 @@ class TestChamfer(unittest.TestCase):
point_reduction="sum",
)
pred_loss_sum /= weights.sum()
self.assertTrue(torch.allclose(loss, pred_loss_sum))
self.assertClose(loss, pred_loss_sum)
pred_loss_norm_sum /= weights.sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum))
self.assertClose(loss_norm, pred_loss_norm_sum)
# batch_reduction = "sum", point_reduction = "mean".
loss, loss_norm = chamfer_distance(
@@ -277,13 +279,13 @@ class TestChamfer(unittest.TestCase):
)
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss_mean = pred_loss_mean.sum()
self.assertTrue(torch.allclose(loss, pred_loss_mean))
self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean = (
pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
)
pred_loss_norm_mean = pred_loss_norm_mean.sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean))
self.assertClose(loss_norm, pred_loss_norm_mean)
# batch_reduction = "mean", point_reduction = "mean". This is the default.
loss, loss_norm = chamfer_distance(
@@ -296,10 +298,10 @@ class TestChamfer(unittest.TestCase):
point_reduction="mean",
)
pred_loss_mean /= weights.sum()
self.assertTrue(torch.allclose(loss, pred_loss_mean))
self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean /= weights.sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean))
self.assertClose(loss_norm, pred_loss_norm_mean)
def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128
@@ -315,17 +317,17 @@ class TestChamfer(unittest.TestCase):
loss, loss_norm = chamfer_distance(
p1, p2, weights=weights, batch_reduction="mean"
)
self.assertTrue(torch.allclose(loss.cpu(), torch.zeros((1,))))
self.assertClose(loss.cpu(), torch.zeros(()))
self.assertTrue(loss.requires_grad)
self.assertTrue(torch.allclose(loss_norm.cpu(), torch.zeros((1,))))
self.assertClose(loss_norm.cpu(), torch.zeros(()))
self.assertTrue(loss_norm.requires_grad)
loss, loss_norm = chamfer_distance(
p1, p2, weights=weights, batch_reduction="none"
)
self.assertTrue(torch.allclose(loss.cpu(), torch.zeros((N,))))
self.assertClose(loss.cpu(), torch.zeros((N, N)))
self.assertTrue(loss.requires_grad)
self.assertTrue(torch.allclose(loss_norm.cpu(), torch.zeros((N,))))
self.assertClose(loss_norm.cpu(), torch.zeros((N, N)))
self.assertTrue(loss_norm.requires_grad)
weights = torch.ones((N,), dtype=torch.float32, device=device) * -1