Avoid torch/extension.h in cuda

Summary: Unlike other cu files, sigmoid_alpha_blend uses torch/extension.h. Avoid for possible build speed win and because of a reported problem #843 on windows with CUDA 11.4.

Reviewed By: nikhilaravi

Differential Revision: D31054121

fbshipit-source-id: 53a1f985a1695a044dfd2ee1a5b0adabdf280595
This commit is contained in:
Jeremy Reizenstein 2021-09-22 15:52:32 -07:00 committed by Facebook GitHub Bot
parent fe5bfa5994
commit cb170ac024

View File

@ -6,18 +6,18 @@
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <cmath>
#include <vector>
template <typename scalar_t>
__global__ void SigmoidAlphaBlendForwardKernel(
// clang-format off
const torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> distances, // (N, H, W, K)
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> alphas, // (N, H, W)
const at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> distances, // (N, H, W, K)
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> alphas, // (N, H, W)
// clang-format on
const scalar_t sigma,
const int N,
@ -67,7 +67,7 @@ __global__ void SigmoidAlphaBlendForwardKernel(
}
}
torch::Tensor SigmoidAlphaBlendForwardCuda(
at::Tensor SigmoidAlphaBlendForwardCuda(
const at::Tensor& distances, // (N, H, W, K)
const at::Tensor& pix_to_face, // (N, H, W, K)
const float sigma) {
@ -99,9 +99,9 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] {
// clang-format off
SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
pix_to_face.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
pix_to_face.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>(),
alphas.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
sigma,
N,
H,
@ -117,11 +117,11 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
template <typename scalar_t>
__global__ void SigmoidAlphaBlendBackwardKernel(
// clang-format off
const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> grad_alphas, // (N, H, W)
const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> alphas, // (N, H, W)
const torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> distances, // (N, H, W, K)
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> grad_distances, // (N, H, W)
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> grad_alphas, // (N, H, W)
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> alphas, // (N, H, W)
const at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> distances, // (N, H, W, K)
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> grad_distances, // (N, H, W)
// clang-format on
const scalar_t sigma,
const int N,
@ -162,7 +162,7 @@ __global__ void SigmoidAlphaBlendBackwardKernel(
}
}
torch::Tensor SigmoidAlphaBlendBackwardCuda(
at::Tensor SigmoidAlphaBlendBackwardCuda(
const at::Tensor& grad_alphas, // (N, H, W)
const at::Tensor& alphas, // (N, H, W)
const at::Tensor& distances, // (N, H, W, K)
@ -195,20 +195,20 @@ torch::Tensor SigmoidAlphaBlendBackwardCuda(
AT_DISPATCH_FLOATING_TYPES(
distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] {
SigmoidAlphaBlendBackwardKernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
// clang-format off
grad_alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
pix_to_face.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>(),
grad_distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
// clang-format on
sigma,
N,
H,
W,
K);
SigmoidAlphaBlendBackwardKernel<
scalar_t><<<blocks, threads, 0, stream>>>(
// clang-format off
grad_alphas.packed_accessor64<scalar_t, 3,at::RestrictPtrTraits>(),
alphas.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
pix_to_face.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>(),
grad_distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
// clang-format on
sigma,
N,
H,
W,
K);
}));
AT_CUDA_CHECK(cudaGetLastError());