mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-11 14:55:59 +08:00
avoid using torch/extension.h in cuda
Summary: Use aten instead of torch interface in all cuda code. This allows the cuda build to work with pytorch 1.5 with GCC 5 (e.g. the compiler of ubuntu 16.04LTS). This wasn't working. It has been failing with errors like the below, perhaps due to a bug in nvcc. ``` torch/include/torch/csrc/api/include/torch/nn/cloneable.h:68:61: error: invalid static_cast from type ‘const torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >’ to type ‘torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> > ``` Reviewed By: nikhilaravi Differential Revision: D21204029 fbshipit-source-id: ca6bdbcecf42493365e1c23a33fe35e1759fe8b6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
54b482bd66
commit
85c396f822
@@ -1,9 +1,9 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <thrust/tuple.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "rasterize_points/bitmask.cuh"
|
||||
@@ -275,11 +275,11 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
RasterizeMeshesNaiveCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_faces_packed_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_faces_packed_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
@@ -305,13 +305,13 @@ RasterizeMeshesNaiveCuda(
|
||||
const int W = image_size;
|
||||
const int K = num_closest;
|
||||
|
||||
auto long_opts = face_verts.options().dtype(torch::kInt64);
|
||||
auto float_opts = face_verts.options().dtype(torch::kFloat32);
|
||||
auto long_opts = face_verts.options().dtype(at::kLong);
|
||||
auto float_opts = face_verts.options().dtype(at::kFloat);
|
||||
|
||||
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
|
||||
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
|
||||
at::Tensor face_idxs = at::full({N, H, W, K}, -1, long_opts);
|
||||
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
@@ -458,12 +458,12 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
const torch::Tensor& face_verts, // (F, 3, 3)
|
||||
const torch::Tensor& pix_to_face, // (N, H, W, K)
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const torch::Tensor& grad_dists, // (N, H, W, K)
|
||||
at::Tensor RasterizeMeshesBackwardCuda(
|
||||
const at::Tensor& face_verts, // (F, 3, 3)
|
||||
const at::Tensor& pix_to_face, // (N, H, W, K)
|
||||
const at::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const at::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const at::Tensor& grad_dists, // (N, H, W, K)
|
||||
const bool perspective_correct) {
|
||||
const int F = face_verts.size(0);
|
||||
const int N = pix_to_face.size(0);
|
||||
@@ -471,7 +471,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
const int W = pix_to_face.size(2);
|
||||
const int K = pix_to_face.size(3);
|
||||
|
||||
torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
|
||||
at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options());
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
@@ -618,10 +618,10 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RasterizeMeshesCoarseCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
at::Tensor RasterizeMeshesCoarseCuda(
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_face_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
@@ -642,9 +642,9 @@ torch::Tensor RasterizeMeshesCoarseCuda(
|
||||
ss << "Got " << num_bins << "; that's too many!";
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto opts = face_verts.options().dtype(torch::kInt32);
|
||||
torch::Tensor faces_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
|
||||
torch::Tensor bin_faces = torch::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
auto opts = face_verts.options().dtype(at::kInt);
|
||||
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
@@ -765,10 +765,10 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
RasterizeMeshesFineCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& bin_faces,
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& bin_faces,
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
@@ -792,13 +792,13 @@ RasterizeMeshesFineCuda(
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 8");
|
||||
}
|
||||
auto long_opts = face_verts.options().dtype(torch::kInt64);
|
||||
auto float_opts = face_verts.options().dtype(torch::kFloat32);
|
||||
auto long_opts = face_verts.options().dtype(at::kLong);
|
||||
auto float_opts = face_verts.options().dtype(at::kFloat);
|
||||
|
||||
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
|
||||
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
|
||||
at::Tensor face_idxs = at::full({N, H, W, K}, -1, long_opts);
|
||||
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
Reference in New Issue
Block a user