torch C API warnings

Summary: This is mostly replacing the old PackedTensorAccessor with the new PackedTensorAccessor64.

Reviewed By: gkioxari

Differential Revision: D21088773

fbshipit-source-id: 5973e5a29d934eafb7c70ec5ec154ca076b64d27
This commit is contained in:
Jeremy Reizenstein 2020-04-17 10:37:10 -07:00 committed by Facebook GitHub Bot
parent f25af96959
commit 9397cd872d
6 changed files with 60 additions and 63 deletions

View File

@ -12,10 +12,10 @@
// Currently, support is for floats only.
__global__ void alphaCompositeCudaForwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
@ -61,12 +61,12 @@ __global__ void alphaCompositeCudaForwardKernel(
// Currently, support is for floats only.
__global__ void alphaCompositeCudaBackwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
@ -149,10 +149,10 @@ torch::Tensor alphaCompositeCudaForward(
// doubles. Currently, support is for floats only.
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
// clang-format on
return result;
@ -175,12 +175,12 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
// doubles. Currently, support is for floats only.
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
// clang-format on
return std::make_tuple(grad_features, grad_alphas);

View File

@ -14,10 +14,10 @@ __constant__ const float kEpsilon = 1e-4;
// Currently, support is for floats only.
__global__ void weightedSumNormCudaForwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
@ -76,12 +76,12 @@ __global__ void weightedSumNormCudaForwardKernel(
// Currently, support is for floats only.
__global__ void weightedSumNormCudaBackwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
@ -164,10 +164,10 @@ torch::Tensor weightedSumNormCudaForward(
// doubles. Currently, support is for floats only.
// clang-format off
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
// clang-format on
return result;
@ -190,12 +190,12 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
// doubles. Currently, support is for floats only.
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
// clang-format on
return std::make_tuple(grad_features, grad_alphas);

View File

@ -12,10 +12,10 @@
// Currently, support is for floats only.
__global__ void weightedSumCudaForwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
@ -58,12 +58,12 @@ __global__ void weightedSumCudaForwardKernel(
// Currently, support is for floats only.
__global__ void weightedSumCudaBackwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
@ -123,10 +123,10 @@ torch::Tensor weightedSumCudaForward(
// doubles. Currently, support is for floats only.
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
// clang-format on
return result;
@ -149,12 +149,12 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
// doubles. Currently, support is for floats only.
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
// clang-format on
return std::make_tuple(grad_features, grad_alphas);

View File

@ -6,7 +6,6 @@
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCpu(
const at::Tensor verts,
const at::Tensor faces) {
const int V = verts.size(0);
const int F = faces.size(0);
at::Tensor areas = at::empty({F}, verts.options());

View File

@ -38,7 +38,6 @@ at::Tensor PaddedToPackedCpu(
const at::Tensor first_idxs,
const int64_t num_inputs) {
const int64_t batch_size = inputs_padded.size(0);
const int64_t max_size = inputs_padded.size(1);
AT_ASSERTM(
inputs_padded.dim() == 3, "inputs_padded must be a 3-dimensional tensor");

View File

@ -415,7 +415,6 @@ torch::Tensor RasterizeMeshesCoarseCpu(
auto opts = face_verts.options().dtype(torch::kInt32);
torch::Tensor faces_per_bin = torch::zeros({N, BH, BW}, opts);
torch::Tensor bin_faces = torch::full({N, BH, BW, M}, -1, opts);
auto faces_per_bin_a = faces_per_bin.accessor<int32_t, 3>();
auto bin_faces_a = bin_faces.accessor<int32_t, 4>();
// Precompute all face bounding boxes.