mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-12 07:15:58 +08:00
Cuda updates
Summary: Updates to: - enable cuda kernel launches on any GPU (not just the default) - cuda and contiguous checks for all kernels - checks to ensure all tensors are on the same device - error reporting in the cuda kernels - cuda tests now run on a random device not just the default Reviewed By: jcjohnson, gkioxari Differential Revision: D21215280 fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c9267ab7af
commit
c3d636dc8c
@@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <thrust/tuple.h>
|
||||
@@ -285,14 +287,14 @@ RasterizeMeshesNaiveCuda(
|
||||
const int num_closest,
|
||||
const bool perspective_correct,
|
||||
const bool cull_backfaces) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
}
|
||||
if (num_faces_per_mesh.size(0) != mesh_to_faces_packed_first_idx.size(0)) {
|
||||
AT_ERROR(
|
||||
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
face_verts.size(2) == 3,
|
||||
"face_verts must have dimensions (num_faces, 3, 3)");
|
||||
|
||||
TORCH_CHECK(
|
||||
num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
|
||||
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
|
||||
|
||||
if (num_closest > kMaxPointsPerPixel) {
|
||||
std::stringstream ss;
|
||||
@@ -300,6 +302,20 @@ RasterizeMeshesNaiveCuda(
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
mesh_to_faces_packed_first_idx_t{
|
||||
mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
|
||||
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
|
||||
at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
|
||||
at::checkAllSameGPU(
|
||||
c,
|
||||
{face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int N = num_faces_per_mesh.size(0); // batch size.
|
||||
const int H = image_size; // Assume square images.
|
||||
const int W = image_size;
|
||||
@@ -313,10 +329,15 @@ RasterizeMeshesNaiveCuda(
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
|
||||
|
||||
if (face_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
RasterizeMeshesNaiveCudaKernel<<<blocks, threads>>>(
|
||||
RasterizeMeshesNaiveCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
|
||||
@@ -332,6 +353,7 @@ RasterizeMeshesNaiveCuda(
|
||||
pix_dists.contiguous().data_ptr<float>(),
|
||||
bary.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||
}
|
||||
|
||||
@@ -465,6 +487,22 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
const at::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const at::Tensor& grad_dists, // (N, H, W, K)
|
||||
const bool perspective_correct) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
pix_to_face_t{pix_to_face, "pix_to_face", 2},
|
||||
grad_zbuf_t{grad_zbuf, "grad_zbuf", 3},
|
||||
grad_bary_t{grad_bary, "grad_bary", 4},
|
||||
grad_dists_t{grad_dists, "grad_dists", 5};
|
||||
at::CheckedFrom c = "RasterizeMeshesBackwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {face_verts_t, pix_to_face_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
|
||||
at::checkAllSameType(
|
||||
c, {face_verts_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int F = face_verts.size(0);
|
||||
const int N = pix_to_face.size(0);
|
||||
const int H = pix_to_face.size(1);
|
||||
@@ -472,10 +510,16 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
const int K = pix_to_face.size(3);
|
||||
|
||||
at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options());
|
||||
|
||||
if (grad_face_verts.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_face_verts;
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
RasterizeMeshesBackwardCudaKernel<<<blocks, threads>>>(
|
||||
RasterizeMeshesBackwardCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
pix_to_face.contiguous().data_ptr<int64_t>(),
|
||||
perspective_correct,
|
||||
@@ -488,6 +532,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_face_verts.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_face_verts;
|
||||
}
|
||||
|
||||
@@ -626,10 +671,24 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int max_faces_per_bin) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
face_verts.size(2) == 3,
|
||||
"face_verts must have dimensions (num_faces, 3, 3)");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
mesh_to_face_first_idx_t{
|
||||
mesh_to_face_first_idx, "mesh_to_face_first_idx", 2},
|
||||
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
|
||||
at::CheckedFrom c = "RasterizeMeshesCoarseCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int W = image_size;
|
||||
const int H = image_size;
|
||||
const int F = face_verts.size(0);
|
||||
@@ -645,12 +704,18 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
auto opts = face_verts.options().dtype(at::kInt);
|
||||
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
|
||||
if (bin_faces.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_faces;
|
||||
}
|
||||
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
const size_t threads = 512;
|
||||
|
||||
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size>>>(
|
||||
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
|
||||
@@ -664,6 +729,8 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
M,
|
||||
faces_per_bin.contiguous().data_ptr<int32_t>(),
|
||||
bin_faces.contiguous().data_ptr<int32_t>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_faces;
|
||||
}
|
||||
|
||||
@@ -775,13 +842,22 @@ RasterizeMeshesFineCuda(
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool cull_backfaces) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
}
|
||||
if (bin_faces.ndimension() != 4) {
|
||||
AT_ERROR("bin_faces must have 4 dimensions");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
face_verts.size(2) == 3,
|
||||
"face_verts must have dimensions (num_faces, 3, 3)");
|
||||
TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
bin_faces_t{bin_faces, "bin_faces", 2};
|
||||
at::CheckedFrom c = "RasterizeMeshesFineCuda";
|
||||
at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int N = bin_faces.size(0);
|
||||
const int B = bin_faces.size(1);
|
||||
const int M = bin_faces.size(3);
|
||||
@@ -790,7 +866,7 @@ RasterizeMeshesFineCuda(
|
||||
const int W = image_size;
|
||||
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 8");
|
||||
AT_ERROR("Must have num_closest <= 150");
|
||||
}
|
||||
auto long_opts = face_verts.options().dtype(at::kLong);
|
||||
auto float_opts = face_verts.options().dtype(at::kFloat);
|
||||
@@ -800,10 +876,15 @@ RasterizeMeshesFineCuda(
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
|
||||
|
||||
if (face_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
RasterizeMeshesFineCudaKernel<<<blocks, threads>>>(
|
||||
RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
bin_faces.contiguous().data_ptr<int32_t>(),
|
||||
blur_radius,
|
||||
|
||||
Reference in New Issue
Block a user