avoid using torch/extension.h in cuda

Summary:
Use aten instead of torch interface in all cuda code. This allows the cuda build to work with pytorch 1.5 with GCC 5 (e.g. the compiler of ubuntu 16.04LTS). This wasn't working. It has been failing with errors like the below, perhaps due to a bug in nvcc.

```
torch/include/torch/csrc/api/include/torch/nn/cloneable.h:68:61: error: invalid static_cast from type ‘const torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >’ to type ‘torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >
```

Reviewed By: nikhilaravi

Differential Revision: D21204029

fbshipit-source-id: ca6bdbcecf42493365e1c23a33fe35e1759fe8b6
This commit is contained in:
Jeremy Reizenstein
2020-04-23 10:22:57 -07:00
committed by Facebook GitHub Bot
parent 54b482bd66
commit 85c396f822
9 changed files with 245 additions and 245 deletions

View File

@@ -1,6 +1,6 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <algorithm>
#include <list>
#include <queue>
@@ -97,11 +97,11 @@ __global__ void PointEdgeForwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& points_first_idx,
const at::Tensor& segms,
const at::Tensor& segms_first_idx,
const int64_t max_points) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
@@ -114,8 +114,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
AT_ASSERTM(segms_first_idx.size(0) == B);
// clang-format off
torch::Tensor dists = torch::zeros({P,}, points.options());
torch::Tensor idxs = torch::zeros({P,}, points_first_idx.options());
at::Tensor dists = at::zeros({P,}, points.options());
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
// clang-format on
const int threads = 128;
@@ -178,11 +178,11 @@ __global__ void PointEdgeBackwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& segms,
const at::Tensor& idx_points,
const at::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
@@ -194,8 +194,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
AT_ASSERTM(grad_dists.size(0) == P);
// clang-format off
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
// clang-format on
const int blocks = 64;
@@ -302,11 +302,11 @@ __global__ void EdgePointForwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& points_first_idx,
const at::Tensor& segms,
const at::Tensor& segms_first_idx,
const int64_t max_segms) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
@@ -319,8 +319,8 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
AT_ASSERTM(segms_first_idx.size(0) == B);
// clang-format off
torch::Tensor dists = torch::zeros({S,}, segms.options());
torch::Tensor idxs = torch::zeros({S,}, segms_first_idx.options());
at::Tensor dists = at::zeros({S,}, segms.options());
at::Tensor idxs = at::zeros({S,}, segms_first_idx.options());
// clang-format on
const int threads = 128;
@@ -384,11 +384,11 @@ __global__ void EdgePointBackwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_segms,
const torch::Tensor& grad_dists) {
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& segms,
const at::Tensor& idx_segms,
const at::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
@@ -400,8 +400,8 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
AT_ASSERTM(grad_dists.size(0) == S);
// clang-format off
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
// clang-format on
const int blocks = 64;
@@ -448,9 +448,9 @@ __global__ void PointEdgeArrayForwardKernel(
}
}
torch::Tensor PointEdgeArrayDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms) {
at::Tensor PointEdgeArrayDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& segms) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
@@ -459,7 +459,7 @@ torch::Tensor PointEdgeArrayDistanceForwardCuda(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
torch::Tensor dists = torch::zeros({P, S}, points.options());
at::Tensor dists = at::zeros({P, S}, points.options());
const size_t blocks = 1024;
const size_t threads = 64;
@@ -516,10 +516,10 @@ __global__ void PointEdgeArrayBackwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& grad_dists) {
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& segms,
const at::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
@@ -529,8 +529,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
"segms must be of shape Sx2x3");
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
const size_t blocks = 1024;
const size_t threads = 64;

View File

@@ -1,6 +1,6 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <algorithm>
#include <list>
#include <queue>
@@ -98,11 +98,11 @@ __global__ void PointFaceForwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& points_first_idx,
const at::Tensor& tris,
const at::Tensor& tris_first_idx,
const int64_t max_points) {
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
@@ -115,8 +115,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
AT_ASSERTM(tris_first_idx.size(0) == B);
// clang-format off
torch::Tensor dists = torch::zeros({P,}, points.options());
torch::Tensor idxs = torch::zeros({P,}, points_first_idx.options());
at::Tensor dists = at::zeros({P,}, points.options());
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
// clang-format on
const int threads = 128;
@@ -186,11 +186,11 @@ __global__ void PointFaceBackwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& idx_points,
const at::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
@@ -202,8 +202,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
AT_ASSERTM(grad_dists.size(0) == P);
// clang-format off
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
// clang-format on
const int blocks = 64;
@@ -311,11 +311,11 @@ __global__ void FacePointForwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& points_first_idx,
const at::Tensor& tris,
const at::Tensor& tris_first_idx,
const int64_t max_tris) {
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
@@ -328,8 +328,8 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
AT_ASSERTM(tris_first_idx.size(0) == B);
// clang-format off
torch::Tensor dists = torch::zeros({T,}, tris.options());
torch::Tensor idxs = torch::zeros({T,}, tris_first_idx.options());
at::Tensor dists = at::zeros({T,}, tris.options());
at::Tensor idxs = at::zeros({T,}, tris_first_idx.options());
// clang-format on
const int threads = 128;
@@ -400,11 +400,11 @@ __global__ void FacePointBackwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists) {
std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& idx_tris,
const at::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
@@ -416,8 +416,8 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
AT_ASSERTM(grad_dists.size(0) == T);
// clang-format off
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
// clang-format on
const int blocks = 64;
@@ -465,9 +465,9 @@ __global__ void PointFaceArrayForwardKernel(
}
}
torch::Tensor PointFaceArrayDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris) {
at::Tensor PointFaceArrayDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& tris) {
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
@@ -476,7 +476,7 @@ torch::Tensor PointFaceArrayDistanceForwardCuda(
(tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3");
torch::Tensor dists = torch::zeros({P, T}, points.options());
at::Tensor dists = at::zeros({P, T}, points.options());
const size_t blocks = 1024;
const size_t threads = 64;
@@ -542,10 +542,10 @@ __global__ void PointFaceArrayBackwardKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& grad_dists) {
std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
@@ -555,8 +555,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
"tris must be of shape Tx3x3");
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == T));
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
const size_t blocks = 1024;
const size_t threads = 64;