diff --git a/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu b/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu index 11044424..e10e3d67 100644 --- a/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu +++ b/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu @@ -6,18 +6,18 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include -#include #include #include template __global__ void SigmoidAlphaBlendForwardKernel( // clang-format off - const torch::PackedTensorAccessor64 distances, // (N, H, W, K) - const torch::PackedTensorAccessor64 pix_to_face, // (N, H, W, K) - torch::PackedTensorAccessor64 alphas, // (N, H, W) + const at::PackedTensorAccessor64 distances, // (N, H, W, K) + const at::PackedTensorAccessor64 pix_to_face, // (N, H, W, K) + at::PackedTensorAccessor64 alphas, // (N, H, W) // clang-format on const scalar_t sigma, const int N, @@ -67,7 +67,7 @@ __global__ void SigmoidAlphaBlendForwardKernel( } } -torch::Tensor SigmoidAlphaBlendForwardCuda( +at::Tensor SigmoidAlphaBlendForwardCuda( const at::Tensor& distances, // (N, H, W, K) const at::Tensor& pix_to_face, // (N, H, W, K) const float sigma) { @@ -99,9 +99,9 @@ torch::Tensor SigmoidAlphaBlendForwardCuda( distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] { // clang-format off SigmoidAlphaBlendForwardKernel<<>>( - distances.packed_accessor64(), - pix_to_face.packed_accessor64(), - alphas.packed_accessor64(), + distances.packed_accessor64(), + pix_to_face.packed_accessor64(), + alphas.packed_accessor64(), sigma, N, H, @@ -117,11 +117,11 @@ torch::Tensor SigmoidAlphaBlendForwardCuda( template __global__ void SigmoidAlphaBlendBackwardKernel( // clang-format off - const torch::PackedTensorAccessor64 grad_alphas, // (N, H, W) - const torch::PackedTensorAccessor64 alphas, // (N, H, W) - const torch::PackedTensorAccessor64 distances, // (N, H, W, K) - const torch::PackedTensorAccessor64 pix_to_face, // (N, H, W, K) - torch::PackedTensorAccessor64 grad_distances, // (N, H, W) + const at::PackedTensorAccessor64 grad_alphas, // (N, H, W) + const at::PackedTensorAccessor64 alphas, // (N, H, W) + const at::PackedTensorAccessor64 distances, // (N, H, W, K) + const at::PackedTensorAccessor64 pix_to_face, // (N, H, W, K) + at::PackedTensorAccessor64 grad_distances, // (N, H, W) // clang-format on const scalar_t sigma, const int N, @@ -162,7 +162,7 @@ __global__ void SigmoidAlphaBlendBackwardKernel( } } -torch::Tensor SigmoidAlphaBlendBackwardCuda( +at::Tensor SigmoidAlphaBlendBackwardCuda( const at::Tensor& grad_alphas, // (N, H, W) const at::Tensor& alphas, // (N, H, W) const at::Tensor& distances, // (N, H, W, K) @@ -195,20 +195,20 @@ torch::Tensor SigmoidAlphaBlendBackwardCuda( AT_DISPATCH_FLOATING_TYPES( distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] { - SigmoidAlphaBlendBackwardKernel - <<>>( - // clang-format off - grad_alphas.packed_accessor64(), - alphas.packed_accessor64(), - distances.packed_accessor64(), - pix_to_face.packed_accessor64(), - grad_distances.packed_accessor64(), - // clang-format on - sigma, - N, - H, - W, - K); + SigmoidAlphaBlendBackwardKernel< + scalar_t><<>>( + // clang-format off + grad_alphas.packed_accessor64(), + alphas.packed_accessor64(), + distances.packed_accessor64(), + pix_to_face.packed_accessor64(), + grad_distances.packed_accessor64(), + // clang-format on + sigma, + N, + H, + W, + K); })); AT_CUDA_CHECK(cudaGetLastError());