mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-12 15:35:58 +08:00
torch C API warnings
Summary: This is mostly replacing the old PackedTensorAccessor with the new PackedTensorAccessor64. Reviewed By: gkioxari Differential Revision: D21088773 fbshipit-source-id: 5973e5a29d934eafb7c70ec5ec154ca076b64d27
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f25af96959
commit
9397cd872d
@@ -12,10 +12,10 @@
|
||||
// Currently, support is for floats only.
|
||||
__global__ void alphaCompositeCudaForwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
|
||||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
|
||||
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
|
||||
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
|
||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
@@ -61,12 +61,12 @@ __global__ void alphaCompositeCudaForwardKernel(
|
||||
// Currently, support is for floats only.
|
||||
__global__ void alphaCompositeCudaBackwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
|
||||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
|
||||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
|
||||
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
|
||||
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
|
||||
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
|
||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
@@ -149,10 +149,10 @@ torch::Tensor alphaCompositeCudaForward(
|
||||
// doubles. Currently, support is for floats only.
|
||||
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
|
||||
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
return result;
|
||||
@@ -175,12 +175,12 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
|
||||
// doubles. Currently, support is for floats only.
|
||||
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
|
||||
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
|
||||
Reference in New Issue
Block a user