mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
gather_scatter on CPU
Summary: CPU implementation of the graph convolution op. Reviewed By: nikhilaravi, gkioxari Differential Revision: D21384361 fbshipit-source-id: bc96730e9727bb9aa1b0a232dcb82f0c0d12fe6b
This commit is contained in:
parent
4872a2c4de
commit
7944d24d48
@ -44,8 +44,8 @@ __global__ void GatherScatterCudaKernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor GatherScatterCuda(
|
at::Tensor GatherScatterCuda(
|
||||||
const at::Tensor input,
|
const at::Tensor& input,
|
||||||
const at::Tensor edges,
|
const at::Tensor& edges,
|
||||||
bool directed,
|
bool directed,
|
||||||
bool backward) {
|
bool backward) {
|
||||||
// Check inputs are on the same device
|
// Check inputs are on the same device
|
||||||
|
@ -20,17 +20,22 @@
|
|||||||
// Returns:
|
// Returns:
|
||||||
// output: float32 Tensor of same shape as input.
|
// output: float32 Tensor of same shape as input.
|
||||||
|
|
||||||
// Cuda implementation.
|
|
||||||
at::Tensor GatherScatterCuda(
|
at::Tensor GatherScatterCuda(
|
||||||
const at::Tensor input,
|
const at::Tensor& input,
|
||||||
const at::Tensor edges,
|
const at::Tensor& edges,
|
||||||
|
bool directed,
|
||||||
|
bool backward);
|
||||||
|
|
||||||
|
at::Tensor GatherScatterCpu(
|
||||||
|
const at::Tensor& input,
|
||||||
|
const at::Tensor& edges,
|
||||||
bool directed,
|
bool directed,
|
||||||
bool backward);
|
bool backward);
|
||||||
|
|
||||||
// Exposed implementation.
|
// Exposed implementation.
|
||||||
at::Tensor GatherScatter(
|
at::Tensor GatherScatter(
|
||||||
const at::Tensor input,
|
const at::Tensor& input,
|
||||||
const at::Tensor edges,
|
const at::Tensor& edges,
|
||||||
bool directed,
|
bool directed,
|
||||||
bool backward) {
|
bool backward) {
|
||||||
if (input.is_cuda() && edges.is_cuda()) {
|
if (input.is_cuda() && edges.is_cuda()) {
|
||||||
@ -42,5 +47,5 @@ at::Tensor GatherScatter(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
AT_ERROR("Not implemented on the CPU");
|
return GatherScatterCpu(input, edges, directed, backward);
|
||||||
}
|
}
|
||||||
|
35
pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp
Normal file
35
pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
at::Tensor GatherScatterCpu(
|
||||||
|
const at::Tensor& input,
|
||||||
|
const at::Tensor& edges,
|
||||||
|
bool directed,
|
||||||
|
bool backward) {
|
||||||
|
const auto num_vertices = input.size(0);
|
||||||
|
const auto input_feature_dim = input.size(1);
|
||||||
|
const auto num_edges = edges.size(0);
|
||||||
|
|
||||||
|
auto output = at::zeros({num_vertices, input_feature_dim}, input.options());
|
||||||
|
|
||||||
|
auto input_a = input.accessor<float, 2>();
|
||||||
|
auto edges_a = edges.accessor<int64_t, 2>();
|
||||||
|
auto output_a = output.accessor<float, 2>();
|
||||||
|
const int v0_idx = backward ? 1 : 0;
|
||||||
|
const int v1_idx = backward ? 0 : 1;
|
||||||
|
|
||||||
|
for (int e = 0; e < num_edges; ++e) {
|
||||||
|
// Get indices of vertices which form the edge.
|
||||||
|
const int64_t v0 = edges_a[e][v0_idx];
|
||||||
|
const int64_t v1 = edges_a[e][v1_idx];
|
||||||
|
|
||||||
|
for (int d = 0; d < input_feature_dim; ++d) {
|
||||||
|
output_a[v0][d] += input_a[v1][d];
|
||||||
|
if (!directed) {
|
||||||
|
output_a[v1][d] += input_a[v0][d];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
@ -101,17 +101,24 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
|||||||
mesh = ico_sphere()
|
mesh = ico_sphere()
|
||||||
verts = mesh.verts_packed()
|
verts = mesh.verts_packed()
|
||||||
edges = mesh.edges_packed()
|
edges = mesh.edges_packed()
|
||||||
|
verts_cpu = verts.clone()
|
||||||
|
edges_cpu = edges.clone()
|
||||||
verts_cuda = verts.clone().to(device)
|
verts_cuda = verts.clone().to(device)
|
||||||
edges_cuda = edges.clone().to(device)
|
edges_cuda = edges.clone().to(device)
|
||||||
verts.requires_grad = True
|
verts.requires_grad = True
|
||||||
|
verts_cpu.requires_grad = True
|
||||||
verts_cuda.requires_grad = True
|
verts_cuda.requires_grad = True
|
||||||
|
|
||||||
neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
|
neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
|
||||||
|
neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
|
||||||
neighbor_sums = gather_scatter_python(verts, edges, False)
|
neighbor_sums = gather_scatter_python(verts, edges, False)
|
||||||
neighbor_sums_cuda.sum().backward()
|
randoms = torch.rand_like(neighbor_sums)
|
||||||
neighbor_sums.sum().backward()
|
(neighbor_sums_cuda * randoms.cuda()).sum().backward()
|
||||||
|
(neighbor_sums_cpu * randoms).sum().backward()
|
||||||
|
(neighbor_sums * randoms).sum().backward()
|
||||||
|
|
||||||
self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
|
self.assertClose(verts.grad, verts_cuda.grad.cpu())
|
||||||
|
self.assertClose(verts.grad, verts_cpu.grad)
|
||||||
|
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
conv = GraphConv(32, 64, directed=True)
|
conv = GraphConv(32, 64, directed=True)
|
||||||
@ -141,22 +148,24 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
|||||||
w0 = nn.Linear(3, 1)
|
w0 = nn.Linear(3, 1)
|
||||||
input = w0(verts)
|
input = w0(verts)
|
||||||
|
|
||||||
# output
|
# undirected
|
||||||
output_cpu = gather_scatter_python(input, edges, False)
|
output_python = gather_scatter_python(input, edges, False)
|
||||||
output_cuda = _C.gather_scatter(
|
output_cuda = _C.gather_scatter(
|
||||||
input.to(device=device), edges.to(device=device), False, False
|
input.to(device=device), edges.to(device=device), False, False
|
||||||
)
|
)
|
||||||
self.assertClose(output_cuda.cpu(), output_cpu)
|
self.assertClose(output_cuda.cpu(), output_python)
|
||||||
with self.assertRaises(Exception) as err:
|
|
||||||
_C.gather_scatter(input.cpu(), edges.cpu(), False, False)
|
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
|
||||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
self.assertClose(output_cpu, output_python)
|
||||||
|
|
||||||
# directed
|
# directed
|
||||||
output_cpu = gather_scatter_python(input, edges, True)
|
output_python = gather_scatter_python(input, edges, True)
|
||||||
output_cuda = _C.gather_scatter(
|
output_cuda = _C.gather_scatter(
|
||||||
input.to(device=device), edges.to(device=device), True, False
|
input.to(device=device), edges.to(device=device), True, False
|
||||||
)
|
)
|
||||||
self.assertClose(output_cuda.cpu(), output_cpu)
|
self.assertClose(output_cuda.cpu(), output_python)
|
||||||
|
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
|
||||||
|
self.assertClose(output_cpu, output_python)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def graph_conv_forward_backward(
|
def graph_conv_forward_backward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user