Cuda updates

Summary:
Updates to:
- enable cuda kernel launches on any GPU (not just the default)
- cuda and contiguous checks for all kernels
- checks to ensure all tensors are on the same device
- error reporting in the cuda kernels
- cuda tests now run on a random device not just the default

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21215280

fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
Nikhila Ravi
2020-04-24 09:07:54 -07:00
committed by Facebook GitHub Bot
parent c9267ab7af
commit c3d636dc8c
33 changed files with 979 additions and 240 deletions

View File

@@ -2,6 +2,8 @@
#include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
@@ -136,6 +138,17 @@ at::Tensor alphaCompositeCudaForward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg features_t{features, "features", 1},
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
at::CheckedFrom c = "alphaCompositeCudaForward";
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
@@ -143,19 +156,24 @@ at::Tensor alphaCompositeCudaForward(
auto result = at::zeros({batch_size, C, H, W}, features.options());
if (result.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
@@ -164,9 +182,26 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
points_idx_t{points_idx, "points_idx", 4};
at::CheckedFrom c = "alphaCompositeCudaBackward";
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_features = at::zeros_like(features);
auto grad_alphas = at::zeros_like(alphas);
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}
const int64_t bs = alphas.size(0);
const dim3 threadsPerBlock(64);
@@ -174,7 +209,7 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
@@ -183,6 +218,6 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}