Cuda updates

Summary:
Updates to:
- enable cuda kernel launches on any GPU (not just the default)
- cuda and contiguous checks for all kernels
- checks to ensure all tensors are on the same device
- error reporting in the cuda kernels
- cuda tests now run on a random device not just the default

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21215280

fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
Nikhila Ravi
2020-04-24 09:07:54 -07:00
committed by Facebook GitHub Bot
parent c9267ab7af
commit c3d636dc8c
33 changed files with 979 additions and 240 deletions

View File

@@ -2,6 +2,8 @@
#include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
@@ -136,6 +138,17 @@ at::Tensor alphaCompositeCudaForward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg features_t{features, "features", 1},
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
at::CheckedFrom c = "alphaCompositeCudaForward";
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
@@ -143,19 +156,24 @@ at::Tensor alphaCompositeCudaForward(
auto result = at::zeros({batch_size, C, H, W}, features.options());
if (result.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
@@ -164,9 +182,26 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
points_idx_t{points_idx, "points_idx", 4};
at::CheckedFrom c = "alphaCompositeCudaBackward";
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_features = at::zeros_like(features);
auto grad_alphas = at::zeros_like(alphas);
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}
const int64_t bs = alphas.size(0);
const dim3 threadsPerBlock(64);
@@ -174,7 +209,7 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
@@ -183,6 +218,6 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}

View File

@@ -2,6 +2,8 @@
#include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
@@ -151,6 +153,17 @@ at::Tensor weightedSumNormCudaForward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg features_t{features, "features", 1},
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
at::CheckedFrom c = "weightedSumNormCudaForward";
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
@@ -158,19 +171,25 @@ at::Tensor weightedSumNormCudaForward(
auto result = at::zeros({batch_size, C, H, W}, features.options());
if (result.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
// clang-format off
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
@@ -179,9 +198,26 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
points_idx_t{points_idx, "points_idx", 4};
at::CheckedFrom c = "weightedSumNormCudaBackward";
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_features = at::zeros_like(features);
auto grad_alphas = at::zeros_like(alphas);
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}
const int64_t bs = points_idx.size(0);
const dim3 threadsPerBlock(64);
@@ -189,7 +225,7 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
@@ -198,6 +234,6 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}

View File

@@ -2,6 +2,8 @@
#include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
@@ -110,6 +112,17 @@ at::Tensor weightedSumCudaForward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg features_t{features, "features", 1},
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
at::CheckedFrom c = "weightedSumCudaForward";
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
@@ -117,19 +130,24 @@ at::Tensor weightedSumCudaForward(
auto result = at::zeros({batch_size, C, H, W}, features.options());
if (result.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
@@ -138,9 +156,26 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
const at::Tensor& features,
const at::Tensor& alphas,
const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
points_idx_t{points_idx, "points_idx", 4};
at::CheckedFrom c = "weightedSumCudaBackward";
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_features = at::zeros_like(features);
auto grad_alphas = at::zeros_like(alphas);
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}
const int64_t bs = points_idx.size(0);
const dim3 threadsPerBlock(64);
@@ -148,7 +183,7 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
@@ -157,6 +192,6 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}

View File

@@ -23,7 +23,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
m.def("gather_scatter", &gather_scatter);
m.def("gather_scatter", &GatherScatter);
m.def("rasterize_points", &RasterizePoints);
m.def("rasterize_points_backward", &RasterizePointsBackward);
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);

View File

@@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <tuple>
template <typename scalar_t>
@@ -213,14 +215,30 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
const auto V = verts.size(0);
const auto F = faces.size(0);
// Check inputs are on the same device
at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2};
at::CheckedFrom c = "FaceAreasNormalsForwardCuda";
at::checkAllSameGPU(c, {verts_t, faces_t});
at::checkAllSameType(c, {verts_t, faces_t});
// Set the device for the kernel launch based on the device of verts
at::cuda::CUDAGuard device_guard(verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::Tensor areas = at::empty({F}, verts.options());
at::Tensor normals = at::empty({F, 3}, verts.options());
if (areas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(areas, normals);
}
const int blocks = 64;
const int threads = 512;
AT_DISPATCH_FLOATING_TYPES(
verts.scalar_type(), "face_areas_normals_forward_cuda", ([&] {
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads>>>(
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
verts.data_ptr<scalar_t>(),
faces.data_ptr<int64_t>(),
areas.data_ptr<scalar_t>(),
@@ -228,7 +246,7 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
V,
F);
}));
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(areas, normals);
}
@@ -237,16 +255,33 @@ at::Tensor FaceAreasNormalsBackwardCuda(
const at::Tensor grad_normals,
const at::Tensor verts,
const at::Tensor faces) {
// Check inputs are on the same device
at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2},
grad_areas_t{verts, "grad_areas", 3},
grad_normals_t{verts, "grad_normals", 4};
at::CheckedFrom c = "FaceAreasNormalsBackwardCuda";
at::checkAllSameGPU(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
at::checkAllSameType(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
// Set the device for the kernel launch based on the device of verts
at::cuda::CUDAGuard device_guard(verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto V = verts.size(0);
const auto F = faces.size(0);
at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options());
if (grad_verts.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_verts;
}
const int blocks = 64;
const int threads = 512;
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
FaceAreasNormalsBackwardKernel<<<blocks, threads>>>(
FaceAreasNormalsBackwardKernel<<<blocks, threads, 0, stream>>>(
grad_areas.data_ptr<float>(),
grad_normals.data_ptr<float>(),
verts.data_ptr<float>(),
@@ -255,5 +290,6 @@ at::Tensor FaceAreasNormalsBackwardCuda(
V,
F);
AT_CUDA_CHECK(cudaGetLastError());
return grad_verts;
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include <torch/extension.h>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// Compute areas of mesh faces using packed representation.
//
@@ -46,6 +47,8 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
const at::Tensor faces) {
if (verts.is_cuda() && faces.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(verts);
CHECK_CONTIGUOUS_CUDA(faces);
return FaceAreasNormalsForwardCuda(verts, faces);
#else
AT_ERROR("Not compiled with GPU support.");
@@ -62,6 +65,10 @@ at::Tensor FaceAreasNormalsBackward(
const at::Tensor faces) {
if (verts.is_cuda() && faces.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(verts);
CHECK_CONTIGUOUS_CUDA(faces);
CHECK_CONTIGUOUS_CUDA(grad_areas);
CHECK_CONTIGUOUS_CUDA(grad_normals);
return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces);
#else
AT_ERROR("Not compiled with GPU support.");

View File

@@ -1,9 +1,11 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
// TODO(T47953967) to make this cuda kernel support all datatypes.
__global__ void gather_scatter_kernel(
__global__ void GatherScatterCudaKernel(
const float* __restrict__ input,
const int64_t* __restrict__ edges,
float* __restrict__ output,
@@ -41,11 +43,20 @@ __global__ void gather_scatter_kernel(
}
}
at::Tensor gather_scatter_cuda(
at::Tensor GatherScatterCuda(
const at::Tensor input,
const at::Tensor edges,
bool directed,
bool backward) {
// Check inputs are on the same device
at::TensorArg input_t{input, "input", 1}, edges_t{edges, "edges", 2};
at::CheckedFrom c = "GatherScatterCuda";
at::checkAllSameGPU(c, {input_t, edges_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto num_vertices = input.size(0);
const auto input_feature_dim = input.size(1);
const auto num_edges = edges.size(0);
@@ -55,7 +66,12 @@ at::Tensor gather_scatter_cuda(
const size_t max_blocks = 1920;
const size_t blocks = num_edges < max_blocks ? num_edges : max_blocks;
gather_scatter_kernel<<<blocks, threads>>>(
if (output.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return output;
}
GatherScatterCudaKernel<<<blocks, threads, 0, stream>>>(
input.data_ptr<float>(),
edges.data_ptr<int64_t>(),
output.data_ptr<float>(),
@@ -64,6 +80,6 @@ at::Tensor gather_scatter_cuda(
num_vertices,
input_feature_dim,
num_edges);
AT_CUDA_CHECK(cudaGetLastError());
return output;
}

View File

@@ -2,6 +2,7 @@
#pragma once
#include <torch/extension.h>
#include "utils/pytorch3d_cutils.h"
// Fused gather scatter operation for aggregating features of neighbor nodes
// in a graph. This gather scatter operation is specific to graphs as edge
@@ -20,21 +21,23 @@
// output: float32 Tensor of same shape as input.
// Cuda implementation.
at::Tensor gather_scatter_cuda(
at::Tensor GatherScatterCuda(
const at::Tensor input,
const at::Tensor edges,
bool directed,
bool backward);
// Exposed implementation.
at::Tensor gather_scatter(
at::Tensor GatherScatter(
const at::Tensor input,
const at::Tensor edges,
bool directed,
bool backward) {
if (input.is_cuda() && edges.is_cuda()) {
#ifdef WITH_CUDA
return gather_scatter_cuda(input, edges, directed, backward);
CHECK_CONTIGUOUS_CUDA(input);
CHECK_CONTIGUOUS_CUDA(edges);
return GatherScatterCuda(input, edges, directed, backward);
#else
AT_ERROR("Not compiled with GPU support.");
#endif

View File

@@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <iostream>
#include <tuple>
@@ -114,7 +116,8 @@ struct KNearestNeighborV1Functor {
const size_t P1,
const size_t P2,
const size_t K) {
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
}
};
@@ -178,7 +181,8 @@ struct KNearestNeighborKernelV2Functor {
const int64_t N,
const int64_t P1,
const int64_t P2) {
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
}
};
@@ -245,7 +249,8 @@ struct KNearestNeighborKernelV3Functor {
const size_t N,
const size_t P1,
const size_t P2) {
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
}
};
@@ -296,17 +301,33 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& lengths2,
int K,
int version) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
at::CheckedFrom c = "KNearestNeighborIdxCuda";
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
at::checkAllSameType(c, {p1_t, p2_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const int64_t K_64 = K;
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = p1.options().dtype(at::kLong);
auto idxs = at::zeros({N, P1, K}, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options());
if (idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}
if (version < 0) {
version = ChooseVersion(D, K);
} else if (!KnnCheckVersion(version, D, K)) {
@@ -328,7 +349,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
if (version == 0) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
KNearestNeighborKernelV0<scalar_t>
<<<blocks, threads>>>(
<<<blocks, threads, 0, stream>>>(
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(),
@@ -409,7 +430,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
P2);
}));
}
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}
@@ -465,27 +486,45 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& lengths2,
const at::Tensor& idxs,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4},
idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6};
at::CheckedFrom c = "KNearestNeighborBackwardCuda";
at::checkAllSameGPU(
c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p2.size(2);
const auto K = idxs.size(2);
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
AT_ASSERTM(
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
TORCH_CHECK(
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
AT_ASSERTM(grad_dists.size(0) == N);
AT_ASSERTM(grad_dists.size(1) == P1);
AT_ASSERTM(grad_dists.size(2) == K);
TORCH_CHECK(grad_dists.size(0) == N);
TORCH_CHECK(grad_dists.size(1) == P1);
TORCH_CHECK(grad_dists.size(2) == K);
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
if (grad_p1.numel() == 0 || grad_p2.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
}
const int blocks = 64;
const int threads = 512;
KNearestNeighborBackwardKernel<<<blocks, threads>>>(
KNearestNeighborBackwardKernel<<<blocks, threads, 0, stream>>>(
p1.data_ptr<float>(),
p2.data_ptr<float>(),
lengths1.data_ptr<int64_t>(),
@@ -500,5 +539,6 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
K,
D);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
}

View File

@@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
// Kernel for inputs_packed of shape (F, D), where D > 1
template <typename scalar_t>
@@ -114,21 +116,36 @@ at::Tensor PackedToPaddedCuda(
const at::Tensor inputs_packed,
const at::Tensor first_idxs,
const int64_t max_size) {
// Check inputs are on the same device
at::TensorArg inputs_packed_t{inputs_packed, "inputs_packed", 1},
first_idxs_t{first_idxs, "first_idxs", 2};
at::CheckedFrom c = "PackedToPaddedCuda";
at::checkAllSameGPU(c, {inputs_packed_t, first_idxs_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(inputs_packed.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t num_inputs = inputs_packed.size(0);
const int64_t batch_size = first_idxs.size(0);
AT_ASSERTM(
TORCH_CHECK(
inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
const int64_t D = inputs_packed.size(1);
at::Tensor inputs_padded =
at::zeros({batch_size, max_size, D}, inputs_packed.options());
if (inputs_padded.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return inputs_padded;
}
const int threads = 512;
const int blocks = batch_size;
if (D == 1) {
AT_DISPATCH_FLOATING_TYPES(
inputs_packed.scalar_type(), "packed_to_padded_d1_kernel", ([&] {
PackedToPaddedKernelD1<scalar_t><<<blocks, threads>>>(
PackedToPaddedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
inputs_packed.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(),
inputs_padded.data_ptr<scalar_t>(),
@@ -139,7 +156,7 @@ at::Tensor PackedToPaddedCuda(
} else {
AT_DISPATCH_FLOATING_TYPES(
inputs_packed.scalar_type(), "packed_to_padded_kernel", ([&] {
PackedToPaddedKernel<scalar_t><<<blocks, threads>>>(
PackedToPaddedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
inputs_packed.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(),
inputs_padded.data_ptr<scalar_t>(),
@@ -150,6 +167,7 @@ at::Tensor PackedToPaddedCuda(
}));
}
AT_CUDA_CHECK(cudaGetLastError());
return inputs_padded;
}
@@ -157,11 +175,21 @@ at::Tensor PaddedToPackedCuda(
const at::Tensor inputs_padded,
const at::Tensor first_idxs,
const int64_t num_inputs) {
// Check inputs are on the same device
at::TensorArg inputs_padded_t{inputs_padded, "inputs_padded", 1},
first_idxs_t{first_idxs, "first_idxs", 2};
at::CheckedFrom c = "PaddedToPackedCuda";
at::checkAllSameGPU(c, {inputs_padded_t, first_idxs_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(inputs_padded.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = inputs_padded.size(0);
const int64_t max_size = inputs_padded.size(1);
AT_ASSERTM(batch_size == first_idxs.size(0), "sizes mismatch");
AT_ASSERTM(
TORCH_CHECK(batch_size == first_idxs.size(0), "sizes mismatch");
TORCH_CHECK(
inputs_padded.dim() == 3,
"inputs_padded must be a 3-dimensional tensor");
const int64_t D = inputs_padded.size(2);
@@ -169,13 +197,18 @@ at::Tensor PaddedToPackedCuda(
at::Tensor inputs_packed =
at::zeros({num_inputs, D}, inputs_padded.options());
if (inputs_packed.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return inputs_packed;
}
const int threads = 512;
const int blocks = batch_size;
if (D == 1) {
AT_DISPATCH_FLOATING_TYPES(
inputs_padded.scalar_type(), "padded_to_packed_d1_kernel", ([&] {
PaddedToPackedKernelD1<scalar_t><<<blocks, threads>>>(
PaddedToPackedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
inputs_padded.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(),
inputs_packed.data_ptr<scalar_t>(),
@@ -186,7 +219,7 @@ at::Tensor PaddedToPackedCuda(
} else {
AT_DISPATCH_FLOATING_TYPES(
inputs_padded.scalar_type(), "padded_to_packed_kernel", ([&] {
PaddedToPackedKernel<scalar_t><<<blocks, threads>>>(
PaddedToPackedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
inputs_padded.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(),
inputs_packed.data_ptr<scalar_t>(),
@@ -197,5 +230,6 @@ at::Tensor PaddedToPackedCuda(
}));
}
AT_CUDA_CHECK(cudaGetLastError());
return inputs_packed;
}

View File

@@ -2,6 +2,7 @@
#pragma once
#include <torch/extension.h>
#include "utils/pytorch3d_cutils.h"
// PackedToPadded
// Converts a packed tensor into a padded tensor, restoring the batch dimension.
@@ -74,6 +75,8 @@ at::Tensor PackedToPadded(
const int64_t max_size) {
if (inputs_packed.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(inputs_packed);
CHECK_CONTIGUOUS_CUDA(first_idxs);
return PackedToPaddedCuda(inputs_packed, first_idxs, max_size);
#else
AT_ERROR("Not compiled with GPU support.");
@@ -89,6 +92,8 @@ at::Tensor PaddedToPacked(
const int64_t num_inputs) {
if (inputs_padded.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(inputs_padded);
CHECK_CONTIGUOUS_CUDA(first_idxs);
return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs);
#else
AT_ERROR("Not compiled with GPU support.");

View File

@@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <list>
#include <queue>
@@ -103,26 +105,45 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
const at::Tensor& segms,
const at::Tensor& segms_first_idx,
const int64_t max_points) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2},
segms_t{segms, "segms", 3},
segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
at::CheckedFrom c = "PointEdgeDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
at::checkAllSameType(c, {points_t, segms_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM(segms_first_idx.size(0) == B);
TORCH_CHECK(segms_first_idx.size(0) == B);
// clang-format off
at::Tensor dists = at::zeros({P,}, points.options());
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
// clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128;
const dim3 blocks(max_points, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
PointEdgeForwardKernel<<<blocks, threads, shared_size>>>(
PointEdgeForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(),
segms.data_ptr<float>(),
@@ -132,7 +153,7 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
B,
P,
S);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
@@ -183,25 +204,42 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
const at::Tensor& segms,
const at::Tensor& idx_points,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
idx_points_t{idx_points, "idx_points", 2}, segms_t{segms, "segms", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, idx_points_t, segms_t, grad_dists_t});
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM(idx_points.size(0) == P);
AT_ASSERTM(grad_dists.size(0) == P);
TORCH_CHECK(idx_points.size(0) == P);
TORCH_CHECK(grad_dists.size(0) == P);
// clang-format off
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
// clang-format on
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_segms);
}
const int blocks = 64;
const int threads = 512;
PointEdgeBackwardKernel<<<blocks, threads>>>(
PointEdgeBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(),
segms.data_ptr<float>(),
idx_points.data_ptr<int64_t>(),
@@ -210,6 +248,7 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
grad_segms.data_ptr<float>(),
P);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_segms);
}
@@ -308,26 +347,45 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
const at::Tensor& segms,
const at::Tensor& segms_first_idx,
const int64_t max_segms) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2},
segms_t{segms, "segms", 3},
segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
at::CheckedFrom c = "EdgePointDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
at::checkAllSameType(c, {points_t, segms_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM(segms_first_idx.size(0) == B);
TORCH_CHECK(segms_first_idx.size(0) == B);
// clang-format off
at::Tensor dists = at::zeros({S,}, segms.options());
at::Tensor idxs = at::zeros({S,}, segms_first_idx.options());
// clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128;
const dim3 blocks(max_segms, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
EdgePointForwardKernel<<<blocks, threads, shared_size>>>(
EdgePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(),
segms.data_ptr<float>(),
@@ -337,7 +395,7 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
B,
P,
S);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
@@ -389,15 +447,27 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
const at::Tensor& segms,
const at::Tensor& idx_segms,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
idx_segms_t{idx_segms, "idx_segms", 2}, segms_t{segms, "segms", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, idx_segms_t, segms_t, grad_dists_t});
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM(idx_segms.size(0) == S);
AT_ASSERTM(grad_dists.size(0) == S);
TORCH_CHECK(idx_segms.size(0) == S);
TORCH_CHECK(grad_dists.size(0) == S);
// clang-format off
at::Tensor grad_points = at::zeros({P, 3}, points.options());
@@ -407,7 +477,7 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
const int blocks = 64;
const int threads = 512;
EdgePointBackwardKernel<<<blocks, threads>>>(
EdgePointBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(),
segms.data_ptr<float>(),
idx_segms.data_ptr<int64_t>(),
@@ -451,26 +521,42 @@ __global__ void PointEdgeArrayForwardKernel(
at::Tensor PointEdgeArrayDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& segms) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2};
at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda";
at::checkAllSameGPU(c, {points_t, segms_t});
at::checkAllSameType(c, {points_t, segms_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
at::Tensor dists = at::zeros({P, S}, points.options());
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return dists;
}
const size_t blocks = 1024;
const size_t threads = 64;
PointEdgeArrayForwardKernel<<<blocks, threads>>>(
PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(),
segms.data_ptr<float>(),
dists.data_ptr<float>(),
P,
S);
AT_CUDA_CHECK(cudaGetLastError());
return dists;
}
@@ -520,22 +606,38 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& segms,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2},
grad_dists_t{grad_dists, "grad_dists", 3};
at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3");
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_segms);
}
const size_t blocks = 1024;
const size_t threads = 64;
PointEdgeArrayBackwardKernel<<<blocks, threads>>>(
PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(),
segms.data_ptr<float>(),
grad_dists.data_ptr<float>(),
@@ -543,6 +645,6 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
grad_segms.data_ptr<float>(),
P,
S);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_segms);
}

View File

@@ -4,6 +4,7 @@
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// ****************************************************************************
// * PointEdgeDistance *
@@ -53,6 +54,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
const int64_t max_points) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(points_first_idx);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(segms_first_idx);
return PointEdgeDistanceForwardCuda(
points, points_first_idx, segms, segms_first_idx, max_points);
#else
@@ -93,6 +98,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(idx_points);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
@@ -149,6 +158,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
const int64_t max_segms) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(points_first_idx);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(segms_first_idx);
return EdgePointDistanceForwardCuda(
points, points_first_idx, segms, segms_first_idx, max_segms);
#else
@@ -189,6 +202,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(idx_segms);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
@@ -220,7 +237,6 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
// will require for the forward pass 5.8G of memory to store dists.
#ifdef WITH_CUDA
torch::Tensor PointEdgeArrayDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& segms);
@@ -231,6 +247,8 @@ torch::Tensor PointEdgeArrayDistanceForward(
const torch::Tensor& segms) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(segms);
return PointEdgeArrayDistanceForwardCuda(points, segms);
#else
AT_ERROR("Not compiled with GPU support.");
@@ -265,6 +283,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");

View File

@@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <list>
#include <queue>
@@ -104,26 +106,45 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
const at::Tensor& tris,
const at::Tensor& tris_first_idx,
const int64_t max_points) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2},
tris_t{tris, "tris", 3},
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
at::CheckedFrom c = "PointFaceDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
at::checkAllSameType(c, {points_t, tris_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3");
AT_ASSERTM(tris_first_idx.size(0) == B);
TORCH_CHECK(tris_first_idx.size(0) == B);
// clang-format off
at::Tensor dists = at::zeros({P,}, points.options());
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
// clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128;
const dim3 blocks(max_points, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
PointFaceForwardKernel<<<blocks, threads, shared_size>>>(
PointFaceForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(),
tris.data_ptr<float>(),
@@ -134,6 +155,7 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
P,
T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
@@ -191,25 +213,42 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
const at::Tensor& tris,
const at::Tensor& idx_points,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
idx_points_t{idx_points, "idx_points", 2}, tris_t{tris, "tris", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "PointFaceDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, idx_points_t, tris_t, grad_dists_t});
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3");
AT_ASSERTM(idx_points.size(0) == P);
AT_ASSERTM(grad_dists.size(0) == P);
TORCH_CHECK(idx_points.size(0) == P);
TORCH_CHECK(grad_dists.size(0) == P);
// clang-format off
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
// clang-format on
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}
const int blocks = 64;
const int threads = 512;
PointFaceBackwardKernel<<<blocks, threads>>>(
PointFaceBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(),
tris.data_ptr<float>(),
idx_points.data_ptr<int64_t>(),
@@ -218,6 +257,7 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
grad_tris.data_ptr<float>(),
P);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}
@@ -317,26 +357,45 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
const at::Tensor& tris,
const at::Tensor& tris_first_idx,
const int64_t max_tris) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2},
tris_t{tris, "tris", 3},
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
at::CheckedFrom c = "FacePointDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
at::checkAllSameType(c, {points_t, tris_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3");
AT_ASSERTM(tris_first_idx.size(0) == B);
TORCH_CHECK(tris_first_idx.size(0) == B);
// clang-format off
at::Tensor dists = at::zeros({T,}, tris.options());
at::Tensor idxs = at::zeros({T,}, tris_first_idx.options());
// clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128;
const dim3 blocks(max_tris, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
FacePointForwardKernel<<<blocks, threads, shared_size>>>(
FacePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(),
tris.data_ptr<float>(),
@@ -347,6 +406,7 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
P,
T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
@@ -405,25 +465,42 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
const at::Tensor& tris,
const at::Tensor& idx_tris,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
idx_tris_t{idx_tris, "idx_tris", 2}, tris_t{tris, "tris", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "FacePointDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, idx_tris_t, tris_t, grad_dists_t});
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3");
AT_ASSERTM(idx_tris.size(0) == T);
AT_ASSERTM(grad_dists.size(0) == T);
TORCH_CHECK(idx_tris.size(0) == T);
TORCH_CHECK(grad_dists.size(0) == T);
// clang-format off
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
// clang-format on
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}
const int blocks = 64;
const int threads = 512;
FacePointBackwardKernel<<<blocks, threads>>>(
FacePointBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(),
tris.data_ptr<float>(),
idx_tris.data_ptr<int64_t>(),
@@ -432,6 +509,7 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
grad_tris.data_ptr<float>(),
T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}
@@ -468,26 +546,42 @@ __global__ void PointFaceArrayForwardKernel(
at::Tensor PointFaceArrayDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& tris) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2};
at::CheckedFrom c = "PointFaceArrayDistanceForwardCuda";
at::checkAllSameGPU(c, {points_t, tris_t});
at::checkAllSameType(c, {points_t, tris_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3");
at::Tensor dists = at::zeros({P, T}, points.options());
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return dists;
}
const size_t blocks = 1024;
const size_t threads = 64;
PointFaceArrayForwardKernel<<<blocks, threads>>>(
PointFaceArrayForwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(),
tris.data_ptr<float>(),
dists.data_ptr<float>(),
P,
T);
AT_CUDA_CHECK(cudaGetLastError());
return dists;
}
@@ -546,22 +640,38 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2},
grad_dists_t{grad_dists, "grad_dists", 3};
at::CheckedFrom c = "PointFaceArrayDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, tris_t, grad_dists_t});
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM(
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3");
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == T));
TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == T));
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}
const size_t blocks = 1024;
const size_t threads = 64;
PointFaceArrayBackwardKernel<<<blocks, threads>>>(
PointFaceArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(),
tris.data_ptr<float>(),
grad_dists.data_ptr<float>(),
@@ -570,5 +680,6 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
P,
T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}

View File

@@ -4,6 +4,7 @@
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// ****************************************************************************
// * PointFaceDistance *
@@ -55,6 +56,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
const int64_t max_points) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(points_first_idx);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(tris_first_idx);
return PointFaceDistanceForwardCuda(
points, points_first_idx, tris, tris_first_idx, max_points);
#else
@@ -95,6 +100,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(idx_points);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
@@ -151,6 +160,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
const int64_t max_tris) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(points_first_idx);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(tris_first_idx);
return FacePointDistanceForwardCuda(
points, points_first_idx, tris, tris_first_idx, max_tris);
#else
@@ -191,6 +204,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(idx_tris);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
@@ -233,6 +250,8 @@ torch::Tensor PointFaceArrayDistanceForward(
const torch::Tensor& tris) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(tris);
return PointFaceArrayDistanceForwardCuda(points, tris);
#else
AT_ERROR("Not compiled with GPU support.");
@@ -254,7 +273,6 @@ torch::Tensor PointFaceArrayDistanceForward(
//
#ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
@@ -267,6 +285,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");

View File

@@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <math.h>
#include <thrust/tuple.h>
@@ -285,14 +287,14 @@ RasterizeMeshesNaiveCuda(
const int num_closest,
const bool perspective_correct,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.size(0) != mesh_to_faces_packed_first_idx.size(0)) {
AT_ERROR(
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
}
TORCH_CHECK(
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
face_verts.size(2) == 3,
"face_verts must have dimensions (num_faces, 3, 3)");
TORCH_CHECK(
num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
if (num_closest > kMaxPointsPerPixel) {
std::stringstream ss;
@@ -300,6 +302,20 @@ RasterizeMeshesNaiveCuda(
AT_ERROR(ss.str());
}
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
mesh_to_faces_packed_first_idx_t{
mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
at::checkAllSameGPU(
c,
{face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = num_faces_per_mesh.size(0); // batch size.
const int H = image_size; // Assume square images.
const int W = image_size;
@@ -313,10 +329,15 @@ RasterizeMeshesNaiveCuda(
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
if (face_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesNaiveCudaKernel<<<blocks, threads>>>(
RasterizeMeshesNaiveCudaKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(),
mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
@@ -332,6 +353,7 @@ RasterizeMeshesNaiveCuda(
pix_dists.contiguous().data_ptr<float>(),
bary.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}
@@ -465,6 +487,22 @@ at::Tensor RasterizeMeshesBackwardCuda(
const at::Tensor& grad_bary, // (N, H, W, K, 3)
const at::Tensor& grad_dists, // (N, H, W, K)
const bool perspective_correct) {
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
pix_to_face_t{pix_to_face, "pix_to_face", 2},
grad_zbuf_t{grad_zbuf, "grad_zbuf", 3},
grad_bary_t{grad_bary, "grad_bary", 4},
grad_dists_t{grad_dists, "grad_dists", 5};
at::CheckedFrom c = "RasterizeMeshesBackwardCuda";
at::checkAllSameGPU(
c, {face_verts_t, pix_to_face_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
at::checkAllSameType(
c, {face_verts_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
@@ -472,10 +510,16 @@ at::Tensor RasterizeMeshesBackwardCuda(
const int K = pix_to_face.size(3);
at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options());
if (grad_face_verts.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_face_verts;
}
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesBackwardCudaKernel<<<blocks, threads>>>(
RasterizeMeshesBackwardCudaKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(),
pix_to_face.contiguous().data_ptr<int64_t>(),
perspective_correct,
@@ -488,6 +532,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
grad_dists.contiguous().data_ptr<float>(),
grad_face_verts.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return grad_face_verts;
}
@@ -626,10 +671,24 @@ at::Tensor RasterizeMeshesCoarseCuda(
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
TORCH_CHECK(
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
face_verts.size(2) == 3,
"face_verts must have dimensions (num_faces, 3, 3)");
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
mesh_to_face_first_idx_t{
mesh_to_face_first_idx, "mesh_to_face_first_idx", 2},
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
at::CheckedFrom c = "RasterizeMeshesCoarseCuda";
at::checkAllSameGPU(
c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int W = image_size;
const int H = image_size;
const int F = face_verts.size(0);
@@ -645,12 +704,18 @@ at::Tensor RasterizeMeshesCoarseCuda(
auto opts = face_verts.options().dtype(at::kInt);
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
if (bin_faces.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;
}
const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size>>>(
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
face_verts.contiguous().data_ptr<float>(),
mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(),
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
@@ -664,6 +729,8 @@ at::Tensor RasterizeMeshesCoarseCuda(
M,
faces_per_bin.contiguous().data_ptr<int32_t>(),
bin_faces.contiguous().data_ptr<int32_t>());
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;
}
@@ -775,13 +842,22 @@ RasterizeMeshesFineCuda(
const int faces_per_pixel,
const bool perspective_correct,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (bin_faces.ndimension() != 4) {
AT_ERROR("bin_faces must have 4 dimensions");
}
TORCH_CHECK(
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
face_verts.size(2) == 3,
"face_verts must have dimensions (num_faces, 3, 3)");
TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
bin_faces_t{bin_faces, "bin_faces", 2};
at::CheckedFrom c = "RasterizeMeshesFineCuda";
at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = bin_faces.size(0);
const int B = bin_faces.size(1);
const int M = bin_faces.size(3);
@@ -790,7 +866,7 @@ RasterizeMeshesFineCuda(
const int W = image_size;
if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8");
AT_ERROR("Must have num_closest <= 150");
}
auto long_opts = face_verts.options().dtype(at::kLong);
auto float_opts = face_verts.options().dtype(at::kFloat);
@@ -800,10 +876,15 @@ RasterizeMeshesFineCuda(
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
if (face_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesFineCudaKernel<<<blocks, threads>>>(
RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(),
bin_faces.contiguous().data_ptr<int32_t>(),
blur_radius,

View File

@@ -4,6 +4,7 @@
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// ****************************************************************************
// * FORWARD PASS *
@@ -95,6 +96,9 @@ RasterizeMeshesNaive(
// TODO: Better type checking.
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(face_verts);
CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx);
CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh);
return RasterizeMeshesNaiveCuda(
face_verts,
mesh_to_face_first_idx,
@@ -175,6 +179,11 @@ torch::Tensor RasterizeMeshesBackward(
const bool perspective_correct) {
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(face_verts);
CHECK_CONTIGUOUS_CUDA(pix_to_face);
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
CHECK_CONTIGUOUS_CUDA(grad_bary);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return RasterizeMeshesBackwardCuda(
face_verts,
pix_to_face,
@@ -251,6 +260,9 @@ torch::Tensor RasterizeMeshesCoarse(
const int max_faces_per_bin) {
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(face_verts);
CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx);
CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh);
return RasterizeMeshesCoarseCuda(
face_verts,
mesh_to_face_first_idx,
@@ -347,6 +359,8 @@ RasterizeMeshesFine(
const bool cull_backfaces) {
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(face_verts);
CHECK_CONTIGUOUS_CUDA(bin_faces);
return RasterizeMeshesFineCuda(
face_verts,
bin_faces,

View File

@@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <math.h>
#include <cstdio>
#include <sstream>
@@ -145,13 +147,25 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
const int image_size,
const float radius,
const int points_per_pixel) {
if (points.ndimension() != 2 || points.size(1) != 3) {
AT_ERROR("points must have dimensions (num_points, 3)");
}
if (num_points_per_cloud.size(0) != cloud_to_packed_first_idx.size(0)) {
AT_ERROR(
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
}
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
cloud_to_packed_first_idx_t{
cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
at::CheckedFrom c = "RasterizePointsNaiveCuda";
at::checkAllSameGPU(
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(
points.ndimension() == 2 && points.size(1) == 3,
"points must have dimensions (num_points, 3)");
TORCH_CHECK(
num_points_per_cloud.size(0) == cloud_to_packed_first_idx.size(0),
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
const int N = num_points_per_cloud.size(0); // batch size.
const int S = image_size;
@@ -169,9 +183,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
if (point_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsNaiveCudaKernel<<<blocks, threads>>>(
RasterizePointsNaiveCudaKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
@@ -182,6 +201,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
point_idxs.contiguous().data_ptr<int32_t>(),
zbuf.contiguous().data_ptr<float>(),
pix_dists.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
@@ -323,14 +344,28 @@ at::Tensor RasterizePointsCoarseCuda(
const float radius,
const int bin_size,
const int max_points_per_bin) {
TORCH_CHECK(
points.ndimension() == 2 && points.size(1) == 3,
"points must have dimensions (num_points, 3)");
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
cloud_to_packed_first_idx_t{
cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
at::CheckedFrom c = "RasterizePointsCoarseCuda";
at::checkAllSameGPU(
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int P = points.size(0);
const int N = num_points_per_cloud.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
const int M = max_points_per_bin;
if (points.ndimension() != 2 || points.size(1) != 3) {
AT_ERROR("points must have dimensions (num_points, 3)");
}
if (num_bins >= 22) {
// Make sure we do not use too much shared memory.
std::stringstream ss;
@@ -340,12 +375,18 @@ at::Tensor RasterizePointsCoarseCuda(
auto opts = points.options().dtype(at::kInt);
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
if (bin_points.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_points;
}
const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
@@ -358,6 +399,8 @@ at::Tensor RasterizePointsCoarseCuda(
M,
points_per_bin.contiguous().data_ptr<int32_t>(),
bin_points.contiguous().data_ptr<int32_t>());
AT_CUDA_CHECK(cudaGetLastError());
return bin_points;
}
@@ -448,13 +491,23 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
const float radius,
const int bin_size,
const int points_per_pixel) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
bin_points_t{bin_points, "bin_points", 2};
at::CheckedFrom c = "RasterizePointsFineCuda";
at::checkAllSameGPU(c, {points_t, bin_points_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = bin_points.size(0);
const int B = bin_points.size(1); // num_bins
const int M = bin_points.size(3);
const int S = image_size;
const int K = points_per_pixel;
if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8");
AT_ERROR("Must have num_closest <= 150");
}
auto int_opts = points.options().dtype(at::kInt);
auto float_opts = points.options().dtype(at::kFloat);
@@ -462,9 +515,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
if (point_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsFineCudaKernel<<<blocks, threads>>>(
RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(),
bin_points.contiguous().data_ptr<int32_t>(),
radius,
@@ -478,6 +536,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
zbuf.contiguous().data_ptr<float>(),
pix_dists.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
@@ -537,6 +596,19 @@ at::Tensor RasterizePointsBackwardCuda(
const at::Tensor& idxs, // (N, H, W, K)
const at::Tensor& grad_zbuf, // (N, H, W, K)
const at::Tensor& grad_dists) { // (N, H, W, K)
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, idxs_t{idxs, "idxs", 2},
grad_zbuf_t{grad_zbuf, "grad_zbuf", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "RasterizePointsBackwardCuda";
at::checkAllSameGPU(c, {points_t, idxs_t, grad_zbuf_t, grad_dists_t});
at::checkAllSameType(c, {points_t, grad_zbuf_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int P = points.size(0);
const int N = idxs.size(0);
const int H = idxs.size(1);
@@ -544,10 +616,16 @@ at::Tensor RasterizePointsBackwardCuda(
const int K = idxs.size(3);
at::Tensor grad_points = at::zeros({P, 3}, points.options());
if (grad_points.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_points;
}
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsBackwardCudaKernel<<<blocks, threads>>>(
RasterizePointsBackwardCudaKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(),
idxs.contiguous().data_ptr<int32_t>(),
N,
@@ -559,5 +637,6 @@ at::Tensor RasterizePointsBackwardCuda(
grad_dists.contiguous().data_ptr<float>(),
grad_points.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return grad_points;
}

View File

@@ -4,6 +4,7 @@
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// ****************************************************************************
// * NAIVE RASTERIZATION *
@@ -66,6 +67,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
num_points_per_cloud.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
return RasterizePointsNaiveCuda(
points,
cloud_to_packed_first_idx,
@@ -140,6 +144,9 @@ torch::Tensor RasterizePointsCoarse(
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
num_points_per_cloud.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
return RasterizePointsCoarseCuda(
points,
cloud_to_packed_first_idx,
@@ -208,6 +215,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
const int points_per_pixel) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(bin_points);
return RasterizePointsFineCuda(
points, bin_points, image_size, radius, bin_size, points_per_pixel);
#else
@@ -257,6 +266,10 @@ torch::Tensor RasterizePointsBackward(
const torch::Tensor& grad_dists) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(idxs);
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists);
#else
AT_ERROR("Not compiled with GPU support");

View File

@@ -3,9 +3,9 @@
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x "must be a CUDA tensor.")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x "must be contiguous.")
TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.")
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)