mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Avoid torch/extension.h in cuda
Summary: Unlike other cu files, sigmoid_alpha_blend uses torch/extension.h. Avoid for possible build speed win and because of a reported problem #843 on windows with CUDA 11.4. Reviewed By: nikhilaravi Differential Revision: D31054121 fbshipit-source-id: 53a1f985a1695a044dfd2ee1a5b0adabdf280595
This commit is contained in:
parent
fe5bfa5994
commit
cb170ac024
@ -6,18 +6,18 @@
|
|||||||
* LICENSE file in the root directory of this source tree.
|
* LICENSE file in the root directory of this source tree.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <torch/extension.h>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void SigmoidAlphaBlendForwardKernel(
|
__global__ void SigmoidAlphaBlendForwardKernel(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
const torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> distances, // (N, H, W, K)
|
const at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> distances, // (N, H, W, K)
|
||||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
|
||||||
torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> alphas, // (N, H, W)
|
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> alphas, // (N, H, W)
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const scalar_t sigma,
|
const scalar_t sigma,
|
||||||
const int N,
|
const int N,
|
||||||
@ -67,7 +67,7 @@ __global__ void SigmoidAlphaBlendForwardKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor SigmoidAlphaBlendForwardCuda(
|
at::Tensor SigmoidAlphaBlendForwardCuda(
|
||||||
const at::Tensor& distances, // (N, H, W, K)
|
const at::Tensor& distances, // (N, H, W, K)
|
||||||
const at::Tensor& pix_to_face, // (N, H, W, K)
|
const at::Tensor& pix_to_face, // (N, H, W, K)
|
||||||
const float sigma) {
|
const float sigma) {
|
||||||
@ -99,9 +99,9 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
|
|||||||
distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] {
|
distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
|
distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
|
||||||
pix_to_face.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>(),
|
pix_to_face.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>(),
|
||||||
alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
|
alphas.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
|
||||||
sigma,
|
sigma,
|
||||||
N,
|
N,
|
||||||
H,
|
H,
|
||||||
@ -117,11 +117,11 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
|
|||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void SigmoidAlphaBlendBackwardKernel(
|
__global__ void SigmoidAlphaBlendBackwardKernel(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> grad_alphas, // (N, H, W)
|
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> grad_alphas, // (N, H, W)
|
||||||
const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> alphas, // (N, H, W)
|
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> alphas, // (N, H, W)
|
||||||
const torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> distances, // (N, H, W, K)
|
const at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> distances, // (N, H, W, K)
|
||||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
|
||||||
torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> grad_distances, // (N, H, W)
|
at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> grad_distances, // (N, H, W)
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const scalar_t sigma,
|
const scalar_t sigma,
|
||||||
const int N,
|
const int N,
|
||||||
@ -162,7 +162,7 @@ __global__ void SigmoidAlphaBlendBackwardKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor SigmoidAlphaBlendBackwardCuda(
|
at::Tensor SigmoidAlphaBlendBackwardCuda(
|
||||||
const at::Tensor& grad_alphas, // (N, H, W)
|
const at::Tensor& grad_alphas, // (N, H, W)
|
||||||
const at::Tensor& alphas, // (N, H, W)
|
const at::Tensor& alphas, // (N, H, W)
|
||||||
const at::Tensor& distances, // (N, H, W, K)
|
const at::Tensor& distances, // (N, H, W, K)
|
||||||
@ -195,20 +195,20 @@ torch::Tensor SigmoidAlphaBlendBackwardCuda(
|
|||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES(
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] {
|
distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] {
|
||||||
SigmoidAlphaBlendBackwardKernel<scalar_t>
|
SigmoidAlphaBlendBackwardKernel<
|
||||||
<<<blocks, threads, 0, stream>>>(
|
scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
grad_alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
|
grad_alphas.packed_accessor64<scalar_t, 3,at::RestrictPtrTraits>(),
|
||||||
alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
|
alphas.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
|
||||||
distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
|
distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
|
||||||
pix_to_face.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>(),
|
pix_to_face.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>(),
|
||||||
grad_distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
|
grad_distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
|
||||||
// clang-format on
|
// clang-format on
|
||||||
sigma,
|
sigma,
|
||||||
N,
|
N,
|
||||||
H,
|
H,
|
||||||
W,
|
W,
|
||||||
K);
|
K);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user