Avoid torch/extension.h in cuda

Summary: Unlike other cu files, sigmoid_alpha_blend uses torch/extension.h. Avoid for possible build speed win and because of a reported problem #843 on windows with CUDA 11.4.

Reviewed By: nikhilaravi

Differential Revision: D31054121

fbshipit-source-id: 53a1f985a1695a044dfd2ee1a5b0adabdf280595
This commit is contained in:
Jeremy Reizenstein 2021-09-22 15:52:32 -07:00 committed by Facebook GitHub Bot
parent fe5bfa5994
commit cb170ac024

View File

@ -6,18 +6,18 @@
* LICENSE file in the root directory of this source tree. * LICENSE file in the root directory of this source tree.
*/ */
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
template <typename scalar_t> template <typename scalar_t>
__global__ void SigmoidAlphaBlendForwardKernel( __global__ void SigmoidAlphaBlendForwardKernel(
// clang-format off // clang-format off
const torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> distances, // (N, H, W, K) const at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> distances, // (N, H, W, K)
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> pix_to_face, // (N, H, W, K) const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> alphas, // (N, H, W) at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> alphas, // (N, H, W)
// clang-format on // clang-format on
const scalar_t sigma, const scalar_t sigma,
const int N, const int N,
@ -67,7 +67,7 @@ __global__ void SigmoidAlphaBlendForwardKernel(
} }
} }
torch::Tensor SigmoidAlphaBlendForwardCuda( at::Tensor SigmoidAlphaBlendForwardCuda(
const at::Tensor& distances, // (N, H, W, K) const at::Tensor& distances, // (N, H, W, K)
const at::Tensor& pix_to_face, // (N, H, W, K) const at::Tensor& pix_to_face, // (N, H, W, K)
const float sigma) { const float sigma) {
@ -99,9 +99,9 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] { distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] {
// clang-format off // clang-format off
SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>( SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(), distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
pix_to_face.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>(), pix_to_face.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>(),
alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(), alphas.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
sigma, sigma,
N, N,
H, H,
@ -117,11 +117,11 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
template <typename scalar_t> template <typename scalar_t>
__global__ void SigmoidAlphaBlendBackwardKernel( __global__ void SigmoidAlphaBlendBackwardKernel(
// clang-format off // clang-format off
const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> grad_alphas, // (N, H, W) const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> grad_alphas, // (N, H, W)
const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> alphas, // (N, H, W) const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> alphas, // (N, H, W)
const torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> distances, // (N, H, W, K) const at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> distances, // (N, H, W, K)
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> pix_to_face, // (N, H, W, K) const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> grad_distances, // (N, H, W) at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> grad_distances, // (N, H, W)
// clang-format on // clang-format on
const scalar_t sigma, const scalar_t sigma,
const int N, const int N,
@ -162,7 +162,7 @@ __global__ void SigmoidAlphaBlendBackwardKernel(
} }
} }
torch::Tensor SigmoidAlphaBlendBackwardCuda( at::Tensor SigmoidAlphaBlendBackwardCuda(
const at::Tensor& grad_alphas, // (N, H, W) const at::Tensor& grad_alphas, // (N, H, W)
const at::Tensor& alphas, // (N, H, W) const at::Tensor& alphas, // (N, H, W)
const at::Tensor& distances, // (N, H, W, K) const at::Tensor& distances, // (N, H, W, K)
@ -195,20 +195,20 @@ torch::Tensor SigmoidAlphaBlendBackwardCuda(
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] { distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] {
SigmoidAlphaBlendBackwardKernel<scalar_t> SigmoidAlphaBlendBackwardKernel<
<<<blocks, threads, 0, stream>>>( scalar_t><<<blocks, threads, 0, stream>>>(
// clang-format off // clang-format off
grad_alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(), grad_alphas.packed_accessor64<scalar_t, 3,at::RestrictPtrTraits>(),
alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(), alphas.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(), distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
pix_to_face.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>(), pix_to_face.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>(),
grad_distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(), grad_distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
// clang-format on // clang-format on
sigma, sigma,
N, N,
H, H,
W, W,
K); K);
})); }));
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());