avoid using torch/extension.h in cuda

Summary:
Use aten instead of torch interface in all cuda code. This allows the cuda build to work with pytorch 1.5 with GCC 5 (e.g. the compiler of ubuntu 16.04LTS). This wasn't working. It has been failing with errors like the below, perhaps due to a bug in nvcc.

```
torch/include/torch/csrc/api/include/torch/nn/cloneable.h:68:61: error: invalid static_cast from type ‘const torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >’ to type ‘torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >
```

Reviewed By: nikhilaravi

Differential Revision: D21204029

fbshipit-source-id: ca6bdbcecf42493365e1c23a33fe35e1759fe8b6
This commit is contained in:
Jeremy Reizenstein
2020-04-23 10:22:57 -07:00
committed by Facebook GitHub Bot
parent 54b482bd66
commit 85c396f822
9 changed files with 245 additions and 245 deletions

View File

@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h>
#include <cuda.h>
#include <cuda_runtime.h>
@@ -12,10 +13,10 @@
// Currently, support is for floats only.
__global__ void alphaCompositeCudaForwardKernel(
// clang-format off
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> result,
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
@@ -61,12 +62,12 @@ __global__ void alphaCompositeCudaForwardKernel(
// Currently, support is for floats only.
__global__ void alphaCompositeCudaBackwardKernel(
// clang-format off
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> grad_features,
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_alphas,
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_outputs,
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
@@ -131,16 +132,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
}
}
torch::Tensor alphaCompositeCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
at::Tensor alphaCompositeCudaForward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
auto result = torch::zeros({batch_size, C, H, W}, features.options());
auto result = at::zeros({batch_size, C, H, W}, features.options());
const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
@@ -149,22 +150,22 @@ torch::Tensor alphaCompositeCudaForward(
// doubles. Currently, support is for floats only.
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
return result;
}
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
auto grad_features = torch::zeros_like(features);
auto grad_alphas = torch::zeros_like(alphas);
std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
const at::Tensor& grad_outputs,
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
auto grad_features = at::zeros_like(features);
auto grad_alphas = at::zeros_like(alphas);
const int64_t bs = alphas.size(0);
@@ -175,12 +176,12 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
// doubles. Currently, support is for floats only.
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
return std::make_tuple(grad_features, grad_alphas);