mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
avoid using torch/extension.h in cuda
Summary: Use aten instead of torch interface in all cuda code. This allows the cuda build to work with pytorch 1.5 with GCC 5 (e.g. the compiler of ubuntu 16.04LTS). This wasn't working. It has been failing with errors like the below, perhaps due to a bug in nvcc. ``` torch/include/torch/csrc/api/include/torch/nn/cloneable.h:68:61: error: invalid static_cast from type ‘const torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >’ to type ‘torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> > ``` Reviewed By: nikhilaravi Differential Revision: D21204029 fbshipit-source-id: ca6bdbcecf42493365e1c23a33fe35e1759fe8b6
This commit is contained in:
parent
54b482bd66
commit
85c396f822
@ -1,6 +1,7 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/TensorAccessor.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
@ -12,10 +13,10 @@
|
||||
// Currently, support is for floats only.
|
||||
__global__ void alphaCompositeCudaForwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
|
||||
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
|
||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
|
||||
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> result,
|
||||
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
@ -61,12 +62,12 @@ __global__ void alphaCompositeCudaForwardKernel(
|
||||
// Currently, support is for floats only.
|
||||
__global__ void alphaCompositeCudaBackwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
|
||||
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
|
||||
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
|
||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
|
||||
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> grad_features,
|
||||
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_alphas,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_outputs,
|
||||
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
@ -131,16 +132,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor alphaCompositeCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
at::Tensor alphaCompositeCudaForward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
auto result = torch::zeros({batch_size, C, H, W}, features.options());
|
||||
auto result = at::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
@ -149,22 +150,22 @@ torch::Tensor alphaCompositeCudaForward(
|
||||
// doubles. Currently, support is for floats only.
|
||||
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
|
||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
auto grad_features = torch::zeros_like(features);
|
||||
auto grad_alphas = torch::zeros_like(alphas);
|
||||
std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
|
||||
const at::Tensor& grad_outputs,
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
auto grad_features = at::zeros_like(features);
|
||||
auto grad_alphas = at::zeros_like(alphas);
|
||||
|
||||
const int64_t bs = alphas.size(0);
|
||||
|
||||
@ -175,12 +176,12 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
|
||||
// doubles. Currently, support is for floats only.
|
||||
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
|
||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/TensorAccessor.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
@ -14,10 +15,10 @@ __constant__ const float kEpsilon = 1e-4;
|
||||
// Currently, support is for floats only.
|
||||
__global__ void weightedSumNormCudaForwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
|
||||
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
|
||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
|
||||
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> result,
|
||||
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
@ -76,12 +77,12 @@ __global__ void weightedSumNormCudaForwardKernel(
|
||||
// Currently, support is for floats only.
|
||||
__global__ void weightedSumNormCudaBackwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
|
||||
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
|
||||
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
|
||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
|
||||
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> grad_features,
|
||||
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_alphas,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_outputs,
|
||||
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
@ -146,16 +147,16 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor weightedSumNormCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
at::Tensor weightedSumNormCudaForward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
auto result = torch::zeros({batch_size, C, H, W}, features.options());
|
||||
auto result = at::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
@ -164,22 +165,22 @@ torch::Tensor weightedSumNormCudaForward(
|
||||
// doubles. Currently, support is for floats only.
|
||||
// clang-format off
|
||||
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
|
||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
auto grad_features = torch::zeros_like(features);
|
||||
auto grad_alphas = torch::zeros_like(alphas);
|
||||
std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
|
||||
const at::Tensor& grad_outputs,
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
auto grad_features = at::zeros_like(features);
|
||||
auto grad_alphas = at::zeros_like(alphas);
|
||||
|
||||
const int64_t bs = points_idx.size(0);
|
||||
|
||||
@ -190,12 +191,12 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
|
||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/TensorAccessor.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
@ -12,10 +13,10 @@
|
||||
// Currently, support is for floats only.
|
||||
__global__ void weightedSumCudaForwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
|
||||
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
|
||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
|
||||
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> result,
|
||||
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
@ -58,12 +59,12 @@ __global__ void weightedSumCudaForwardKernel(
|
||||
// Currently, support is for floats only.
|
||||
__global__ void weightedSumCudaBackwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
|
||||
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
|
||||
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
|
||||
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
|
||||
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
|
||||
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> grad_features,
|
||||
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_alphas,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_outputs,
|
||||
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
@ -105,16 +106,16 @@ __global__ void weightedSumCudaBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor weightedSumCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
at::Tensor weightedSumCudaForward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
auto result = torch::zeros({batch_size, C, H, W}, features.options());
|
||||
auto result = at::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
@ -123,22 +124,22 @@ torch::Tensor weightedSumCudaForward(
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
|
||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
auto grad_features = torch::zeros_like(features);
|
||||
auto grad_alphas = torch::zeros_like(alphas);
|
||||
std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
|
||||
const at::Tensor& grad_outputs,
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
auto grad_features = at::zeros_like(features);
|
||||
auto grad_alphas = at::zeros_like(alphas);
|
||||
|
||||
const int64_t bs = points_idx.size(0);
|
||||
|
||||
@ -149,12 +150,12 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
|
||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
|
@ -1,7 +1,6 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
// Kernel for inputs_packed of shape (F, D), where D > 1
|
||||
template <typename scalar_t>
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
@ -97,11 +97,11 @@ __global__ void PointEdgeForwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& segms_first_idx,
|
||||
const int64_t max_points) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
@ -114,8 +114,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
|
||||
AT_ASSERTM(segms_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor dists = torch::zeros({P,}, points.options());
|
||||
torch::Tensor idxs = torch::zeros({P,}, points_first_idx.options());
|
||||
at::Tensor dists = at::zeros({P,}, points.options());
|
||||
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
const int threads = 128;
|
||||
@ -178,11 +178,11 @@ __global__ void PointEdgeBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists) {
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& idx_points,
|
||||
const at::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
@ -194,8 +194,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
AT_ASSERTM(grad_dists.size(0) == P);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
@ -302,11 +302,11 @@ __global__ void EdgePointForwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& segms_first_idx,
|
||||
const int64_t max_segms) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
@ -319,8 +319,8 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
|
||||
AT_ASSERTM(segms_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor dists = torch::zeros({S,}, segms.options());
|
||||
torch::Tensor idxs = torch::zeros({S,}, segms_first_idx.options());
|
||||
at::Tensor dists = at::zeros({S,}, segms.options());
|
||||
at::Tensor idxs = at::zeros({S,}, segms_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
const int threads = 128;
|
||||
@ -384,11 +384,11 @@ __global__ void EdgePointBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& idx_segms,
|
||||
const at::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
@ -400,8 +400,8 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
|
||||
AT_ASSERTM(grad_dists.size(0) == S);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
@ -448,9 +448,9 @@ __global__ void PointEdgeArrayForwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms) {
|
||||
at::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
@ -459,7 +459,7 @@ torch::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
|
||||
torch::Tensor dists = torch::zeros({P, S}, points.options());
|
||||
at::Tensor dists = at::zeros({P, S}, points.options());
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
@ -516,10 +516,10 @@ __global__ void PointEdgeArrayBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
@ -529,8 +529,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
|
||||
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
@ -98,11 +98,11 @@ __global__ void PointFaceForwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& tris_first_idx,
|
||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_points) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
@ -115,8 +115,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
|
||||
AT_ASSERTM(tris_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor dists = torch::zeros({P,}, points.options());
|
||||
torch::Tensor idxs = torch::zeros({P,}, points_first_idx.options());
|
||||
at::Tensor dists = at::zeros({P,}, points.options());
|
||||
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
const int threads = 128;
|
||||
@ -186,11 +186,11 @@ __global__ void PointFaceBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists) {
|
||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& idx_points,
|
||||
const at::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
@ -202,8 +202,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
|
||||
AT_ASSERTM(grad_dists.size(0) == P);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
@ -311,11 +311,11 @@ __global__ void FacePointForwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& tris_first_idx,
|
||||
std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_tris) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
@ -328,8 +328,8 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
|
||||
AT_ASSERTM(tris_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor dists = torch::zeros({T,}, tris.options());
|
||||
torch::Tensor idxs = torch::zeros({T,}, tris_first_idx.options());
|
||||
at::Tensor dists = at::zeros({T,}, tris.options());
|
||||
at::Tensor idxs = at::zeros({T,}, tris_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
const int threads = 128;
|
||||
@ -400,11 +400,11 @@ __global__ void FacePointBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& idx_tris,
|
||||
const torch::Tensor& grad_dists) {
|
||||
std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& idx_tris,
|
||||
const at::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
@ -416,8 +416,8 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
|
||||
AT_ASSERTM(grad_dists.size(0) == T);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
@ -465,9 +465,9 @@ __global__ void PointFaceArrayForwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor PointFaceArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris) {
|
||||
at::Tensor PointFaceArrayDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& tris) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
@ -476,7 +476,7 @@ torch::Tensor PointFaceArrayDistanceForwardCuda(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
|
||||
torch::Tensor dists = torch::zeros({P, T}, points.options());
|
||||
at::Tensor dists = at::zeros({P, T}, points.options());
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
@ -542,10 +542,10 @@ __global__ void PointFaceArrayBackwardKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& grad_dists) {
|
||||
std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
@ -555,8 +555,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == T));
|
||||
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
@ -1,9 +1,9 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <thrust/tuple.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "rasterize_points/bitmask.cuh"
|
||||
@ -275,11 +275,11 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
RasterizeMeshesNaiveCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_faces_packed_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_faces_packed_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
@ -305,13 +305,13 @@ RasterizeMeshesNaiveCuda(
|
||||
const int W = image_size;
|
||||
const int K = num_closest;
|
||||
|
||||
auto long_opts = face_verts.options().dtype(torch::kInt64);
|
||||
auto float_opts = face_verts.options().dtype(torch::kFloat32);
|
||||
auto long_opts = face_verts.options().dtype(at::kLong);
|
||||
auto float_opts = face_verts.options().dtype(at::kFloat);
|
||||
|
||||
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
|
||||
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
|
||||
at::Tensor face_idxs = at::full({N, H, W, K}, -1, long_opts);
|
||||
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
@ -458,12 +458,12 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
const torch::Tensor& face_verts, // (F, 3, 3)
|
||||
const torch::Tensor& pix_to_face, // (N, H, W, K)
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const torch::Tensor& grad_dists, // (N, H, W, K)
|
||||
at::Tensor RasterizeMeshesBackwardCuda(
|
||||
const at::Tensor& face_verts, // (F, 3, 3)
|
||||
const at::Tensor& pix_to_face, // (N, H, W, K)
|
||||
const at::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const at::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const at::Tensor& grad_dists, // (N, H, W, K)
|
||||
const bool perspective_correct) {
|
||||
const int F = face_verts.size(0);
|
||||
const int N = pix_to_face.size(0);
|
||||
@ -471,7 +471,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
const int W = pix_to_face.size(2);
|
||||
const int K = pix_to_face.size(3);
|
||||
|
||||
torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
|
||||
at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options());
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
@ -618,10 +618,10 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RasterizeMeshesCoarseCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
at::Tensor RasterizeMeshesCoarseCuda(
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_face_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
@ -642,9 +642,9 @@ torch::Tensor RasterizeMeshesCoarseCuda(
|
||||
ss << "Got " << num_bins << "; that's too many!";
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto opts = face_verts.options().dtype(torch::kInt32);
|
||||
torch::Tensor faces_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
|
||||
torch::Tensor bin_faces = torch::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
auto opts = face_verts.options().dtype(at::kInt);
|
||||
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
@ -765,10 +765,10 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
RasterizeMeshesFineCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& bin_faces,
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& bin_faces,
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
@ -792,13 +792,13 @@ RasterizeMeshesFineCuda(
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 8");
|
||||
}
|
||||
auto long_opts = face_verts.options().dtype(torch::kInt64);
|
||||
auto float_opts = face_verts.options().dtype(torch::kFloat32);
|
||||
auto long_opts = face_verts.options().dtype(at::kLong);
|
||||
auto float_opts = face_verts.options().dtype(at::kFloat);
|
||||
|
||||
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
|
||||
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
|
||||
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
|
||||
at::Tensor face_idxs = at::full({N, H, W, K}, -1, long_opts);
|
||||
at::Tensor zbuf = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <math.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
@ -138,11 +138,10 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizePointsNaiveCuda(
|
||||
const torch::Tensor& points, // (P. 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
const at::Tensor& points, // (P. 3)
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel) {
|
||||
@ -164,11 +163,11 @@ RasterizePointsNaiveCuda(
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
|
||||
auto int_opts = points.options().dtype(torch::kInt32);
|
||||
auto float_opts = points.options().dtype(torch::kFloat32);
|
||||
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
|
||||
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
|
||||
auto int_opts = points.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
@ -316,10 +315,10 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RasterizePointsCoarseCuda(
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
at::Tensor RasterizePointsCoarseCuda(
|
||||
const at::Tensor& points, // (P, 3)
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
@ -338,9 +337,9 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
ss << "Got " << num_bins << "; that's too many!";
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto opts = points.options().dtype(torch::kInt32);
|
||||
torch::Tensor points_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
|
||||
torch::Tensor bin_points = torch::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
auto opts = points.options().dtype(at::kInt);
|
||||
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
@ -442,9 +441,9 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
const torch::Tensor& points, // (P, 3)
|
||||
const torch::Tensor& bin_points,
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
const at::Tensor& points, // (P, 3)
|
||||
const at::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
@ -457,11 +456,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 8");
|
||||
}
|
||||
auto int_opts = points.options().dtype(torch::kInt32);
|
||||
auto float_opts = points.options().dtype(torch::kFloat32);
|
||||
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
|
||||
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
|
||||
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
|
||||
auto int_opts = points.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
@ -533,18 +532,18 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RasterizePointsBackwardCuda(
|
||||
const torch::Tensor& points, // (N, P, 3)
|
||||
const torch::Tensor& idxs, // (N, H, W, K)
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_dists) { // (N, H, W, K)
|
||||
at::Tensor RasterizePointsBackwardCuda(
|
||||
const at::Tensor& points, // (N, P, 3)
|
||||
const at::Tensor& idxs, // (N, H, W, K)
|
||||
const at::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const at::Tensor& grad_dists) { // (N, H, W, K)
|
||||
const int P = points.size(0);
|
||||
const int N = idxs.size(0);
|
||||
const int H = idxs.size(1);
|
||||
const int W = idxs.size(2);
|
||||
const int K = idxs.size(3);
|
||||
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include "float_math.cuh"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user