mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-03 18:55:59 +08:00
avoid using torch/extension.h in cuda
Summary: Use aten instead of torch interface in all cuda code. This allows the cuda build to work with pytorch 1.5 with GCC 5 (e.g. the compiler of ubuntu 16.04LTS). This wasn't working. It has been failing with errors like the below, perhaps due to a bug in nvcc. ``` torch/include/torch/csrc/api/include/torch/nn/cloneable.h:68:61: error: invalid static_cast from type ‘const torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >’ to type ‘torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> > ``` Reviewed By: nikhilaravi Differential Revision: D21204029 fbshipit-source-id: ca6bdbcecf42493365e1c23a33fe35e1759fe8b6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
54b482bd66
commit
85c396f822
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <math.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
@@ -138,11 +138,10 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizePointsNaiveCuda(
|
||||
const torch::Tensor& points, // (P. 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
const at::Tensor& points, // (P. 3)
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel) {
|
||||
@@ -164,11 +163,11 @@ RasterizePointsNaiveCuda(
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
|
||||
auto int_opts = points.options().dtype(torch::kInt32);
|
||||
auto float_opts = points.options().dtype(torch::kFloat32);
|
||||
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
|
||||
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
|
||||
auto int_opts = points.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
@@ -316,10 +315,10 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RasterizePointsCoarseCuda(
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
at::Tensor RasterizePointsCoarseCuda(
|
||||
const at::Tensor& points, // (P, 3)
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
@@ -338,9 +337,9 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
ss << "Got " << num_bins << "; that's too many!";
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto opts = points.options().dtype(torch::kInt32);
|
||||
torch::Tensor points_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
|
||||
torch::Tensor bin_points = torch::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
auto opts = points.options().dtype(at::kInt);
|
||||
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
@@ -442,9 +441,9 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& bin_points,
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
const at::Tensor& points, // (P, 3)
|
||||
const at::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
@@ -457,11 +456,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 8");
|
||||
}
|
||||
auto int_opts = points.options().dtype(torch::kInt32);
|
||||
auto float_opts = points.options().dtype(torch::kFloat32);
|
||||
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
|
||||
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
|
||||
auto int_opts = points.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
@@ -533,18 +532,18 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RasterizePointsBackwardCuda(
|
||||
const torch::Tensor& points, // (N, P, 3)
|
||||
const torch::Tensor& idxs, // (N, H, W, K)
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_dists) { // (N, H, W, K)
|
||||
at::Tensor RasterizePointsBackwardCuda(
|
||||
const at::Tensor& points, // (N, P, 3)
|
||||
const at::Tensor& idxs, // (N, H, W, K)
|
||||
const at::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const at::Tensor& grad_dists) { // (N, H, W, K)
|
||||
const int P = points.size(0);
|
||||
const int N = idxs.size(0);
|
||||
const int H = idxs.size(1);
|
||||
const int W = idxs.size(2);
|
||||
const int K = idxs.size(3);
|
||||
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user