mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
gather_scatter on CPU
Summary: CPU implementation of the graph convolution op. Reviewed By: nikhilaravi, gkioxari Differential Revision: D21384361 fbshipit-source-id: bc96730e9727bb9aa1b0a232dcb82f0c0d12fe6b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4872a2c4de
commit
7944d24d48
@@ -101,17 +101,24 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
mesh = ico_sphere()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
verts_cpu = verts.clone()
|
||||
edges_cpu = edges.clone()
|
||||
verts_cuda = verts.clone().to(device)
|
||||
edges_cuda = edges.clone().to(device)
|
||||
verts.requires_grad = True
|
||||
verts_cpu.requires_grad = True
|
||||
verts_cuda.requires_grad = True
|
||||
|
||||
neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
|
||||
neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
|
||||
neighbor_sums = gather_scatter_python(verts, edges, False)
|
||||
neighbor_sums_cuda.sum().backward()
|
||||
neighbor_sums.sum().backward()
|
||||
randoms = torch.rand_like(neighbor_sums)
|
||||
(neighbor_sums_cuda * randoms.cuda()).sum().backward()
|
||||
(neighbor_sums_cpu * randoms).sum().backward()
|
||||
(neighbor_sums * randoms).sum().backward()
|
||||
|
||||
self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
|
||||
self.assertClose(verts.grad, verts_cuda.grad.cpu())
|
||||
self.assertClose(verts.grad, verts_cpu.grad)
|
||||
|
||||
def test_repr(self):
|
||||
conv = GraphConv(32, 64, directed=True)
|
||||
@@ -141,22 +148,24 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
w0 = nn.Linear(3, 1)
|
||||
input = w0(verts)
|
||||
|
||||
# output
|
||||
output_cpu = gather_scatter_python(input, edges, False)
|
||||
# undirected
|
||||
output_python = gather_scatter_python(input, edges, False)
|
||||
output_cuda = _C.gather_scatter(
|
||||
input.to(device=device), edges.to(device=device), False, False
|
||||
)
|
||||
self.assertClose(output_cuda.cpu(), output_cpu)
|
||||
with self.assertRaises(Exception) as err:
|
||||
_C.gather_scatter(input.cpu(), edges.cpu(), False, False)
|
||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
||||
self.assertClose(output_cuda.cpu(), output_python)
|
||||
|
||||
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
|
||||
self.assertClose(output_cpu, output_python)
|
||||
|
||||
# directed
|
||||
output_cpu = gather_scatter_python(input, edges, True)
|
||||
output_python = gather_scatter_python(input, edges, True)
|
||||
output_cuda = _C.gather_scatter(
|
||||
input.to(device=device), edges.to(device=device), True, False
|
||||
)
|
||||
self.assertClose(output_cuda.cpu(), output_cpu)
|
||||
self.assertClose(output_cuda.cpu(), output_python)
|
||||
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
|
||||
self.assertClose(output_cpu, output_python)
|
||||
|
||||
@staticmethod
|
||||
def graph_conv_forward_backward(
|
||||
|
||||
Reference in New Issue
Block a user