avoid using torch/extension.h in cuda

Summary:
Use aten instead of torch interface in all cuda code. This allows the cuda build to work with pytorch 1.5 with GCC 5 (e.g. the compiler of ubuntu 16.04LTS). This wasn't working. It has been failing with errors like the below, perhaps due to a bug in nvcc.

```
torch/include/torch/csrc/api/include/torch/nn/cloneable.h:68:61: error: invalid static_cast from type ‘const torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >’ to type ‘torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >
```

Reviewed By: nikhilaravi

Differential Revision: D21204029

fbshipit-source-id: ca6bdbcecf42493365e1c23a33fe35e1759fe8b6
This commit is contained in:
Jeremy Reizenstein
2020-04-23 10:22:57 -07:00
committed by Facebook GitHub Bot
parent 54b482bd66
commit 85c396f822
9 changed files with 245 additions and 245 deletions

View File

@@ -1,9 +1,9 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <float.h>
#include <math.h>
#include <thrust/tuple.h>
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "rasterize_points/bitmask.cuh"
@@ -275,11 +275,11 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesNaiveCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_faces_packed_first_idx,
const torch::Tensor& num_faces_per_mesh,
const at::Tensor& face_verts,
const at::Tensor& mesh_to_faces_packed_first_idx,
const at::Tensor& num_faces_per_mesh,
const int image_size,
const float blur_radius,
const int num_closest,
@@ -305,13 +305,13 @@ RasterizeMeshesNaiveCuda(
const int W = image_size;
const int K = num_closest;
auto long_opts = face_verts.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
auto long_opts = face_verts.options().dtype(at::kLong);
auto float_opts = face_verts.options().dtype(at::kFloat);
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
at::Tensor face_idxs = at::full({N, H, W, K}, -1, long_opts);
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
@@ -458,12 +458,12 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
}
}
torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& face_verts, // (F, 3, 3)
const torch::Tensor& pix_to_face, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
at::Tensor RasterizeMeshesBackwardCuda(
const at::Tensor& face_verts, // (F, 3, 3)
const at::Tensor& pix_to_face, // (N, H, W, K)
const at::Tensor& grad_zbuf, // (N, H, W, K)
const at::Tensor& grad_bary, // (N, H, W, K, 3)
const at::Tensor& grad_dists, // (N, H, W, K)
const bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
@@ -471,7 +471,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const int W = pix_to_face.size(2);
const int K = pix_to_face.size(3);
torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options());
const size_t blocks = 1024;
const size_t threads = 64;
@@ -618,10 +618,10 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
}
}
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
at::Tensor RasterizeMeshesCoarseCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
const int image_size,
const float blur_radius,
const int bin_size,
@@ -642,9 +642,9 @@ torch::Tensor RasterizeMeshesCoarseCuda(
ss << "Got " << num_bins << "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = face_verts.options().dtype(torch::kInt32);
torch::Tensor faces_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
torch::Tensor bin_faces = torch::full({N, num_bins, num_bins, M}, -1, opts);
auto opts = face_verts.options().dtype(at::kInt);
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64;
@@ -765,10 +765,10 @@ __global__ void RasterizeMeshesFineCudaKernel(
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const at::Tensor& face_verts,
const at::Tensor& bin_faces,
const int image_size,
const float blur_radius,
const int bin_size,
@@ -792,13 +792,13 @@ RasterizeMeshesFineCuda(
if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8");
}
auto long_opts = face_verts.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
auto long_opts = face_verts.options().dtype(at::kLong);
auto float_opts = face_verts.options().dtype(at::kFloat);
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
at::Tensor face_idxs = at::full({N, H, W, K}, -1, long_opts);
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;