mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
use assertClose
Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph. Reviewed By: nikhilaravi Differential Revision: D20556912 fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
744ef0c2c8
commit
595aca27ea
@@ -13,8 +13,10 @@ from pytorch3d.ops.graph_conv import (
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.utils import ico_sphere
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
class TestGraphConv(unittest.TestCase):
|
||||
|
||||
class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
def test_undirected(self):
|
||||
dtype = torch.float32
|
||||
device = torch.device("cuda:0")
|
||||
@@ -42,7 +44,7 @@ class TestGraphConv(unittest.TestCase):
|
||||
conv.w1.bias.data.zero_()
|
||||
|
||||
y = conv(verts, edges)
|
||||
self.assertTrue(torch.allclose(y, expected_y))
|
||||
self.assertClose(y, expected_y)
|
||||
|
||||
def test_no_edges(self):
|
||||
dtype = torch.float32
|
||||
@@ -57,19 +59,26 @@ class TestGraphConv(unittest.TestCase):
|
||||
conv.w0.bias.data.zero_()
|
||||
|
||||
y = conv(verts, edges)
|
||||
self.assertTrue(torch.allclose(y, expected_y))
|
||||
self.assertClose(y, expected_y)
|
||||
|
||||
def test_no_verts_and_edges(self):
|
||||
dtype = torch.float32
|
||||
verts = torch.tensor([], dtype=dtype, requires_grad=True)
|
||||
edges = torch.tensor([], dtype=dtype)
|
||||
w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
|
||||
|
||||
conv = GraphConv(3, 1).to(dtype)
|
||||
conv.w0.weight.data.copy_(w0)
|
||||
conv.w0.bias.data.zero_()
|
||||
|
||||
y = conv(verts, edges)
|
||||
self.assertTrue(torch.allclose(y, torch.tensor([])))
|
||||
self.assertClose(y, torch.zeros((0, 1)))
|
||||
self.assertTrue(y.requires_grad)
|
||||
|
||||
conv2 = GraphConv(3, 2).to(dtype)
|
||||
conv2.w0.weight.data.copy_(w0.repeat(2, 1))
|
||||
conv2.w0.bias.data.zero_()
|
||||
y = conv2(verts, edges)
|
||||
self.assertClose(y, torch.zeros((0, 2)))
|
||||
self.assertTrue(y.requires_grad)
|
||||
|
||||
def test_directed(self):
|
||||
@@ -91,7 +100,7 @@ class TestGraphConv(unittest.TestCase):
|
||||
conv.w1.bias.data.zero_()
|
||||
|
||||
y = conv(verts, edges)
|
||||
self.assertTrue(torch.allclose(y, expected_y))
|
||||
self.assertClose(y, expected_y)
|
||||
|
||||
def test_backward(self):
|
||||
device = torch.device("cuda:0")
|
||||
@@ -108,7 +117,7 @@ class TestGraphConv(unittest.TestCase):
|
||||
neighbor_sums_cuda.sum().backward()
|
||||
neighbor_sums.sum().backward()
|
||||
|
||||
self.assertTrue(torch.allclose(verts.grad.cpu(), verts_cuda.grad.cpu()))
|
||||
self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
|
||||
|
||||
def test_repr(self):
|
||||
conv = GraphConv(32, 64, directed=True)
|
||||
@@ -147,7 +156,7 @@ class TestGraphConv(unittest.TestCase):
|
||||
output_cuda = _C.gather_scatter(
|
||||
input.to(device=device), edges.to(device=device), False, False
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_cuda.cpu(), output_cpu))
|
||||
self.assertClose(output_cuda.cpu(), output_cpu)
|
||||
with self.assertRaises(Exception) as err:
|
||||
_C.gather_scatter(input.cpu(), edges.cpu(), False, False)
|
||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
||||
@@ -157,7 +166,7 @@ class TestGraphConv(unittest.TestCase):
|
||||
output_cuda = _C.gather_scatter(
|
||||
input.to(device=device), edges.to(device=device), True, False
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_cuda.cpu(), output_cpu))
|
||||
self.assertClose(output_cuda.cpu(), output_cpu)
|
||||
|
||||
@staticmethod
|
||||
def graph_conv_forward_backward(
|
||||
|
||||
Reference in New Issue
Block a user