mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
gather_scatter on CPU
Summary: CPU implementation of the graph convolution op. Reviewed By: nikhilaravi, gkioxari Differential Revision: D21384361 fbshipit-source-id: bc96730e9727bb9aa1b0a232dcb82f0c0d12fe6b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4872a2c4de
commit
7944d24d48
@@ -44,8 +44,8 @@ __global__ void GatherScatterCudaKernel(
|
||||
}
|
||||
|
||||
at::Tensor GatherScatterCuda(
|
||||
const at::Tensor input,
|
||||
const at::Tensor edges,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& edges,
|
||||
bool directed,
|
||||
bool backward) {
|
||||
// Check inputs are on the same device
|
||||
|
||||
@@ -20,17 +20,22 @@
|
||||
// Returns:
|
||||
// output: float32 Tensor of same shape as input.
|
||||
|
||||
// Cuda implementation.
|
||||
at::Tensor GatherScatterCuda(
|
||||
const at::Tensor input,
|
||||
const at::Tensor edges,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& edges,
|
||||
bool directed,
|
||||
bool backward);
|
||||
|
||||
at::Tensor GatherScatterCpu(
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& edges,
|
||||
bool directed,
|
||||
bool backward);
|
||||
|
||||
// Exposed implementation.
|
||||
at::Tensor GatherScatter(
|
||||
const at::Tensor input,
|
||||
const at::Tensor edges,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& edges,
|
||||
bool directed,
|
||||
bool backward) {
|
||||
if (input.is_cuda() && edges.is_cuda()) {
|
||||
@@ -42,5 +47,5 @@ at::Tensor GatherScatter(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
return GatherScatterCpu(input, edges, directed, backward);
|
||||
}
|
||||
|
||||
35
pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp
Normal file
35
pytorch3d/csrc/gather_scatter/gather_scatter_cpu.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
at::Tensor GatherScatterCpu(
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& edges,
|
||||
bool directed,
|
||||
bool backward) {
|
||||
const auto num_vertices = input.size(0);
|
||||
const auto input_feature_dim = input.size(1);
|
||||
const auto num_edges = edges.size(0);
|
||||
|
||||
auto output = at::zeros({num_vertices, input_feature_dim}, input.options());
|
||||
|
||||
auto input_a = input.accessor<float, 2>();
|
||||
auto edges_a = edges.accessor<int64_t, 2>();
|
||||
auto output_a = output.accessor<float, 2>();
|
||||
const int v0_idx = backward ? 1 : 0;
|
||||
const int v1_idx = backward ? 0 : 1;
|
||||
|
||||
for (int e = 0; e < num_edges; ++e) {
|
||||
// Get indices of vertices which form the edge.
|
||||
const int64_t v0 = edges_a[e][v0_idx];
|
||||
const int64_t v1 = edges_a[e][v1_idx];
|
||||
|
||||
for (int d = 0; d < input_feature_dim; ++d) {
|
||||
output_a[v0][d] += input_a[v1][d];
|
||||
if (!directed) {
|
||||
output_a[v1][d] += input_a[v0][d];
|
||||
}
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
Reference in New Issue
Block a user