use assertClose

Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph.

Reviewed By: nikhilaravi

Differential Revision: D20556912

fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
Jeremy Reizenstein
2020-03-23 11:33:10 -07:00
committed by Facebook GitHub Bot
parent 744ef0c2c8
commit 595aca27ea
13 changed files with 216 additions and 241 deletions

View File

@@ -13,8 +13,10 @@ from pytorch3d.ops.graph_conv import (
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils import ico_sphere
from common_testing import TestCaseMixin
class TestGraphConv(unittest.TestCase):
class TestGraphConv(TestCaseMixin, unittest.TestCase):
def test_undirected(self):
dtype = torch.float32
device = torch.device("cuda:0")
@@ -42,7 +44,7 @@ class TestGraphConv(unittest.TestCase):
conv.w1.bias.data.zero_()
y = conv(verts, edges)
self.assertTrue(torch.allclose(y, expected_y))
self.assertClose(y, expected_y)
def test_no_edges(self):
dtype = torch.float32
@@ -57,19 +59,26 @@ class TestGraphConv(unittest.TestCase):
conv.w0.bias.data.zero_()
y = conv(verts, edges)
self.assertTrue(torch.allclose(y, expected_y))
self.assertClose(y, expected_y)
def test_no_verts_and_edges(self):
dtype = torch.float32
verts = torch.tensor([], dtype=dtype, requires_grad=True)
edges = torch.tensor([], dtype=dtype)
w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
conv = GraphConv(3, 1).to(dtype)
conv.w0.weight.data.copy_(w0)
conv.w0.bias.data.zero_()
y = conv(verts, edges)
self.assertTrue(torch.allclose(y, torch.tensor([])))
self.assertClose(y, torch.zeros((0, 1)))
self.assertTrue(y.requires_grad)
conv2 = GraphConv(3, 2).to(dtype)
conv2.w0.weight.data.copy_(w0.repeat(2, 1))
conv2.w0.bias.data.zero_()
y = conv2(verts, edges)
self.assertClose(y, torch.zeros((0, 2)))
self.assertTrue(y.requires_grad)
def test_directed(self):
@@ -91,7 +100,7 @@ class TestGraphConv(unittest.TestCase):
conv.w1.bias.data.zero_()
y = conv(verts, edges)
self.assertTrue(torch.allclose(y, expected_y))
self.assertClose(y, expected_y)
def test_backward(self):
device = torch.device("cuda:0")
@@ -108,7 +117,7 @@ class TestGraphConv(unittest.TestCase):
neighbor_sums_cuda.sum().backward()
neighbor_sums.sum().backward()
self.assertTrue(torch.allclose(verts.grad.cpu(), verts_cuda.grad.cpu()))
self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
def test_repr(self):
conv = GraphConv(32, 64, directed=True)
@@ -147,7 +156,7 @@ class TestGraphConv(unittest.TestCase):
output_cuda = _C.gather_scatter(
input.to(device=device), edges.to(device=device), False, False
)
self.assertTrue(torch.allclose(output_cuda.cpu(), output_cpu))
self.assertClose(output_cuda.cpu(), output_cpu)
with self.assertRaises(Exception) as err:
_C.gather_scatter(input.cpu(), edges.cpu(), False, False)
self.assertTrue("Not implemented on the CPU" in str(err.exception))
@@ -157,7 +166,7 @@ class TestGraphConv(unittest.TestCase):
output_cuda = _C.gather_scatter(
input.to(device=device), edges.to(device=device), True, False
)
self.assertTrue(torch.allclose(output_cuda.cpu(), output_cpu))
self.assertClose(output_cuda.cpu(), output_cpu)
@staticmethod
def graph_conv_forward_backward(