diff --git a/tests/test_graph_conv.py b/tests/test_graph_conv.py index 75ff4f90..a3ac605d 100644 --- a/tests/test_graph_conv.py +++ b/tests/test_graph_conv.py @@ -113,7 +113,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase): neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False) neighbor_sums = gather_scatter_python(verts, edges, False) randoms = torch.rand_like(neighbor_sums) - (neighbor_sums_cuda * randoms.cuda()).sum().backward() + (neighbor_sums_cuda * randoms.to(device)).sum().backward() (neighbor_sums_cpu * randoms).sum().backward() (neighbor_sums * randoms).sum().backward()