mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 03:40:34 +08:00
use assertClose
Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph. Reviewed By: nikhilaravi Differential Revision: D20556912 fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
744ef0c2c8
commit
595aca27ea
@@ -6,8 +6,10 @@ import torch.nn.functional as F
|
||||
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
class TestChamfer(unittest.TestCase):
|
||||
|
||||
class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
@staticmethod
|
||||
def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
|
||||
"""
|
||||
@@ -85,7 +87,7 @@ class TestChamfer(unittest.TestCase):
|
||||
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
||||
pred_loss *= weights
|
||||
pred_loss = pred_loss.sum() / weights.sum()
|
||||
self.assertTrue(torch.allclose(loss, pred_loss))
|
||||
self.assertClose(loss, pred_loss)
|
||||
self.assertTrue(loss_norm is None)
|
||||
|
||||
def test_chamfer_point_reduction(self):
|
||||
@@ -115,13 +117,13 @@ class TestChamfer(unittest.TestCase):
|
||||
)
|
||||
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
||||
pred_loss_mean *= weights
|
||||
self.assertTrue(torch.allclose(loss, pred_loss_mean))
|
||||
self.assertClose(loss, pred_loss_mean)
|
||||
|
||||
pred_loss_norm_mean = (
|
||||
pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
|
||||
)
|
||||
pred_loss_norm_mean *= weights
|
||||
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean))
|
||||
self.assertClose(loss_norm, pred_loss_norm_mean)
|
||||
|
||||
# point_reduction = "sum".
|
||||
loss, loss_norm = chamfer_distance(
|
||||
@@ -135,11 +137,11 @@ class TestChamfer(unittest.TestCase):
|
||||
)
|
||||
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
|
||||
pred_loss_sum *= weights
|
||||
self.assertTrue(torch.allclose(loss, pred_loss_sum))
|
||||
self.assertClose(loss, pred_loss_sum)
|
||||
|
||||
pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(1)
|
||||
pred_loss_norm_sum *= weights
|
||||
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum))
|
||||
self.assertClose(loss_norm, pred_loss_norm_sum)
|
||||
|
||||
# Error when point_reduction = "none" and batch_reduction = "none".
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -182,12 +184,12 @@ class TestChamfer(unittest.TestCase):
|
||||
pred_loss[0] *= weights.view(N, 1)
|
||||
pred_loss[1] *= weights.view(N, 1)
|
||||
pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
|
||||
self.assertTrue(torch.allclose(loss, pred_loss))
|
||||
self.assertClose(loss, pred_loss)
|
||||
|
||||
pred_loss_norm[0] *= weights.view(N, 1)
|
||||
pred_loss_norm[1] *= weights.view(N, 1)
|
||||
pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
|
||||
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm))
|
||||
self.assertClose(loss_norm, pred_loss_norm)
|
||||
|
||||
# batch_reduction = "mean".
|
||||
loss, loss_norm = chamfer_distance(
|
||||
@@ -201,10 +203,10 @@ class TestChamfer(unittest.TestCase):
|
||||
)
|
||||
|
||||
pred_loss /= weights.sum()
|
||||
self.assertTrue(torch.allclose(loss, pred_loss))
|
||||
self.assertClose(loss, pred_loss)
|
||||
|
||||
pred_loss_norm /= weights.sum()
|
||||
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm))
|
||||
self.assertClose(loss_norm, pred_loss_norm)
|
||||
|
||||
# Error when point_reduction is not in ["none", "mean", "sum"].
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -239,7 +241,7 @@ class TestChamfer(unittest.TestCase):
|
||||
pred_loss[1] *= weights.view(N, 1)
|
||||
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1) # point sum
|
||||
pred_loss_sum = pred_loss_sum.sum() # batch sum
|
||||
self.assertTrue(torch.allclose(loss, pred_loss_sum))
|
||||
self.assertClose(loss, pred_loss_sum)
|
||||
|
||||
pred_loss_norm[0] *= weights.view(N, 1)
|
||||
pred_loss_norm[1] *= weights.view(N, 1)
|
||||
@@ -247,7 +249,7 @@ class TestChamfer(unittest.TestCase):
|
||||
1
|
||||
) # point sum.
|
||||
pred_loss_norm_sum = pred_loss_norm_sum.sum() # batch sum
|
||||
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum))
|
||||
self.assertClose(loss_norm, pred_loss_norm_sum)
|
||||
|
||||
# batch_reduction = "mean", point_reduction = "sum".
|
||||
loss, loss_norm = chamfer_distance(
|
||||
@@ -260,10 +262,10 @@ class TestChamfer(unittest.TestCase):
|
||||
point_reduction="sum",
|
||||
)
|
||||
pred_loss_sum /= weights.sum()
|
||||
self.assertTrue(torch.allclose(loss, pred_loss_sum))
|
||||
self.assertClose(loss, pred_loss_sum)
|
||||
|
||||
pred_loss_norm_sum /= weights.sum()
|
||||
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum))
|
||||
self.assertClose(loss_norm, pred_loss_norm_sum)
|
||||
|
||||
# batch_reduction = "sum", point_reduction = "mean".
|
||||
loss, loss_norm = chamfer_distance(
|
||||
@@ -277,13 +279,13 @@ class TestChamfer(unittest.TestCase):
|
||||
)
|
||||
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
||||
pred_loss_mean = pred_loss_mean.sum()
|
||||
self.assertTrue(torch.allclose(loss, pred_loss_mean))
|
||||
self.assertClose(loss, pred_loss_mean)
|
||||
|
||||
pred_loss_norm_mean = (
|
||||
pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
|
||||
)
|
||||
pred_loss_norm_mean = pred_loss_norm_mean.sum()
|
||||
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean))
|
||||
self.assertClose(loss_norm, pred_loss_norm_mean)
|
||||
|
||||
# batch_reduction = "mean", point_reduction = "mean". This is the default.
|
||||
loss, loss_norm = chamfer_distance(
|
||||
@@ -296,10 +298,10 @@ class TestChamfer(unittest.TestCase):
|
||||
point_reduction="mean",
|
||||
)
|
||||
pred_loss_mean /= weights.sum()
|
||||
self.assertTrue(torch.allclose(loss, pred_loss_mean))
|
||||
self.assertClose(loss, pred_loss_mean)
|
||||
|
||||
pred_loss_norm_mean /= weights.sum()
|
||||
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean))
|
||||
self.assertClose(loss_norm, pred_loss_norm_mean)
|
||||
|
||||
def test_incorrect_weights(self):
|
||||
N, P1, P2 = 16, 64, 128
|
||||
@@ -315,17 +317,17 @@ class TestChamfer(unittest.TestCase):
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1, p2, weights=weights, batch_reduction="mean"
|
||||
)
|
||||
self.assertTrue(torch.allclose(loss.cpu(), torch.zeros((1,))))
|
||||
self.assertClose(loss.cpu(), torch.zeros(()))
|
||||
self.assertTrue(loss.requires_grad)
|
||||
self.assertTrue(torch.allclose(loss_norm.cpu(), torch.zeros((1,))))
|
||||
self.assertClose(loss_norm.cpu(), torch.zeros(()))
|
||||
self.assertTrue(loss_norm.requires_grad)
|
||||
|
||||
loss, loss_norm = chamfer_distance(
|
||||
p1, p2, weights=weights, batch_reduction="none"
|
||||
)
|
||||
self.assertTrue(torch.allclose(loss.cpu(), torch.zeros((N,))))
|
||||
self.assertClose(loss.cpu(), torch.zeros((N, N)))
|
||||
self.assertTrue(loss.requires_grad)
|
||||
self.assertTrue(torch.allclose(loss_norm.cpu(), torch.zeros((N,))))
|
||||
self.assertClose(loss_norm.cpu(), torch.zeros((N, N)))
|
||||
self.assertTrue(loss_norm.requires_grad)
|
||||
|
||||
weights = torch.ones((N,), dtype=torch.float32, device=device) * -1
|
||||
|
||||
Reference in New Issue
Block a user