Cuda updates

Summary:
Updates to:
- enable cuda kernel launches on any GPU (not just the default)
- cuda and contiguous checks for all kernels
- checks to ensure all tensors are on the same device
- error reporting in the cuda kernels
- cuda tests now run on a random device not just the default

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21215280

fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
Nikhila Ravi 2020-04-24 09:07:54 -07:00 committed by Facebook GitHub Bot
parent c9267ab7af
commit c3d636dc8c
33 changed files with 979 additions and 240 deletions

View File

@ -2,6 +2,8 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h> #include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
@ -136,6 +138,17 @@ at::Tensor alphaCompositeCudaForward(
const at::Tensor& features, const at::Tensor& features,
const at::Tensor& alphas, const at::Tensor& alphas,
const at::Tensor& points_idx) { const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg features_t{features, "features", 1},
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
at::CheckedFrom c = "alphaCompositeCudaForward";
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = points_idx.size(0); const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
@ -143,19 +156,24 @@ at::Tensor alphaCompositeCudaForward(
auto result = at::zeros({batch_size, C, H, W}, features.options()); auto result = at::zeros({batch_size, C, H, W}, features.options());
if (result.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
const dim3 threadsPerBlock(64); const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1); const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only. // doubles. Currently, support is for floats only.
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>( alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off // clang-format off
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(), result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(), features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>()); points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on // clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return result; return result;
} }
@ -164,9 +182,26 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
const at::Tensor& features, const at::Tensor& features,
const at::Tensor& alphas, const at::Tensor& alphas,
const at::Tensor& points_idx) { const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
points_idx_t{points_idx, "points_idx", 4};
at::CheckedFrom c = "alphaCompositeCudaBackward";
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_features = at::zeros_like(features); auto grad_features = at::zeros_like(features);
auto grad_alphas = at::zeros_like(alphas); auto grad_alphas = at::zeros_like(alphas);
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}
const int64_t bs = alphas.size(0); const int64_t bs = alphas.size(0);
const dim3 threadsPerBlock(64); const dim3 threadsPerBlock(64);
@ -174,7 +209,7 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only. // doubles. Currently, support is for floats only.
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>( alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off // clang-format off
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(), grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
@ -183,6 +218,6 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>()); points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on // clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas); return std::make_tuple(grad_features, grad_alphas);
} }

View File

@ -2,6 +2,8 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h> #include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
@ -151,6 +153,17 @@ at::Tensor weightedSumNormCudaForward(
const at::Tensor& features, const at::Tensor& features,
const at::Tensor& alphas, const at::Tensor& alphas,
const at::Tensor& points_idx) { const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg features_t{features, "features", 1},
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
at::CheckedFrom c = "weightedSumNormCudaForward";
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = points_idx.size(0); const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
@ -158,19 +171,25 @@ at::Tensor weightedSumNormCudaForward(
auto result = at::zeros({batch_size, C, H, W}, features.options()); auto result = at::zeros({batch_size, C, H, W}, features.options());
if (result.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
const dim3 threadsPerBlock(64); const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1); const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only. // doubles. Currently, support is for floats only.
// clang-format off // clang-format off
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>>( weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(), result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(), features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>()); points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on // clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return result; return result;
} }
@ -179,9 +198,26 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
const at::Tensor& features, const at::Tensor& features,
const at::Tensor& alphas, const at::Tensor& alphas,
const at::Tensor& points_idx) { const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
points_idx_t{points_idx, "points_idx", 4};
at::CheckedFrom c = "weightedSumNormCudaBackward";
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_features = at::zeros_like(features); auto grad_features = at::zeros_like(features);
auto grad_alphas = at::zeros_like(alphas); auto grad_alphas = at::zeros_like(alphas);
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}
const int64_t bs = points_idx.size(0); const int64_t bs = points_idx.size(0);
const dim3 threadsPerBlock(64); const dim3 threadsPerBlock(64);
@ -189,7 +225,7 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only. // doubles. Currently, support is for floats only.
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>( weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off // clang-format off
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(), grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
@ -198,6 +234,6 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>()); points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on // clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas); return std::make_tuple(grad_features, grad_alphas);
} }

View File

@ -2,6 +2,8 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/core/TensorAccessor.h> #include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
@ -110,6 +112,17 @@ at::Tensor weightedSumCudaForward(
const at::Tensor& features, const at::Tensor& features,
const at::Tensor& alphas, const at::Tensor& alphas,
const at::Tensor& points_idx) { const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg features_t{features, "features", 1},
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
at::CheckedFrom c = "weightedSumCudaForward";
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = points_idx.size(0); const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
@ -117,19 +130,24 @@ at::Tensor weightedSumCudaForward(
auto result = at::zeros({batch_size, C, H, W}, features.options()); auto result = at::zeros({batch_size, C, H, W}, features.options());
if (result.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return result;
}
const dim3 threadsPerBlock(64); const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1); const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only. // doubles. Currently, support is for floats only.
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock>>>( weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off // clang-format off
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(), result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(), features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>()); points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on // clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return result; return result;
} }
@ -138,9 +156,26 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
const at::Tensor& features, const at::Tensor& features,
const at::Tensor& alphas, const at::Tensor& alphas,
const at::Tensor& points_idx) { const at::Tensor& points_idx) {
// Check inputs are on the same device
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
points_idx_t{points_idx, "points_idx", 4};
at::CheckedFrom c = "weightedSumCudaBackward";
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(features.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_features = at::zeros_like(features); auto grad_features = at::zeros_like(features);
auto grad_alphas = at::zeros_like(alphas); auto grad_alphas = at::zeros_like(alphas);
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas);
}
const int64_t bs = points_idx.size(0); const int64_t bs = points_idx.size(0);
const dim3 threadsPerBlock(64); const dim3 threadsPerBlock(64);
@ -148,7 +183,7 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only. // doubles. Currently, support is for floats only.
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>( weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
// clang-format off // clang-format off
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(), grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
@ -157,6 +192,6 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(), alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>()); points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
// clang-format on // clang-format on
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_features, grad_alphas); return std::make_tuple(grad_features, grad_alphas);
} }

View File

@ -23,7 +23,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif #endif
m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward); m.def("knn_points_backward", &KNearestNeighborBackward);
m.def("gather_scatter", &gather_scatter); m.def("gather_scatter", &GatherScatter);
m.def("rasterize_points", &RasterizePoints); m.def("rasterize_points", &RasterizePoints);
m.def("rasterize_points_backward", &RasterizePointsBackward); m.def("rasterize_points_backward", &RasterizePointsBackward);
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward); m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);

View File

@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <tuple> #include <tuple>
template <typename scalar_t> template <typename scalar_t>
@ -213,14 +215,30 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
const auto V = verts.size(0); const auto V = verts.size(0);
const auto F = faces.size(0); const auto F = faces.size(0);
// Check inputs are on the same device
at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2};
at::CheckedFrom c = "FaceAreasNormalsForwardCuda";
at::checkAllSameGPU(c, {verts_t, faces_t});
at::checkAllSameType(c, {verts_t, faces_t});
// Set the device for the kernel launch based on the device of verts
at::cuda::CUDAGuard device_guard(verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::Tensor areas = at::empty({F}, verts.options()); at::Tensor areas = at::empty({F}, verts.options());
at::Tensor normals = at::empty({F, 3}, verts.options()); at::Tensor normals = at::empty({F, 3}, verts.options());
if (areas.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(areas, normals);
}
const int blocks = 64; const int blocks = 64;
const int threads = 512; const int threads = 512;
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
verts.scalar_type(), "face_areas_normals_forward_cuda", ([&] { verts.scalar_type(), "face_areas_normals_forward_cuda", ([&] {
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads>>>( FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
verts.data_ptr<scalar_t>(), verts.data_ptr<scalar_t>(),
faces.data_ptr<int64_t>(), faces.data_ptr<int64_t>(),
areas.data_ptr<scalar_t>(), areas.data_ptr<scalar_t>(),
@ -228,7 +246,7 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
V, V,
F); F);
})); }));
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(areas, normals); return std::make_tuple(areas, normals);
} }
@ -237,16 +255,33 @@ at::Tensor FaceAreasNormalsBackwardCuda(
const at::Tensor grad_normals, const at::Tensor grad_normals,
const at::Tensor verts, const at::Tensor verts,
const at::Tensor faces) { const at::Tensor faces) {
// Check inputs are on the same device
at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2},
grad_areas_t{verts, "grad_areas", 3},
grad_normals_t{verts, "grad_normals", 4};
at::CheckedFrom c = "FaceAreasNormalsBackwardCuda";
at::checkAllSameGPU(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
at::checkAllSameType(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
// Set the device for the kernel launch based on the device of verts
at::cuda::CUDAGuard device_guard(verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto V = verts.size(0); const auto V = verts.size(0);
const auto F = faces.size(0); const auto F = faces.size(0);
at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options()); at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options());
if (grad_verts.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_verts;
}
const int blocks = 64; const int blocks = 64;
const int threads = 512; const int threads = 512;
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only. // doubles. Currently, support is for floats only.
FaceAreasNormalsBackwardKernel<<<blocks, threads>>>( FaceAreasNormalsBackwardKernel<<<blocks, threads, 0, stream>>>(
grad_areas.data_ptr<float>(), grad_areas.data_ptr<float>(),
grad_normals.data_ptr<float>(), grad_normals.data_ptr<float>(),
verts.data_ptr<float>(), verts.data_ptr<float>(),
@ -255,5 +290,6 @@ at::Tensor FaceAreasNormalsBackwardCuda(
V, V,
F); F);
AT_CUDA_CHECK(cudaGetLastError());
return grad_verts; return grad_verts;
} }

View File

@ -3,6 +3,7 @@
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
#include <tuple> #include <tuple>
#include "utils/pytorch3d_cutils.h"
// Compute areas of mesh faces using packed representation. // Compute areas of mesh faces using packed representation.
// //
@ -46,6 +47,8 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
const at::Tensor faces) { const at::Tensor faces) {
if (verts.is_cuda() && faces.is_cuda()) { if (verts.is_cuda() && faces.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(verts);
CHECK_CONTIGUOUS_CUDA(faces);
return FaceAreasNormalsForwardCuda(verts, faces); return FaceAreasNormalsForwardCuda(verts, faces);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
@ -62,6 +65,10 @@ at::Tensor FaceAreasNormalsBackward(
const at::Tensor faces) { const at::Tensor faces) {
if (verts.is_cuda() && faces.is_cuda()) { if (verts.is_cuda() && faces.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(verts);
CHECK_CONTIGUOUS_CUDA(faces);
CHECK_CONTIGUOUS_CUDA(grad_areas);
CHECK_CONTIGUOUS_CUDA(grad_normals);
return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces); return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");

View File

@ -1,9 +1,11 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
// TODO(T47953967) to make this cuda kernel support all datatypes. // TODO(T47953967) to make this cuda kernel support all datatypes.
__global__ void gather_scatter_kernel( __global__ void GatherScatterCudaKernel(
const float* __restrict__ input, const float* __restrict__ input,
const int64_t* __restrict__ edges, const int64_t* __restrict__ edges,
float* __restrict__ output, float* __restrict__ output,
@ -41,11 +43,20 @@ __global__ void gather_scatter_kernel(
} }
} }
at::Tensor gather_scatter_cuda( at::Tensor GatherScatterCuda(
const at::Tensor input, const at::Tensor input,
const at::Tensor edges, const at::Tensor edges,
bool directed, bool directed,
bool backward) { bool backward) {
// Check inputs are on the same device
at::TensorArg input_t{input, "input", 1}, edges_t{edges, "edges", 2};
at::CheckedFrom c = "GatherScatterCuda";
at::checkAllSameGPU(c, {input_t, edges_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto num_vertices = input.size(0); const auto num_vertices = input.size(0);
const auto input_feature_dim = input.size(1); const auto input_feature_dim = input.size(1);
const auto num_edges = edges.size(0); const auto num_edges = edges.size(0);
@ -55,7 +66,12 @@ at::Tensor gather_scatter_cuda(
const size_t max_blocks = 1920; const size_t max_blocks = 1920;
const size_t blocks = num_edges < max_blocks ? num_edges : max_blocks; const size_t blocks = num_edges < max_blocks ? num_edges : max_blocks;
gather_scatter_kernel<<<blocks, threads>>>( if (output.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return output;
}
GatherScatterCudaKernel<<<blocks, threads, 0, stream>>>(
input.data_ptr<float>(), input.data_ptr<float>(),
edges.data_ptr<int64_t>(), edges.data_ptr<int64_t>(),
output.data_ptr<float>(), output.data_ptr<float>(),
@ -64,6 +80,6 @@ at::Tensor gather_scatter_cuda(
num_vertices, num_vertices,
input_feature_dim, input_feature_dim,
num_edges); num_edges);
AT_CUDA_CHECK(cudaGetLastError());
return output; return output;
} }

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
#include "utils/pytorch3d_cutils.h"
// Fused gather scatter operation for aggregating features of neighbor nodes // Fused gather scatter operation for aggregating features of neighbor nodes
// in a graph. This gather scatter operation is specific to graphs as edge // in a graph. This gather scatter operation is specific to graphs as edge
@ -20,21 +21,23 @@
// output: float32 Tensor of same shape as input. // output: float32 Tensor of same shape as input.
// Cuda implementation. // Cuda implementation.
at::Tensor gather_scatter_cuda( at::Tensor GatherScatterCuda(
const at::Tensor input, const at::Tensor input,
const at::Tensor edges, const at::Tensor edges,
bool directed, bool directed,
bool backward); bool backward);
// Exposed implementation. // Exposed implementation.
at::Tensor gather_scatter( at::Tensor GatherScatter(
const at::Tensor input, const at::Tensor input,
const at::Tensor edges, const at::Tensor edges,
bool directed, bool directed,
bool backward) { bool backward) {
if (input.is_cuda() && edges.is_cuda()) { if (input.is_cuda() && edges.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return gather_scatter_cuda(input, edges, directed, backward); CHECK_CONTIGUOUS_CUDA(input);
CHECK_CONTIGUOUS_CUDA(edges);
return GatherScatterCuda(input, edges, directed, backward);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
#endif #endif

View File

@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <float.h> #include <float.h>
#include <iostream> #include <iostream>
#include <tuple> #include <tuple>
@ -114,7 +116,8 @@ struct KNearestNeighborV1Functor {
const size_t P1, const size_t P1,
const size_t P2, const size_t P2,
const size_t K) { const size_t K) {
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>( cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K); points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
} }
}; };
@ -178,7 +181,8 @@ struct KNearestNeighborKernelV2Functor {
const int64_t N, const int64_t N,
const int64_t P1, const int64_t P1,
const int64_t P2) { const int64_t P2) {
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>( cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
} }
}; };
@ -245,7 +249,8 @@ struct KNearestNeighborKernelV3Functor {
const size_t N, const size_t N,
const size_t P1, const size_t P1,
const size_t P2) { const size_t P2) {
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>( cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
} }
}; };
@ -296,17 +301,33 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& lengths2, const at::Tensor& lengths2,
int K, int K,
int version) { int version) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
at::CheckedFrom c = "KNearestNeighborIdxCuda";
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
at::checkAllSameType(c, {p1_t, p2_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto N = p1.size(0); const auto N = p1.size(0);
const auto P1 = p1.size(1); const auto P1 = p1.size(1);
const auto P2 = p2.size(1); const auto P2 = p2.size(1);
const auto D = p2.size(2); const auto D = p2.size(2);
const int64_t K_64 = K; const int64_t K_64 = K;
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension"); TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = p1.options().dtype(at::kLong); auto long_dtype = p1.options().dtype(at::kLong);
auto idxs = at::zeros({N, P1, K}, long_dtype); auto idxs = at::zeros({N, P1, K}, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options()); auto dists = at::zeros({N, P1, K}, p1.options());
if (idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists);
}
if (version < 0) { if (version < 0) {
version = ChooseVersion(D, K); version = ChooseVersion(D, K);
} else if (!KnnCheckVersion(version, D, K)) { } else if (!KnnCheckVersion(version, D, K)) {
@ -328,7 +349,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
if (version == 0) { if (version == 0) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
KNearestNeighborKernelV0<scalar_t> KNearestNeighborKernelV0<scalar_t>
<<<blocks, threads>>>( <<<blocks, threads, 0, stream>>>(
p1.data_ptr<scalar_t>(), p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(), p2.data_ptr<scalar_t>(),
lengths1.data_ptr<int64_t>(), lengths1.data_ptr<int64_t>(),
@ -409,7 +430,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
P2); P2);
})); }));
} }
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(idxs, dists); return std::make_tuple(idxs, dists);
} }
@ -465,27 +486,45 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& lengths2, const at::Tensor& lengths2,
const at::Tensor& idxs, const at::Tensor& idxs,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4},
idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6};
at::CheckedFrom c = "KNearestNeighborBackwardCuda";
at::checkAllSameGPU(
c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto N = p1.size(0); const auto N = p1.size(0);
const auto P1 = p1.size(1); const auto P1 = p1.size(1);
const auto P2 = p2.size(1); const auto P2 = p2.size(1);
const auto D = p2.size(2); const auto D = p2.size(2);
const auto K = idxs.size(2); const auto K = idxs.size(2);
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension"); TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension"); TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
AT_ASSERTM( TORCH_CHECK(
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1"); idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
AT_ASSERTM(grad_dists.size(0) == N); TORCH_CHECK(grad_dists.size(0) == N);
AT_ASSERTM(grad_dists.size(1) == P1); TORCH_CHECK(grad_dists.size(1) == P1);
AT_ASSERTM(grad_dists.size(2) == K); TORCH_CHECK(grad_dists.size(2) == K);
auto grad_p1 = at::zeros({N, P1, D}, p1.options()); auto grad_p1 = at::zeros({N, P1, D}, p1.options());
auto grad_p2 = at::zeros({N, P2, D}, p2.options()); auto grad_p2 = at::zeros({N, P2, D}, p2.options());
if (grad_p1.numel() == 0 || grad_p2.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
}
const int blocks = 64; const int blocks = 64;
const int threads = 512; const int threads = 512;
KNearestNeighborBackwardKernel<<<blocks, threads>>>( KNearestNeighborBackwardKernel<<<blocks, threads, 0, stream>>>(
p1.data_ptr<float>(), p1.data_ptr<float>(),
p2.data_ptr<float>(), p2.data_ptr<float>(),
lengths1.data_ptr<int64_t>(), lengths1.data_ptr<int64_t>(),
@ -500,5 +539,6 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
K, K,
D); D);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2); return std::make_tuple(grad_p1, grad_p2);
} }

View File

@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
// Kernel for inputs_packed of shape (F, D), where D > 1 // Kernel for inputs_packed of shape (F, D), where D > 1
template <typename scalar_t> template <typename scalar_t>
@ -114,21 +116,36 @@ at::Tensor PackedToPaddedCuda(
const at::Tensor inputs_packed, const at::Tensor inputs_packed,
const at::Tensor first_idxs, const at::Tensor first_idxs,
const int64_t max_size) { const int64_t max_size) {
// Check inputs are on the same device
at::TensorArg inputs_packed_t{inputs_packed, "inputs_packed", 1},
first_idxs_t{first_idxs, "first_idxs", 2};
at::CheckedFrom c = "PackedToPaddedCuda";
at::checkAllSameGPU(c, {inputs_packed_t, first_idxs_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(inputs_packed.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t num_inputs = inputs_packed.size(0); const int64_t num_inputs = inputs_packed.size(0);
const int64_t batch_size = first_idxs.size(0); const int64_t batch_size = first_idxs.size(0);
AT_ASSERTM( TORCH_CHECK(
inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor"); inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
const int64_t D = inputs_packed.size(1); const int64_t D = inputs_packed.size(1);
at::Tensor inputs_padded = at::Tensor inputs_padded =
at::zeros({batch_size, max_size, D}, inputs_packed.options()); at::zeros({batch_size, max_size, D}, inputs_packed.options());
if (inputs_padded.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return inputs_padded;
}
const int threads = 512; const int threads = 512;
const int blocks = batch_size; const int blocks = batch_size;
if (D == 1) { if (D == 1) {
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
inputs_packed.scalar_type(), "packed_to_padded_d1_kernel", ([&] { inputs_packed.scalar_type(), "packed_to_padded_d1_kernel", ([&] {
PackedToPaddedKernelD1<scalar_t><<<blocks, threads>>>( PackedToPaddedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
inputs_packed.data_ptr<scalar_t>(), inputs_packed.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(), first_idxs.data_ptr<int64_t>(),
inputs_padded.data_ptr<scalar_t>(), inputs_padded.data_ptr<scalar_t>(),
@ -139,7 +156,7 @@ at::Tensor PackedToPaddedCuda(
} else { } else {
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
inputs_packed.scalar_type(), "packed_to_padded_kernel", ([&] { inputs_packed.scalar_type(), "packed_to_padded_kernel", ([&] {
PackedToPaddedKernel<scalar_t><<<blocks, threads>>>( PackedToPaddedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
inputs_packed.data_ptr<scalar_t>(), inputs_packed.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(), first_idxs.data_ptr<int64_t>(),
inputs_padded.data_ptr<scalar_t>(), inputs_padded.data_ptr<scalar_t>(),
@ -150,6 +167,7 @@ at::Tensor PackedToPaddedCuda(
})); }));
} }
AT_CUDA_CHECK(cudaGetLastError());
return inputs_padded; return inputs_padded;
} }
@ -157,11 +175,21 @@ at::Tensor PaddedToPackedCuda(
const at::Tensor inputs_padded, const at::Tensor inputs_padded,
const at::Tensor first_idxs, const at::Tensor first_idxs,
const int64_t num_inputs) { const int64_t num_inputs) {
// Check inputs are on the same device
at::TensorArg inputs_padded_t{inputs_padded, "inputs_padded", 1},
first_idxs_t{first_idxs, "first_idxs", 2};
at::CheckedFrom c = "PaddedToPackedCuda";
at::checkAllSameGPU(c, {inputs_padded_t, first_idxs_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(inputs_padded.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = inputs_padded.size(0); const int64_t batch_size = inputs_padded.size(0);
const int64_t max_size = inputs_padded.size(1); const int64_t max_size = inputs_padded.size(1);
AT_ASSERTM(batch_size == first_idxs.size(0), "sizes mismatch"); TORCH_CHECK(batch_size == first_idxs.size(0), "sizes mismatch");
AT_ASSERTM( TORCH_CHECK(
inputs_padded.dim() == 3, inputs_padded.dim() == 3,
"inputs_padded must be a 3-dimensional tensor"); "inputs_padded must be a 3-dimensional tensor");
const int64_t D = inputs_padded.size(2); const int64_t D = inputs_padded.size(2);
@ -169,13 +197,18 @@ at::Tensor PaddedToPackedCuda(
at::Tensor inputs_packed = at::Tensor inputs_packed =
at::zeros({num_inputs, D}, inputs_padded.options()); at::zeros({num_inputs, D}, inputs_padded.options());
if (inputs_packed.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return inputs_packed;
}
const int threads = 512; const int threads = 512;
const int blocks = batch_size; const int blocks = batch_size;
if (D == 1) { if (D == 1) {
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
inputs_padded.scalar_type(), "padded_to_packed_d1_kernel", ([&] { inputs_padded.scalar_type(), "padded_to_packed_d1_kernel", ([&] {
PaddedToPackedKernelD1<scalar_t><<<blocks, threads>>>( PaddedToPackedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
inputs_padded.data_ptr<scalar_t>(), inputs_padded.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(), first_idxs.data_ptr<int64_t>(),
inputs_packed.data_ptr<scalar_t>(), inputs_packed.data_ptr<scalar_t>(),
@ -186,7 +219,7 @@ at::Tensor PaddedToPackedCuda(
} else { } else {
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
inputs_padded.scalar_type(), "padded_to_packed_kernel", ([&] { inputs_padded.scalar_type(), "padded_to_packed_kernel", ([&] {
PaddedToPackedKernel<scalar_t><<<blocks, threads>>>( PaddedToPackedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
inputs_padded.data_ptr<scalar_t>(), inputs_padded.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(), first_idxs.data_ptr<int64_t>(),
inputs_packed.data_ptr<scalar_t>(), inputs_packed.data_ptr<scalar_t>(),
@ -197,5 +230,6 @@ at::Tensor PaddedToPackedCuda(
})); }));
} }
AT_CUDA_CHECK(cudaGetLastError());
return inputs_packed; return inputs_packed;
} }

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
#include "utils/pytorch3d_cutils.h"
// PackedToPadded // PackedToPadded
// Converts a packed tensor into a padded tensor, restoring the batch dimension. // Converts a packed tensor into a padded tensor, restoring the batch dimension.
@ -74,6 +75,8 @@ at::Tensor PackedToPadded(
const int64_t max_size) { const int64_t max_size) {
if (inputs_packed.is_cuda()) { if (inputs_packed.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(inputs_packed);
CHECK_CONTIGUOUS_CUDA(first_idxs);
return PackedToPaddedCuda(inputs_packed, first_idxs, max_size); return PackedToPaddedCuda(inputs_packed, first_idxs, max_size);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
@ -89,6 +92,8 @@ at::Tensor PaddedToPacked(
const int64_t num_inputs) { const int64_t num_inputs) {
if (inputs_padded.is_cuda()) { if (inputs_padded.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(inputs_padded);
CHECK_CONTIGUOUS_CUDA(first_idxs);
return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs); return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");

View File

@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm> #include <algorithm>
#include <list> #include <list>
#include <queue> #include <queue>
@ -103,26 +105,45 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
const at::Tensor& segms, const at::Tensor& segms,
const at::Tensor& segms_first_idx, const at::Tensor& segms_first_idx,
const int64_t max_points) { const int64_t max_points) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2},
segms_t{segms, "segms", 3},
segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
at::CheckedFrom c = "PointEdgeDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
at::checkAllSameType(c, {points_t, segms_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t S = segms.size(0); const int64_t S = segms.size(0);
const int64_t B = points_first_idx.size(0); const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3), (segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3"); "segms must be of shape Sx2x3");
AT_ASSERTM(segms_first_idx.size(0) == B); TORCH_CHECK(segms_first_idx.size(0) == B);
// clang-format off // clang-format off
at::Tensor dists = at::zeros({P,}, points.options()); at::Tensor dists = at::zeros({P,}, points.options());
at::Tensor idxs = at::zeros({P,}, points_first_idx.options()); at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
// clang-format on // clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128; const int threads = 128;
const dim3 blocks(max_points, B); const dim3 blocks(max_points, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
PointEdgeForwardKernel<<<blocks, threads, shared_size>>>( PointEdgeForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(), points_first_idx.data_ptr<int64_t>(),
segms.data_ptr<float>(), segms.data_ptr<float>(),
@ -132,7 +153,7 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
B, B,
P, P,
S); S);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs); return std::make_tuple(dists, idxs);
} }
@ -183,25 +204,42 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
const at::Tensor& segms, const at::Tensor& segms,
const at::Tensor& idx_points, const at::Tensor& idx_points,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
idx_points_t{idx_points, "idx_points", 2}, segms_t{segms, "segms", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, idx_points_t, segms_t, grad_dists_t});
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t S = segms.size(0); const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3), (segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3"); "segms must be of shape Sx2x3");
AT_ASSERTM(idx_points.size(0) == P); TORCH_CHECK(idx_points.size(0) == P);
AT_ASSERTM(grad_dists.size(0) == P); TORCH_CHECK(grad_dists.size(0) == P);
// clang-format off // clang-format off
at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options()); at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
// clang-format on // clang-format on
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_segms);
}
const int blocks = 64; const int blocks = 64;
const int threads = 512; const int threads = 512;
PointEdgeBackwardKernel<<<blocks, threads>>>( PointEdgeBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
segms.data_ptr<float>(), segms.data_ptr<float>(),
idx_points.data_ptr<int64_t>(), idx_points.data_ptr<int64_t>(),
@ -210,6 +248,7 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
grad_segms.data_ptr<float>(), grad_segms.data_ptr<float>(),
P); P);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_segms); return std::make_tuple(grad_points, grad_segms);
} }
@ -308,26 +347,45 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
const at::Tensor& segms, const at::Tensor& segms,
const at::Tensor& segms_first_idx, const at::Tensor& segms_first_idx,
const int64_t max_segms) { const int64_t max_segms) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2},
segms_t{segms, "segms", 3},
segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
at::CheckedFrom c = "EdgePointDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
at::checkAllSameType(c, {points_t, segms_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t S = segms.size(0); const int64_t S = segms.size(0);
const int64_t B = points_first_idx.size(0); const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3), (segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3"); "segms must be of shape Sx2x3");
AT_ASSERTM(segms_first_idx.size(0) == B); TORCH_CHECK(segms_first_idx.size(0) == B);
// clang-format off // clang-format off
at::Tensor dists = at::zeros({S,}, segms.options()); at::Tensor dists = at::zeros({S,}, segms.options());
at::Tensor idxs = at::zeros({S,}, segms_first_idx.options()); at::Tensor idxs = at::zeros({S,}, segms_first_idx.options());
// clang-format on // clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128; const int threads = 128;
const dim3 blocks(max_segms, B); const dim3 blocks(max_segms, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
EdgePointForwardKernel<<<blocks, threads, shared_size>>>( EdgePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(), points_first_idx.data_ptr<int64_t>(),
segms.data_ptr<float>(), segms.data_ptr<float>(),
@ -337,7 +395,7 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
B, B,
P, P,
S); S);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs); return std::make_tuple(dists, idxs);
} }
@ -389,15 +447,27 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
const at::Tensor& segms, const at::Tensor& segms,
const at::Tensor& idx_segms, const at::Tensor& idx_segms,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
idx_segms_t{idx_segms, "idx_segms", 2}, segms_t{segms, "segms", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, idx_segms_t, segms_t, grad_dists_t});
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t S = segms.size(0); const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3), (segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3"); "segms must be of shape Sx2x3");
AT_ASSERTM(idx_segms.size(0) == S); TORCH_CHECK(idx_segms.size(0) == S);
AT_ASSERTM(grad_dists.size(0) == S); TORCH_CHECK(grad_dists.size(0) == S);
// clang-format off // clang-format off
at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_points = at::zeros({P, 3}, points.options());
@ -407,7 +477,7 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
const int blocks = 64; const int blocks = 64;
const int threads = 512; const int threads = 512;
EdgePointBackwardKernel<<<blocks, threads>>>( EdgePointBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
segms.data_ptr<float>(), segms.data_ptr<float>(),
idx_segms.data_ptr<int64_t>(), idx_segms.data_ptr<int64_t>(),
@ -451,26 +521,42 @@ __global__ void PointEdgeArrayForwardKernel(
at::Tensor PointEdgeArrayDistanceForwardCuda( at::Tensor PointEdgeArrayDistanceForwardCuda(
const at::Tensor& points, const at::Tensor& points,
const at::Tensor& segms) { const at::Tensor& segms) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2};
at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda";
at::checkAllSameGPU(c, {points_t, segms_t});
at::checkAllSameType(c, {points_t, segms_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t S = segms.size(0); const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3), (segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3"); "segms must be of shape Sx2x3");
at::Tensor dists = at::zeros({P, S}, points.options()); at::Tensor dists = at::zeros({P, S}, points.options());
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return dists;
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
PointEdgeArrayForwardKernel<<<blocks, threads>>>( PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
segms.data_ptr<float>(), segms.data_ptr<float>(),
dists.data_ptr<float>(), dists.data_ptr<float>(),
P, P,
S); S);
AT_CUDA_CHECK(cudaGetLastError());
return dists; return dists;
} }
@ -520,22 +606,38 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
const at::Tensor& points, const at::Tensor& points,
const at::Tensor& segms, const at::Tensor& segms,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2},
grad_dists_t{grad_dists, "grad_dists", 3};
at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t S = segms.size(0); const int64_t S = segms.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(segms.size(1) == 2) && (segms.size(2) == 3), (segms.size(1) == 2) && (segms.size(2) == 3),
"segms must be of shape Sx2x3"); "segms must be of shape Sx2x3");
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == S)); TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options()); at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_segms);
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
PointEdgeArrayBackwardKernel<<<blocks, threads>>>( PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
segms.data_ptr<float>(), segms.data_ptr<float>(),
grad_dists.data_ptr<float>(), grad_dists.data_ptr<float>(),
@ -543,6 +645,6 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
grad_segms.data_ptr<float>(), grad_segms.data_ptr<float>(),
P, P,
S); S);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_segms); return std::make_tuple(grad_points, grad_segms);
} }

View File

@ -4,6 +4,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <cstdio> #include <cstdio>
#include <tuple> #include <tuple>
#include "utils/pytorch3d_cutils.h"
// **************************************************************************** // ****************************************************************************
// * PointEdgeDistance * // * PointEdgeDistance *
@ -53,6 +54,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
const int64_t max_points) { const int64_t max_points) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(points_first_idx);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(segms_first_idx);
return PointEdgeDistanceForwardCuda( return PointEdgeDistanceForwardCuda(
points, points_first_idx, segms, segms_first_idx, max_points); points, points_first_idx, segms, segms_first_idx, max_points);
#else #else
@ -93,6 +98,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
const torch::Tensor& grad_dists) { const torch::Tensor& grad_dists) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(idx_points);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists); return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
@ -149,6 +158,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
const int64_t max_segms) { const int64_t max_segms) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(points_first_idx);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(segms_first_idx);
return EdgePointDistanceForwardCuda( return EdgePointDistanceForwardCuda(
points, points_first_idx, segms, segms_first_idx, max_segms); points, points_first_idx, segms, segms_first_idx, max_segms);
#else #else
@ -189,6 +202,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
const torch::Tensor& grad_dists) { const torch::Tensor& grad_dists) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(idx_segms);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists); return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
@ -220,7 +237,6 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
// will require for the forward pass 5.8G of memory to store dists. // will require for the forward pass 5.8G of memory to store dists.
#ifdef WITH_CUDA #ifdef WITH_CUDA
torch::Tensor PointEdgeArrayDistanceForwardCuda( torch::Tensor PointEdgeArrayDistanceForwardCuda(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& segms); const torch::Tensor& segms);
@ -231,6 +247,8 @@ torch::Tensor PointEdgeArrayDistanceForward(
const torch::Tensor& segms) { const torch::Tensor& segms) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(segms);
return PointEdgeArrayDistanceForwardCuda(points, segms); return PointEdgeArrayDistanceForwardCuda(points, segms);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
@ -265,6 +283,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
const torch::Tensor& grad_dists) { const torch::Tensor& grad_dists) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(segms);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists); return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");

View File

@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm> #include <algorithm>
#include <list> #include <list>
#include <queue> #include <queue>
@ -104,26 +106,45 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
const at::Tensor& tris, const at::Tensor& tris,
const at::Tensor& tris_first_idx, const at::Tensor& tris_first_idx,
const int64_t max_points) { const int64_t max_points) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2},
tris_t{tris, "tris", 3},
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
at::CheckedFrom c = "PointFaceDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
at::checkAllSameType(c, {points_t, tris_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t T = tris.size(0); const int64_t T = tris.size(0);
const int64_t B = points_first_idx.size(0); const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3), (tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3"); "tris must be of shape Tx3x3");
AT_ASSERTM(tris_first_idx.size(0) == B); TORCH_CHECK(tris_first_idx.size(0) == B);
// clang-format off // clang-format off
at::Tensor dists = at::zeros({P,}, points.options()); at::Tensor dists = at::zeros({P,}, points.options());
at::Tensor idxs = at::zeros({P,}, points_first_idx.options()); at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
// clang-format on // clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128; const int threads = 128;
const dim3 blocks(max_points, B); const dim3 blocks(max_points, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
PointFaceForwardKernel<<<blocks, threads, shared_size>>>( PointFaceForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(), points_first_idx.data_ptr<int64_t>(),
tris.data_ptr<float>(), tris.data_ptr<float>(),
@ -134,6 +155,7 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
P, P,
T); T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs); return std::make_tuple(dists, idxs);
} }
@ -191,25 +213,42 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
const at::Tensor& tris, const at::Tensor& tris,
const at::Tensor& idx_points, const at::Tensor& idx_points,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
idx_points_t{idx_points, "idx_points", 2}, tris_t{tris, "tris", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "PointFaceDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, idx_points_t, tris_t, grad_dists_t});
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t T = tris.size(0); const int64_t T = tris.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3), (tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3"); "tris must be of shape Tx3x3");
AT_ASSERTM(idx_points.size(0) == P); TORCH_CHECK(idx_points.size(0) == P);
AT_ASSERTM(grad_dists.size(0) == P); TORCH_CHECK(grad_dists.size(0) == P);
// clang-format off // clang-format off
at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options()); at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
// clang-format on // clang-format on
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}
const int blocks = 64; const int blocks = 64;
const int threads = 512; const int threads = 512;
PointFaceBackwardKernel<<<blocks, threads>>>( PointFaceBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
tris.data_ptr<float>(), tris.data_ptr<float>(),
idx_points.data_ptr<int64_t>(), idx_points.data_ptr<int64_t>(),
@ -218,6 +257,7 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
grad_tris.data_ptr<float>(), grad_tris.data_ptr<float>(),
P); P);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris); return std::make_tuple(grad_points, grad_tris);
} }
@ -317,26 +357,45 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
const at::Tensor& tris, const at::Tensor& tris,
const at::Tensor& tris_first_idx, const at::Tensor& tris_first_idx,
const int64_t max_tris) { const int64_t max_tris) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2},
tris_t{tris, "tris", 3},
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
at::CheckedFrom c = "FacePointDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
at::checkAllSameType(c, {points_t, tris_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t T = tris.size(0); const int64_t T = tris.size(0);
const int64_t B = points_first_idx.size(0); const int64_t B = points_first_idx.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3), (tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3"); "tris must be of shape Tx3x3");
AT_ASSERTM(tris_first_idx.size(0) == B); TORCH_CHECK(tris_first_idx.size(0) == B);
// clang-format off // clang-format off
at::Tensor dists = at::zeros({T,}, tris.options()); at::Tensor dists = at::zeros({T,}, tris.options());
at::Tensor idxs = at::zeros({T,}, tris_first_idx.options()); at::Tensor idxs = at::zeros({T,}, tris_first_idx.options());
// clang-format on // clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128; const int threads = 128;
const dim3 blocks(max_tris, B); const dim3 blocks(max_tris, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
FacePointForwardKernel<<<blocks, threads, shared_size>>>( FacePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
points_first_idx.data_ptr<int64_t>(), points_first_idx.data_ptr<int64_t>(),
tris.data_ptr<float>(), tris.data_ptr<float>(),
@ -347,6 +406,7 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
P, P,
T); T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs); return std::make_tuple(dists, idxs);
} }
@ -405,25 +465,42 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
const at::Tensor& tris, const at::Tensor& tris,
const at::Tensor& idx_tris, const at::Tensor& idx_tris,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
idx_tris_t{idx_tris, "idx_tris", 2}, tris_t{tris, "tris", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "FacePointDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, idx_tris_t, tris_t, grad_dists_t});
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t T = tris.size(0); const int64_t T = tris.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3), (tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3"); "tris must be of shape Tx3x3");
AT_ASSERTM(idx_tris.size(0) == T); TORCH_CHECK(idx_tris.size(0) == T);
AT_ASSERTM(grad_dists.size(0) == T); TORCH_CHECK(grad_dists.size(0) == T);
// clang-format off // clang-format off
at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options()); at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
// clang-format on // clang-format on
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}
const int blocks = 64; const int blocks = 64;
const int threads = 512; const int threads = 512;
FacePointBackwardKernel<<<blocks, threads>>>( FacePointBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
tris.data_ptr<float>(), tris.data_ptr<float>(),
idx_tris.data_ptr<int64_t>(), idx_tris.data_ptr<int64_t>(),
@ -432,6 +509,7 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
grad_tris.data_ptr<float>(), grad_tris.data_ptr<float>(),
T); T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris); return std::make_tuple(grad_points, grad_tris);
} }
@ -468,26 +546,42 @@ __global__ void PointFaceArrayForwardKernel(
at::Tensor PointFaceArrayDistanceForwardCuda( at::Tensor PointFaceArrayDistanceForwardCuda(
const at::Tensor& points, const at::Tensor& points,
const at::Tensor& tris) { const at::Tensor& tris) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2};
at::CheckedFrom c = "PointFaceArrayDistanceForwardCuda";
at::checkAllSameGPU(c, {points_t, tris_t});
at::checkAllSameType(c, {points_t, tris_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t T = tris.size(0); const int64_t T = tris.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3), (tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3"); "tris must be of shape Tx3x3");
at::Tensor dists = at::zeros({P, T}, points.options()); at::Tensor dists = at::zeros({P, T}, points.options());
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return dists;
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
PointFaceArrayForwardKernel<<<blocks, threads>>>( PointFaceArrayForwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
tris.data_ptr<float>(), tris.data_ptr<float>(),
dists.data_ptr<float>(), dists.data_ptr<float>(),
P, P,
T); T);
AT_CUDA_CHECK(cudaGetLastError());
return dists; return dists;
} }
@ -546,22 +640,38 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
const at::Tensor& points, const at::Tensor& points,
const at::Tensor& tris, const at::Tensor& tris,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2},
grad_dists_t{grad_dists, "grad_dists", 3};
at::CheckedFrom c = "PointFaceArrayDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, tris_t, grad_dists_t});
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t P = points.size(0);
const int64_t T = tris.size(0); const int64_t T = tris.size(0);
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
AT_ASSERTM( TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3), (tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3"); "tris must be of shape Tx3x3");
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == T)); TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == T));
at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options()); at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
PointFaceArrayBackwardKernel<<<blocks, threads>>>( PointFaceArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
points.data_ptr<float>(), points.data_ptr<float>(),
tris.data_ptr<float>(), tris.data_ptr<float>(),
grad_dists.data_ptr<float>(), grad_dists.data_ptr<float>(),
@ -570,5 +680,6 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
P, P,
T); T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris); return std::make_tuple(grad_points, grad_tris);
} }

View File

@ -4,6 +4,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <cstdio> #include <cstdio>
#include <tuple> #include <tuple>
#include "utils/pytorch3d_cutils.h"
// **************************************************************************** // ****************************************************************************
// * PointFaceDistance * // * PointFaceDistance *
@ -55,6 +56,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
const int64_t max_points) { const int64_t max_points) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(points_first_idx);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(tris_first_idx);
return PointFaceDistanceForwardCuda( return PointFaceDistanceForwardCuda(
points, points_first_idx, tris, tris_first_idx, max_points); points, points_first_idx, tris, tris_first_idx, max_points);
#else #else
@ -95,6 +100,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
const torch::Tensor& grad_dists) { const torch::Tensor& grad_dists) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(idx_points);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists); return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
@ -151,6 +160,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
const int64_t max_tris) { const int64_t max_tris) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(points_first_idx);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(tris_first_idx);
return FacePointDistanceForwardCuda( return FacePointDistanceForwardCuda(
points, points_first_idx, tris, tris_first_idx, max_tris); points, points_first_idx, tris, tris_first_idx, max_tris);
#else #else
@ -191,6 +204,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
const torch::Tensor& grad_dists) { const torch::Tensor& grad_dists) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(idx_tris);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists); return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
@ -233,6 +250,8 @@ torch::Tensor PointFaceArrayDistanceForward(
const torch::Tensor& tris) { const torch::Tensor& tris) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(tris);
return PointFaceArrayDistanceForwardCuda(points, tris); return PointFaceArrayDistanceForwardCuda(points, tris);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
@ -254,7 +273,6 @@ torch::Tensor PointFaceArrayDistanceForward(
// //
#ifdef WITH_CUDA #ifdef WITH_CUDA
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda( std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
const torch::Tensor& points, const torch::Tensor& points,
const torch::Tensor& tris, const torch::Tensor& tris,
@ -267,6 +285,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
const torch::Tensor& grad_dists) { const torch::Tensor& grad_dists) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(tris);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists); return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");

View File

@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <float.h> #include <float.h>
#include <math.h> #include <math.h>
#include <thrust/tuple.h> #include <thrust/tuple.h>
@ -285,14 +287,14 @@ RasterizeMeshesNaiveCuda(
const int num_closest, const int num_closest,
const bool perspective_correct, const bool perspective_correct,
const bool cull_backfaces) { const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || TORCH_CHECK(
face_verts.size(2) != 3) { face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); face_verts.size(2) == 3,
} "face_verts must have dimensions (num_faces, 3, 3)");
if (num_faces_per_mesh.size(0) != mesh_to_faces_packed_first_idx.size(0)) {
AT_ERROR( TORCH_CHECK(
num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx"); "num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
}
if (num_closest > kMaxPointsPerPixel) { if (num_closest > kMaxPointsPerPixel) {
std::stringstream ss; std::stringstream ss;
@ -300,6 +302,20 @@ RasterizeMeshesNaiveCuda(
AT_ERROR(ss.str()); AT_ERROR(ss.str());
} }
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
mesh_to_faces_packed_first_idx_t{
mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
at::checkAllSameGPU(
c,
{face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = num_faces_per_mesh.size(0); // batch size. const int N = num_faces_per_mesh.size(0); // batch size.
const int H = image_size; // Assume square images. const int H = image_size; // Assume square images.
const int W = image_size; const int W = image_size;
@ -313,10 +329,15 @@ RasterizeMeshesNaiveCuda(
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts); at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts); at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
if (face_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
RasterizeMeshesNaiveCudaKernel<<<blocks, threads>>>( RasterizeMeshesNaiveCudaKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(), face_verts.contiguous().data_ptr<float>(),
mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(), mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_faces_per_mesh.contiguous().data_ptr<int64_t>(), num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
@ -332,6 +353,7 @@ RasterizeMeshesNaiveCuda(
pix_dists.contiguous().data_ptr<float>(), pix_dists.contiguous().data_ptr<float>(),
bary.contiguous().data_ptr<float>()); bary.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists); return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
} }
@ -465,6 +487,22 @@ at::Tensor RasterizeMeshesBackwardCuda(
const at::Tensor& grad_bary, // (N, H, W, K, 3) const at::Tensor& grad_bary, // (N, H, W, K, 3)
const at::Tensor& grad_dists, // (N, H, W, K) const at::Tensor& grad_dists, // (N, H, W, K)
const bool perspective_correct) { const bool perspective_correct) {
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
pix_to_face_t{pix_to_face, "pix_to_face", 2},
grad_zbuf_t{grad_zbuf, "grad_zbuf", 3},
grad_bary_t{grad_bary, "grad_bary", 4},
grad_dists_t{grad_dists, "grad_dists", 5};
at::CheckedFrom c = "RasterizeMeshesBackwardCuda";
at::checkAllSameGPU(
c, {face_verts_t, pix_to_face_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
at::checkAllSameType(
c, {face_verts_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int F = face_verts.size(0); const int F = face_verts.size(0);
const int N = pix_to_face.size(0); const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1); const int H = pix_to_face.size(1);
@ -472,10 +510,16 @@ at::Tensor RasterizeMeshesBackwardCuda(
const int K = pix_to_face.size(3); const int K = pix_to_face.size(3);
at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options()); at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options());
if (grad_face_verts.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_face_verts;
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
RasterizeMeshesBackwardCudaKernel<<<blocks, threads>>>( RasterizeMeshesBackwardCudaKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(), face_verts.contiguous().data_ptr<float>(),
pix_to_face.contiguous().data_ptr<int64_t>(), pix_to_face.contiguous().data_ptr<int64_t>(),
perspective_correct, perspective_correct,
@ -488,6 +532,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
grad_dists.contiguous().data_ptr<float>(), grad_dists.contiguous().data_ptr<float>(),
grad_face_verts.contiguous().data_ptr<float>()); grad_face_verts.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return grad_face_verts; return grad_face_verts;
} }
@ -626,10 +671,24 @@ at::Tensor RasterizeMeshesCoarseCuda(
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int max_faces_per_bin) { const int max_faces_per_bin) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || TORCH_CHECK(
face_verts.size(2) != 3) { face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); face_verts.size(2) == 3,
} "face_verts must have dimensions (num_faces, 3, 3)");
// Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
mesh_to_face_first_idx_t{
mesh_to_face_first_idx, "mesh_to_face_first_idx", 2},
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
at::CheckedFrom c = "RasterizeMeshesCoarseCuda";
at::checkAllSameGPU(
c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int W = image_size; const int W = image_size;
const int H = image_size; const int H = image_size;
const int F = face_verts.size(0); const int F = face_verts.size(0);
@ -645,12 +704,18 @@ at::Tensor RasterizeMeshesCoarseCuda(
auto opts = face_verts.options().dtype(at::kInt); auto opts = face_verts.options().dtype(at::kInt);
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts); at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts); at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
if (bin_faces.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;
}
const int chunk_size = 512; const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8; const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64; const size_t blocks = 64;
const size_t threads = 512; const size_t threads = 512;
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size>>>( RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
face_verts.contiguous().data_ptr<float>(), face_verts.contiguous().data_ptr<float>(),
mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(), mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(),
num_faces_per_mesh.contiguous().data_ptr<int64_t>(), num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
@ -664,6 +729,8 @@ at::Tensor RasterizeMeshesCoarseCuda(
M, M,
faces_per_bin.contiguous().data_ptr<int32_t>(), faces_per_bin.contiguous().data_ptr<int32_t>(),
bin_faces.contiguous().data_ptr<int32_t>()); bin_faces.contiguous().data_ptr<int32_t>());
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces; return bin_faces;
} }
@ -775,13 +842,22 @@ RasterizeMeshesFineCuda(
const int faces_per_pixel, const int faces_per_pixel,
const bool perspective_correct, const bool perspective_correct,
const bool cull_backfaces) { const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || TORCH_CHECK(
face_verts.size(2) != 3) { face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); face_verts.size(2) == 3,
} "face_verts must have dimensions (num_faces, 3, 3)");
if (bin_faces.ndimension() != 4) { TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
AT_ERROR("bin_faces must have 4 dimensions");
} // Check inputs are on the same device
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
bin_faces_t{bin_faces, "bin_faces", 2};
at::CheckedFrom c = "RasterizeMeshesFineCuda";
at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = bin_faces.size(0); const int N = bin_faces.size(0);
const int B = bin_faces.size(1); const int B = bin_faces.size(1);
const int M = bin_faces.size(3); const int M = bin_faces.size(3);
@ -790,7 +866,7 @@ RasterizeMeshesFineCuda(
const int W = image_size; const int W = image_size;
if (K > kMaxPointsPerPixel) { if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8"); AT_ERROR("Must have num_closest <= 150");
} }
auto long_opts = face_verts.options().dtype(at::kLong); auto long_opts = face_verts.options().dtype(at::kLong);
auto float_opts = face_verts.options().dtype(at::kFloat); auto float_opts = face_verts.options().dtype(at::kFloat);
@ -800,10 +876,15 @@ RasterizeMeshesFineCuda(
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts); at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts); at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
if (face_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
RasterizeMeshesFineCudaKernel<<<blocks, threads>>>( RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
face_verts.contiguous().data_ptr<float>(), face_verts.contiguous().data_ptr<float>(),
bin_faces.contiguous().data_ptr<int32_t>(), bin_faces.contiguous().data_ptr<int32_t>(),
blur_radius, blur_radius,

View File

@ -4,6 +4,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <cstdio> #include <cstdio>
#include <tuple> #include <tuple>
#include "utils/pytorch3d_cutils.h"
// **************************************************************************** // ****************************************************************************
// * FORWARD PASS * // * FORWARD PASS *
@ -95,6 +96,9 @@ RasterizeMeshesNaive(
// TODO: Better type checking. // TODO: Better type checking.
if (face_verts.is_cuda()) { if (face_verts.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(face_verts);
CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx);
CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh);
return RasterizeMeshesNaiveCuda( return RasterizeMeshesNaiveCuda(
face_verts, face_verts,
mesh_to_face_first_idx, mesh_to_face_first_idx,
@ -175,6 +179,11 @@ torch::Tensor RasterizeMeshesBackward(
const bool perspective_correct) { const bool perspective_correct) {
if (face_verts.is_cuda()) { if (face_verts.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(face_verts);
CHECK_CONTIGUOUS_CUDA(pix_to_face);
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
CHECK_CONTIGUOUS_CUDA(grad_bary);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return RasterizeMeshesBackwardCuda( return RasterizeMeshesBackwardCuda(
face_verts, face_verts,
pix_to_face, pix_to_face,
@ -251,6 +260,9 @@ torch::Tensor RasterizeMeshesCoarse(
const int max_faces_per_bin) { const int max_faces_per_bin) {
if (face_verts.is_cuda()) { if (face_verts.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(face_verts);
CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx);
CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh);
return RasterizeMeshesCoarseCuda( return RasterizeMeshesCoarseCuda(
face_verts, face_verts,
mesh_to_face_first_idx, mesh_to_face_first_idx,
@ -347,6 +359,8 @@ RasterizeMeshesFine(
const bool cull_backfaces) { const bool cull_backfaces) {
if (face_verts.is_cuda()) { if (face_verts.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(face_verts);
CHECK_CONTIGUOUS_CUDA(bin_faces);
return RasterizeMeshesFineCuda( return RasterizeMeshesFineCuda(
face_verts, face_verts,
bin_faces, bin_faces,

View File

@ -1,6 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <math.h> #include <math.h>
#include <cstdio> #include <cstdio>
#include <sstream> #include <sstream>
@ -145,13 +147,25 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
const int image_size, const int image_size,
const float radius, const float radius,
const int points_per_pixel) { const int points_per_pixel) {
if (points.ndimension() != 2 || points.size(1) != 3) { // Check inputs are on the same device
AT_ERROR("points must have dimensions (num_points, 3)"); at::TensorArg points_t{points, "points", 1},
} cloud_to_packed_first_idx_t{
if (num_points_per_cloud.size(0) != cloud_to_packed_first_idx.size(0)) { cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
AT_ERROR( num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
at::CheckedFrom c = "RasterizePointsNaiveCuda";
at::checkAllSameGPU(
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(
points.ndimension() == 2 && points.size(1) == 3,
"points must have dimensions (num_points, 3)");
TORCH_CHECK(
num_points_per_cloud.size(0) == cloud_to_packed_first_idx.size(0),
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx"); "num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
}
const int N = num_points_per_cloud.size(0); // batch size. const int N = num_points_per_cloud.size(0); // batch size.
const int S = image_size; const int S = image_size;
@ -169,9 +183,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
if (point_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
RasterizePointsNaiveCudaKernel<<<blocks, threads>>>( RasterizePointsNaiveCudaKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(), points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(), cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(), num_points_per_cloud.contiguous().data_ptr<int64_t>(),
@ -182,6 +201,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
point_idxs.contiguous().data_ptr<int32_t>(), point_idxs.contiguous().data_ptr<int32_t>(),
zbuf.contiguous().data_ptr<float>(), zbuf.contiguous().data_ptr<float>(),
pix_dists.contiguous().data_ptr<float>()); pix_dists.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(point_idxs, zbuf, pix_dists); return std::make_tuple(point_idxs, zbuf, pix_dists);
} }
@ -323,14 +344,28 @@ at::Tensor RasterizePointsCoarseCuda(
const float radius, const float radius,
const int bin_size, const int bin_size,
const int max_points_per_bin) { const int max_points_per_bin) {
TORCH_CHECK(
points.ndimension() == 2 && points.size(1) == 3,
"points must have dimensions (num_points, 3)");
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
cloud_to_packed_first_idx_t{
cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
at::CheckedFrom c = "RasterizePointsCoarseCuda";
at::checkAllSameGPU(
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int P = points.size(0); const int P = points.size(0);
const int N = num_points_per_cloud.size(0); const int N = num_points_per_cloud.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
const int M = max_points_per_bin; const int M = max_points_per_bin;
if (points.ndimension() != 2 || points.size(1) != 3) {
AT_ERROR("points must have dimensions (num_points, 3)");
}
if (num_bins >= 22) { if (num_bins >= 22) {
// Make sure we do not use too much shared memory. // Make sure we do not use too much shared memory.
std::stringstream ss; std::stringstream ss;
@ -340,12 +375,18 @@ at::Tensor RasterizePointsCoarseCuda(
auto opts = points.options().dtype(at::kInt); auto opts = points.options().dtype(at::kInt);
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts); at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts); at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
if (bin_points.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return bin_points;
}
const int chunk_size = 512; const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8; const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64; const size_t blocks = 64;
const size_t threads = 512; const size_t threads = 512;
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>( RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
points.contiguous().data_ptr<float>(), points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(), cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
num_points_per_cloud.contiguous().data_ptr<int64_t>(), num_points_per_cloud.contiguous().data_ptr<int64_t>(),
@ -358,6 +399,8 @@ at::Tensor RasterizePointsCoarseCuda(
M, M,
points_per_bin.contiguous().data_ptr<int32_t>(), points_per_bin.contiguous().data_ptr<int32_t>(),
bin_points.contiguous().data_ptr<int32_t>()); bin_points.contiguous().data_ptr<int32_t>());
AT_CUDA_CHECK(cudaGetLastError());
return bin_points; return bin_points;
} }
@ -448,13 +491,23 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
const float radius, const float radius,
const int bin_size, const int bin_size,
const int points_per_pixel) { const int points_per_pixel) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1},
bin_points_t{bin_points, "bin_points", 2};
at::CheckedFrom c = "RasterizePointsFineCuda";
at::checkAllSameGPU(c, {points_t, bin_points_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = bin_points.size(0); const int N = bin_points.size(0);
const int B = bin_points.size(1); // num_bins const int B = bin_points.size(1); // num_bins
const int M = bin_points.size(3); const int M = bin_points.size(3);
const int S = image_size; const int S = image_size;
const int K = points_per_pixel; const int K = points_per_pixel;
if (K > kMaxPointsPerPixel) { if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8"); AT_ERROR("Must have num_closest <= 150");
} }
auto int_opts = points.options().dtype(at::kInt); auto int_opts = points.options().dtype(at::kInt);
auto float_opts = points.options().dtype(at::kFloat); auto float_opts = points.options().dtype(at::kFloat);
@ -462,9 +515,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
if (point_idxs.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
RasterizePointsFineCudaKernel<<<blocks, threads>>>( RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(), points.contiguous().data_ptr<float>(),
bin_points.contiguous().data_ptr<int32_t>(), bin_points.contiguous().data_ptr<int32_t>(),
radius, radius,
@ -478,6 +536,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
zbuf.contiguous().data_ptr<float>(), zbuf.contiguous().data_ptr<float>(),
pix_dists.contiguous().data_ptr<float>()); pix_dists.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(point_idxs, zbuf, pix_dists); return std::make_tuple(point_idxs, zbuf, pix_dists);
} }
@ -537,6 +596,19 @@ at::Tensor RasterizePointsBackwardCuda(
const at::Tensor& idxs, // (N, H, W, K) const at::Tensor& idxs, // (N, H, W, K)
const at::Tensor& grad_zbuf, // (N, H, W, K) const at::Tensor& grad_zbuf, // (N, H, W, K)
const at::Tensor& grad_dists) { // (N, H, W, K) const at::Tensor& grad_dists) { // (N, H, W, K)
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, idxs_t{idxs, "idxs", 2},
grad_zbuf_t{grad_zbuf, "grad_zbuf", 3},
grad_dists_t{grad_dists, "grad_dists", 4};
at::CheckedFrom c = "RasterizePointsBackwardCuda";
at::checkAllSameGPU(c, {points_t, idxs_t, grad_zbuf_t, grad_dists_t});
at::checkAllSameType(c, {points_t, grad_zbuf_t, grad_dists_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int P = points.size(0); const int P = points.size(0);
const int N = idxs.size(0); const int N = idxs.size(0);
const int H = idxs.size(1); const int H = idxs.size(1);
@ -544,10 +616,16 @@ at::Tensor RasterizePointsBackwardCuda(
const int K = idxs.size(3); const int K = idxs.size(3);
at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_points = at::zeros({P, 3}, points.options());
if (grad_points.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_points;
}
const size_t blocks = 1024; const size_t blocks = 1024;
const size_t threads = 64; const size_t threads = 64;
RasterizePointsBackwardCudaKernel<<<blocks, threads>>>( RasterizePointsBackwardCudaKernel<<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<float>(), points.contiguous().data_ptr<float>(),
idxs.contiguous().data_ptr<int32_t>(), idxs.contiguous().data_ptr<int32_t>(),
N, N,
@ -559,5 +637,6 @@ at::Tensor RasterizePointsBackwardCuda(
grad_dists.contiguous().data_ptr<float>(), grad_dists.contiguous().data_ptr<float>(),
grad_points.contiguous().data_ptr<float>()); grad_points.contiguous().data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return grad_points; return grad_points;
} }

View File

@ -4,6 +4,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <cstdio> #include <cstdio>
#include <tuple> #include <tuple>
#include "utils/pytorch3d_cutils.h"
// **************************************************************************** // ****************************************************************************
// * NAIVE RASTERIZATION * // * NAIVE RASTERIZATION *
@ -66,6 +67,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
num_points_per_cloud.is_cuda()) { num_points_per_cloud.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
return RasterizePointsNaiveCuda( return RasterizePointsNaiveCuda(
points, points,
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
@ -140,6 +144,9 @@ torch::Tensor RasterizePointsCoarse(
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
num_points_per_cloud.is_cuda()) { num_points_per_cloud.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
return RasterizePointsCoarseCuda( return RasterizePointsCoarseCuda(
points, points,
cloud_to_packed_first_idx, cloud_to_packed_first_idx,
@ -208,6 +215,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
const int points_per_pixel) { const int points_per_pixel) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(bin_points);
return RasterizePointsFineCuda( return RasterizePointsFineCuda(
points, bin_points, image_size, radius, bin_size, points_per_pixel); points, bin_points, image_size, radius, bin_size, points_per_pixel);
#else #else
@ -257,6 +266,10 @@ torch::Tensor RasterizePointsBackward(
const torch::Tensor& grad_dists) { const torch::Tensor& grad_dists) {
if (points.is_cuda()) { if (points.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(points);
CHECK_CONTIGUOUS_CUDA(idxs);
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
CHECK_CONTIGUOUS_CUDA(grad_dists);
return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists); return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists);
#else #else
AT_ERROR("Not compiled with GPU support"); AT_ERROR("Not compiled with GPU support");

View File

@ -3,9 +3,9 @@
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x "must be a CUDA tensor.") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x) \ #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x "must be contiguous.") TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.")
#define CHECK_CONTIGUOUS_CUDA(x) \ #define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \ CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)

View File

@ -8,11 +8,21 @@ from test_chamfer import TestChamfer
def bm_chamfer() -> None: def bm_chamfer() -> None:
kwargs_list_naive = [ devices = ["cpu"]
{"batch_size": 1, "P1": 32, "P2": 64, "return_normals": False}, if torch.cuda.is_available():
{"batch_size": 1, "P1": 32, "P2": 64, "return_normals": True}, devices.append("cuda:0")
{"batch_size": 32, "P1": 32, "P2": 64, "return_normals": False},
] kwargs_list_naive = []
batch_size = [1, 32]
return_normals = [True, False]
test_cases = product(batch_size, return_normals, devices)
for case in test_cases:
b, n, d = case
kwargs_list_naive.append(
{"batch_size": b, "P1": 32, "P2": 64, "return_normals": n, "device": d}
)
benchmark( benchmark(
TestChamfer.chamfer_naive_with_init, TestChamfer.chamfer_naive_with_init,
"CHAMFER_NAIVE", "CHAMFER_NAIVE",
@ -21,6 +31,7 @@ def bm_chamfer() -> None:
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0"
kwargs_list = [] kwargs_list = []
batch_size = [1, 32] batch_size = [1, 32]
P1 = [32, 1000, 10000] P1 = [32, 1000, 10000]
@ -38,6 +49,7 @@ def bm_chamfer() -> None:
"P2": p2, "P2": p2,
"return_normals": n, "return_normals": n,
"homogeneous": h, "homogeneous": h,
"device": device,
} }
) )
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1) benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)

View File

@ -20,6 +20,18 @@ def load_rgb_image(filename: str, data_dir: Union[str, Path]):
TensorOrArray = Union[torch.Tensor, np.ndarray] TensorOrArray = Union[torch.Tensor, np.ndarray]
def get_random_cuda_device() -> str:
"""
Function to get a random GPU device from the
available devices. This is useful for testing
that custom cuda kernels can support inputs on
any device without having to set the device explicitly.
"""
num_devices = torch.cuda.device_count()
rand_device_id = torch.randint(high=num_devices, size=(1,)).item()
return "cuda:%d" % rand_device_id
class TestCaseMixin(unittest.TestCase): class TestCaseMixin(unittest.TestCase):
def assertSeparate(self, tensor1, tensor2) -> None: def assertSeparate(self, tensor1, tensor2) -> None:
""" """

View File

@ -6,7 +6,7 @@ from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.loss import chamfer_distance from pytorch3d.loss import chamfer_distance
from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.structures.pointclouds import Pointclouds
@ -81,7 +81,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
) )
@staticmethod @staticmethod
def chamfer_distance_naive_pointclouds(p1, p2): def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
""" """
Naive iterative implementation of nearest neighbor and chamfer distance. Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals. x and y are assumed to be pointclouds objects with points and optionally normals.
@ -97,7 +97,6 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
x_normals = p1.normals_padded() x_normals = p1.normals_padded()
y_normals = p2.normals_padded() y_normals = p2.normals_padded()
device = torch.device("cuda:0")
return_normals = x_normals is not None and y_normals is not None return_normals = x_normals is not None and y_normals is not None
# Initialize all distances to + inf # Initialize all distances to + inf
@ -163,7 +162,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
""" """
N, P1, D = x.shape N, P1, D = x.shape
P2 = y.size(1) P2 = y.size(1)
device = torch.device("cuda:0") device = x.device
return_normals = x_normals is not None and y_normals is not None return_normals = x_normals is not None and y_normals is not None
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device) dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
@ -203,7 +202,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
This tests only uses homogeneous pointclouds. This tests only uses homogeneous pointclouds.
""" """
N, max_P1, max_P2 = 7, 10, 18 N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0" device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1 p1 = points_normals.p1
p2 = points_normals.p2 p2 = points_normals.p2
@ -237,7 +236,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
which supports heterogeneous pointcloud objects. which supports heterogeneous pointcloud objects.
""" """
N, max_P1, max_P2 = 3, 70, 70 N, max_P1, max_P2 = 3, 70, 70
device = "cuda:0" device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights weights = points_normals.weights
x_lengths = points_normals.p1_lengths x_lengths = points_normals.p1_lengths
@ -256,7 +255,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
# Chamfer with pointclouds as input. # Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds( pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2 points_normals.cloud1, points_normals.cloud2, device=device
) )
# Mean reduction point loss. # Mean reduction point loss.
@ -299,7 +298,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_chamfer_pointcloud_object_withnormals(self): def test_chamfer_pointcloud_object_withnormals(self):
N = 5 N = 5
P1, P2 = 100, 100 P1, P2 = 100, 100
device = "cuda:0" device = get_random_cuda_device()
reductions = [ reductions = [
("sum", "sum"), ("sum", "sum"),
@ -359,7 +358,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_chamfer_pointcloud_object_nonormals(self): def test_chamfer_pointcloud_object_nonormals(self):
N = 5 N = 5
P1, P2 = 100, 100 P1, P2 = 100, 100
device = "cuda:0" device = get_random_cuda_device()
reductions = [ reductions = [
("sum", "sum"), ("sum", "sum"),
@ -415,7 +414,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for point_reduction = "mean" and batch_reduction = None. for point_reduction = "mean" and batch_reduction = None.
""" """
N, max_P1, max_P2 = 7, 10, 18 N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0" device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1 p1 = points_normals.p1
p2 = points_normals.p2 p2 = points_normals.p2
@ -464,7 +463,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for point_reduction = "sum" and batch_reduction = None. for point_reduction = "sum" and batch_reduction = None.
""" """
N, P1, P2 = 7, 10, 18 N, P1, P2 = 7, 10, 18
device = "cuda:0" device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1 p1 = points_normals.p1
p2 = points_normals.p2 p2 = points_normals.p2
@ -579,7 +578,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
point_reduction in ["mean", "sum"]. point_reduction in ["mean", "sum"].
""" """
N, max_P1, max_P2 = 7, 10, 18 N, max_P1, max_P2 = 7, 10, 18
device = "cuda:0" device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1 p1 = points_normals.p1
@ -681,7 +680,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_incorrect_weights(self): def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128 N, P1, P2 = 16, 64, 128
device = torch.device("cuda:0") device = get_random_cuda_device()
p1 = torch.rand( p1 = torch.rand(
(N, P1, 3), dtype=torch.float32, device=device, requires_grad=True (N, P1, 3), dtype=torch.float32, device=device, requires_grad=True
) )
@ -716,7 +715,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def test_incorrect_inputs(self): def test_incorrect_inputs(self):
N, P1, P2 = 7, 10, 18 N, P1, P2 = 7, 10, 18
device = "cuda:0" device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1 p1 = points_normals.p1
p2 = points_normals.p2 p2 = points_normals.p2
@ -740,11 +739,16 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def chamfer_with_init( def chamfer_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool batch_size: int,
P1: int,
P2: int,
return_normals: bool,
homogeneous: bool,
device="cpu",
): ):
p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds( points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
batch_size, P1, P2 l1 = points_normals.p1_lengths
) l2 = points_normals.p2_lengths
if homogeneous: if homogeneous:
# Set lengths to None so in Chamfer it assumes # Set lengths to None so in Chamfer it assumes
# there is no padding. # there is no padding.
@ -754,13 +758,13 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
def loss(): def loss():
loss, loss_normals = chamfer_distance( loss, loss_normals = chamfer_distance(
p1, points_normals.p1,
p2, points_normals.p2,
x_lengths=l1, x_lengths=l1,
y_lengths=l2, y_lengths=l2,
x_normals=p1_normals, x_normals=points_normals.n1,
y_normals=p2_normals, y_normals=points_normals.n2,
weights=weights, weights=points_normals.weights,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -768,16 +772,17 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def chamfer_naive_with_init( def chamfer_naive_with_init(
batch_size: int, P1: int, P2: int, return_normals: bool batch_size: int, P1: int, P2: int, return_normals: bool, device="cpu"
): ):
p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds( points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
batch_size, P1, P2
)
torch.cuda.synchronize() torch.cuda.synchronize()
def loss(): def loss():
loss, loss_normals = TestChamfer.chamfer_distance_naive( loss, loss_normals = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=p1_normals, y_normals=p2_normals points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
) )
torch.cuda.synchronize() torch.cuda.synchronize()

View File

@ -3,6 +3,7 @@
import unittest import unittest
import torch import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.renderer.compositing import ( from pytorch3d.renderer.compositing import (
alpha_composite, alpha_composite,
norm_weighted_sum, norm_weighted_sum,
@ -10,7 +11,7 @@ from pytorch3d.renderer.compositing import (
) )
class TestAccumulatePoints(unittest.TestCase): class TestAccumulatePoints(TestCaseMixin, unittest.TestCase):
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING) # NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
@staticmethod @staticmethod
@ -120,7 +121,7 @@ class TestAccumulatePoints(unittest.TestCase):
self._simple_wsumnorm(norm_weighted_sum, device) self._simple_wsumnorm(norm_weighted_sum, device)
def test_cuda(self): def test_cuda(self):
device = torch.device("cuda:0") device = get_random_cuda_device()
self._simple_alphacomposite(alpha_composite, device) self._simple_alphacomposite(alpha_composite, device)
self._simple_wsum(weighted_sum, device) self._simple_wsum(weighted_sum, device)
self._simple_wsumnorm(norm_weighted_sum, device) self._simple_wsumnorm(norm_weighted_sum, device)
@ -142,7 +143,7 @@ class TestAccumulatePoints(unittest.TestCase):
C = 3 C = 3
P = 32 P = 32
for d in ["cpu", "cuda"]: for d in ["cpu", get_random_cuda_device()]:
# TODO(gkioxari) add torch.float64 to types after double precision # TODO(gkioxari) add torch.float64 to types after double precision
# support is added to atomicAdd # support is added to atomicAdd
for t in [torch.float32]: for t in [torch.float32]:
@ -181,7 +182,7 @@ class TestAccumulatePoints(unittest.TestCase):
res1 = fn1(*args1) res1 = fn1(*args1)
res2 = fn2(*args2) res2 = fn2(*args2)
self.assertTrue(torch.allclose(res1.cpu(), res2.cpu(), atol=1e-6)) self.assertClose(res1.cpu(), res2.cpu(), atol=1e-6)
if not compare_grads: if not compare_grads:
return return
@ -200,7 +201,7 @@ class TestAccumulatePoints(unittest.TestCase):
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2] grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
for i in range(0, len(grads1)): for i in range(0, len(grads1)):
self.assertTrue(torch.allclose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)) self.assertClose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
def _simple_wsum(self, accum_func, device): def _simple_wsum(self, accum_func, device):
# Initialise variables # Initialise variables
@ -273,7 +274,7 @@ class TestAccumulatePoints(unittest.TestCase):
] ]
).to(device) ).to(device)
self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)) self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
def _simple_wsumnorm(self, accum_func, device): def _simple_wsumnorm(self, accum_func, device):
# Initialise variables # Initialise variables
@ -346,7 +347,7 @@ class TestAccumulatePoints(unittest.TestCase):
] ]
).to(device) ).to(device)
self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)) self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
def _simple_alphacomposite(self, accum_func, device): def _simple_alphacomposite(self, accum_func, device):
# Initialise variables # Initialise variables

View File

@ -33,7 +33,9 @@ class TestCubify(unittest.TestCase):
# 1st-check # 1st-check
verts, faces = meshes.get_mesh_verts_faces(0) verts, faces = meshes.get_mesh_verts_faces(0)
self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1]))) self.assertTrue(
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
)
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
verts, verts,
@ -78,7 +80,9 @@ class TestCubify(unittest.TestCase):
) )
# 2nd-check # 2nd-check
verts, faces = meshes.get_mesh_verts_faces(1) verts, faces = meshes.get_mesh_verts_faces(1)
self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1]))) self.assertTrue(
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
)
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
verts, verts,

View File

@ -4,7 +4,7 @@
import unittest import unittest
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import mesh_face_areas_normals from pytorch3d.ops import mesh_face_areas_normals
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
@ -94,13 +94,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
self._test_face_areas_normals_helper("cpu") self._test_face_areas_normals_helper("cpu")
def test_face_areas_normals_cuda(self): def test_face_areas_normals_cuda(self):
self._test_face_areas_normals_helper("cuda:0") device = get_random_cuda_device()
self._test_face_areas_normals_helper(device)
def test_nonfloats_cpu(self): def test_nonfloats_cpu(self):
self._test_face_areas_normals_helper("cpu", dtype=torch.double) self._test_face_areas_normals_helper("cpu", dtype=torch.double)
def test_nonfloats_cuda(self): def test_nonfloats_cuda(self):
self._test_face_areas_normals_helper("cuda:0", dtype=torch.double) device = get_random_cuda_device()
self._test_face_areas_normals_helper(device, dtype=torch.double)
@staticmethod @staticmethod
def face_areas_normals_with_init( def face_areas_normals_with_init(

View File

@ -4,7 +4,7 @@ import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C from pytorch3d import _C
from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
@ -14,7 +14,7 @@ from pytorch3d.utils import ico_sphere
class TestGraphConv(TestCaseMixin, unittest.TestCase): class TestGraphConv(TestCaseMixin, unittest.TestCase):
def test_undirected(self): def test_undirected(self):
dtype = torch.float32 dtype = torch.float32
device = torch.device("cuda:0") device = get_random_cuda_device()
verts = torch.tensor( verts = torch.tensor(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device
) )
@ -97,7 +97,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
self.assertClose(y, expected_y) self.assertClose(y, expected_y)
def test_backward(self): def test_backward(self):
device = torch.device("cuda:0") device = get_random_cuda_device()
mesh = ico_sphere() mesh = ico_sphere()
verts = mesh.verts_packed() verts = mesh.verts_packed()
edges = mesh.edges_packed() edges = mesh.edges_packed()
@ -118,7 +118,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)") self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)")
def test_cpu_cuda_tensor_error(self): def test_cpu_cuda_tensor_error(self):
device = torch.device("cuda:0") device = get_random_cuda_device()
verts = torch.tensor( verts = torch.tensor(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
) )
@ -134,7 +134,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
Check that gather_scatter cuda version throws an error if cpu tensors Check that gather_scatter cuda version throws an error if cpu tensors
are given as input. are given as input.
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
mesh = ico_sphere() mesh = ico_sphere()
verts = mesh.verts_packed() verts = mesh.verts_packed()
edges = mesh.edges_packed() edges = mesh.edges_packed()

View File

@ -4,7 +4,7 @@ import unittest
from itertools import product from itertools import product
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
@ -89,7 +89,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
self._knn_vs_python_square_helper(device) self._knn_vs_python_square_helper(device)
def test_knn_vs_python_square_cuda(self): def test_knn_vs_python_square_cuda(self):
device = torch.device("cuda:0") device = get_random_cuda_device()
self._knn_vs_python_square_helper(device) self._knn_vs_python_square_helper(device)
def _knn_vs_python_ragged_helper(self, device): def _knn_vs_python_ragged_helper(self, device):
@ -133,11 +133,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
self._knn_vs_python_ragged_helper(device) self._knn_vs_python_ragged_helper(device)
def test_knn_vs_python_ragged_cuda(self): def test_knn_vs_python_ragged_cuda(self):
device = torch.device("cuda:0") device = get_random_cuda_device()
self._knn_vs_python_ragged_helper(device) self._knn_vs_python_ragged_helper(device)
def test_knn_gather(self): def test_knn_gather(self):
device = torch.device("cuda:0") device = get_random_cuda_device()
N, P1, P2, K, D = 4, 16, 12, 8, 3 N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device) x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device) y = torch.rand((N, P2, D), device=device)

View File

@ -3,7 +3,7 @@
import unittest import unittest
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import packed_to_padded, padded_to_packed from pytorch3d.ops import packed_to_padded, padded_to_packed
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
@ -126,13 +126,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
self._test_packed_to_padded_helper(16, "cpu") self._test_packed_to_padded_helper(16, "cpu")
def test_packed_to_padded_flat_cuda(self): def test_packed_to_padded_flat_cuda(self):
self._test_packed_to_padded_helper(0, "cuda:0") device = get_random_cuda_device()
self._test_packed_to_padded_helper(0, device)
def test_packed_to_padded_D1_cuda(self): def test_packed_to_padded_D1_cuda(self):
self._test_packed_to_padded_helper(1, "cuda:0") device = get_random_cuda_device()
self._test_packed_to_padded_helper(1, device)
def test_packed_to_padded_D16_cuda(self): def test_packed_to_padded_D16_cuda(self):
self._test_packed_to_padded_helper(16, "cuda:0") device = get_random_cuda_device()
self._test_packed_to_padded_helper(16, device)
def _test_padded_to_packed_helper(self, D, device): def _test_padded_to_packed_helper(self, D, device):
""" """
@ -191,13 +194,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
self._test_padded_to_packed_helper(16, "cpu") self._test_padded_to_packed_helper(16, "cpu")
def test_padded_to_packed_flat_cuda(self): def test_padded_to_packed_flat_cuda(self):
self._test_padded_to_packed_helper(0, "cuda:0") device = get_random_cuda_device()
self._test_padded_to_packed_helper(0, device)
def test_padded_to_packed_D1_cuda(self): def test_padded_to_packed_D1_cuda(self):
self._test_padded_to_packed_helper(1, "cuda:0") device = get_random_cuda_device()
self._test_padded_to_packed_helper(1, device)
def test_padded_to_packed_D16_cuda(self): def test_padded_to_packed_D16_cuda(self):
self._test_padded_to_packed_helper(16, "cuda:0") device = get_random_cuda_device()
self._test_padded_to_packed_helper(16, device)
def test_invalid_inputs_shapes(self, device="cuda:0"): def test_invalid_inputs_shapes(self, device="cuda:0"):
with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."): with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."):

View File

@ -4,7 +4,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C from pytorch3d import _C
from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance
from pytorch3d.structures import Meshes, Pointclouds, packed_to_list from pytorch3d.structures import Meshes, Pointclouds, packed_to_list
@ -203,7 +203,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
& PointEdgeArrayDistanceBackward & PointEdgeArrayDistanceBackward
""" """
P, E = 16, 32 P, E = 16, 32
device = torch.device("cuda:0") device = get_random_cuda_device()
points = torch.rand((P, 3), dtype=torch.float32, device=device) points = torch.rand((P, 3), dtype=torch.float32, device=device)
edges = torch.rand((E, 2, 3), dtype=torch.float32, device=device) edges = torch.rand((E, 2, 3), dtype=torch.float32, device=device)
@ -246,9 +246,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for PointEdgeDistanceForward Test CUDA implementation for PointEdgeDistanceForward
& PointEdgeDistanceBackward & PointEdgeDistanceBackward
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24 N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P) meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node # make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3) points_packed = pcls.points_packed().detach().clone() # (P, 3)
@ -327,9 +327,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for EdgePointDistanceForward Test CUDA implementation for EdgePointDistanceForward
& EdgePointDistanceBackward & EdgePointDistanceBackward
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24 N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P) meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node # make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3) points_packed = pcls.points_packed().detach().clone() # (P, 3)
@ -409,9 +409,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
""" """
Test point_mesh_edge_distance from pytorch3d.loss Test point_mesh_edge_distance from pytorch3d.loss
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24 N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P) meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# clone and detach for another backward pass through the op # clone and detach for another backward pass through the op
verts_op = [verts.clone().detach() for verts in meshes.verts_list()] verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
@ -480,7 +480,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
& PointFaceArrayDistanceBackward & PointFaceArrayDistanceBackward
""" """
P, T = 16, 32 P, T = 16, 32
device = torch.device("cuda:0") device = get_random_cuda_device()
points = torch.rand((P, 3), dtype=torch.float32, device=device) points = torch.rand((P, 3), dtype=torch.float32, device=device)
tris = torch.rand((T, 3, 3), dtype=torch.float32, device=device) tris = torch.rand((T, 3, 3), dtype=torch.float32, device=device)
@ -525,9 +525,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for PointFaceDistanceForward Test CUDA implementation for PointFaceDistanceForward
& PointFaceDistanceBackward & PointFaceDistanceBackward
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24 N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P) meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node # make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3) points_packed = pcls.points_packed().detach().clone() # (P, 3)
@ -608,9 +608,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
Test CUDA implementation for FacePointDistanceForward Test CUDA implementation for FacePointDistanceForward
& FacePointDistanceBackward & FacePointDistanceBackward
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24 N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P) meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# make points packed a leaf node # make points packed a leaf node
points_packed = pcls.points_packed().detach().clone() # (P, 3) points_packed = pcls.points_packed().detach().clone() # (P, 3)
@ -690,9 +690,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
""" """
Test point_mesh_face_distance from pytorch3d.loss Test point_mesh_face_distance from pytorch3d.loss
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
N, V, F, P = 4, 32, 16, 24 N, V, F, P = 4, 32, 16, 24
meshes, pcls = self.init_meshes_clouds(N, V, F, P) meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
# clone and detach for another backward pass through the op # clone and detach for another backward pass through the op
verts_op = [verts.clone().detach() for verts in meshes.verts_list()] verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
@ -751,7 +751,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def point_mesh_edge(N: int, V: int, F: int, P: int, device: str): def point_mesh_edge(N: int, V: int, F: int, P: int, device: str):
device = torch.device(device) device = torch.device(device)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P) meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
N, V, F, P, device=device
)
torch.cuda.synchronize() torch.cuda.synchronize()
def loss(): def loss():
@ -763,7 +765,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def point_mesh_face(N: int, V: int, F: int, P: int, device: str): def point_mesh_face(N: int, V: int, F: int, P: int, device: str):
device = torch.device(device) device = torch.device(device)
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P) meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
N, V, F, P, device=device
)
torch.cuda.synchronize() torch.cuda.synchronize()
def loss(): def loss():

View File

@ -4,7 +4,7 @@ import functools
import unittest import unittest
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C from pytorch3d import _C
from pytorch3d.renderer.mesh.rasterize_meshes import ( from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes, rasterize_meshes,
@ -32,7 +32,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._test_back_face_culling(rasterize_meshes, device, bin_size=0) self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
def test_simple_cuda_naive(self): def test_simple_cuda_naive(self):
device = torch.device("cuda:0") device = get_random_cuda_device()
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0) self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0) self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
self._test_behind_camera(rasterize_meshes, device, bin_size=0) self._test_behind_camera(rasterize_meshes, device, bin_size=0)
@ -40,7 +40,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._test_back_face_culling(rasterize_meshes, device, bin_size=0) self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
def test_simple_cuda_binned(self): def test_simple_cuda_binned(self):
device = torch.device("cuda:0") device = get_random_cuda_device()
self._simple_triangle_raster(rasterize_meshes, device, bin_size=5) self._simple_triangle_raster(rasterize_meshes, device, bin_size=5)
self._simple_blurry_raster(rasterize_meshes, device, bin_size=5) self._simple_blurry_raster(rasterize_meshes, device, bin_size=5)
self._test_behind_camera(rasterize_meshes, device, bin_size=5) self._test_behind_camera(rasterize_meshes, device, bin_size=5)
@ -54,7 +54,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
blur_radius = 0.1 ** 2 blur_radius = 0.1 ** 2
faces_per_pixel = 3 faces_per_pixel = 3
for d in ["cpu", "cuda"]: for d in ["cpu", get_random_cuda_device()]:
device = torch.device(d) device = torch.device(d)
compare_grads = True compare_grads = True
# Mesh with a single face. # Mesh with a single face.
@ -164,7 +164,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
verts1.requires_grad = True verts1.requires_grad = True
meshes_cpu = Meshes(verts=[verts1], faces=[faces1]) meshes_cpu = Meshes(verts=[verts1], faces=[faces1])
device = torch.device("cuda:0") device = get_random_cuda_device()
meshes_cuda = ico_sphere(0, device) meshes_cuda = ico_sphere(0, device)
verts2, faces2 = meshes_cuda.get_mesh_verts_faces(0) verts2, faces2 = meshes_cuda.get_mesh_verts_faces(0)
verts2.requires_grad = True verts2.requires_grad = True
@ -186,7 +186,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
return self._test_coarse_rasterize(torch.device("cpu")) return self._test_coarse_rasterize(torch.device("cpu"))
def test_coarse_cuda(self): def test_coarse_cuda(self):
return self._test_coarse_rasterize(torch.device("cuda:0")) return self._test_coarse_rasterize(get_random_cuda_device())
def test_cpp_vs_cuda_naive_vs_cuda_binned(self): def test_cpp_vs_cuda_naive_vs_cuda_binned(self):
# Make sure that the backward pass runs for all pathways # Make sure that the backward pass runs for all pathways
@ -221,7 +221,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
grad1 = verts.grad.data.cpu().clone() grad1 = verts.grad.data.cpu().clone()
# Option II: CUDA, naive # Option II: CUDA, naive
device = torch.device("cuda:0") device = get_random_cuda_device()
meshes = ico_sphere(0, device) meshes = ico_sphere(0, device)
verts, faces = meshes.get_mesh_verts_faces(0) verts, faces = meshes.get_mesh_verts_faces(0)
verts.requires_grad = True verts.requires_grad = True
@ -229,9 +229,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
args = (meshes, image_size, radius, faces_per_pixel, 0, 0) args = (meshes, image_size, radius, faces_per_pixel, 0, 0)
idx2, zbuf2, bary2, dist2 = rasterize_meshes(*args) idx2, zbuf2, bary2, dist2 = rasterize_meshes(*args)
grad_zbuf = grad_zbuf.cuda() grad_zbuf = grad_zbuf.to(device)
grad_dist = grad_dist.cuda() grad_dist = grad_dist.to(device)
grad_bary = grad_bary.cuda() grad_bary = grad_bary.to(device)
loss = ( loss = (
(zbuf2 * grad_zbuf).sum() (zbuf2 * grad_zbuf).sum()
+ (dist2 * grad_dist).sum() + (dist2 * grad_dist).sum()
@ -244,7 +244,6 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
grad2 = verts.grad.data.cpu().clone() grad2 = verts.grad.data.cpu().clone()
# Option III: CUDA, binned # Option III: CUDA, binned
device = torch.device("cuda:0")
meshes = ico_sphere(0, device) meshes = ico_sphere(0, device)
verts, faces = meshes.get_mesh_verts_faces(0) verts, faces = meshes.get_mesh_verts_faces(0)
verts.requires_grad = True verts.requires_grad = True
@ -302,7 +301,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
bin_size, bin_size,
max_faces_per_bin, max_faces_per_bin,
) )
device = torch.device("cuda:0") device = get_random_cuda_device()
meshes = meshes.clone().to(device) meshes = meshes.clone().to(device)
faces = meshes.faces_packed() faces = meshes.faces_packed()
@ -356,8 +355,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
verts1, faces1 = meshes.get_mesh_verts_faces(0) verts1, faces1 = meshes.get_mesh_verts_faces(0)
verts1.requires_grad = True verts1.requires_grad = True
meshes1 = Meshes(verts=[verts1], faces=[faces1]) meshes1 = Meshes(verts=[verts1], faces=[faces1])
verts2 = verts1.detach().cuda().requires_grad_(True) device = get_random_cuda_device()
faces2 = faces1.detach().clone().cuda() verts2 = verts1.detach().to(device).requires_grad_(True)
faces2 = faces1.detach().clone().to(device)
meshes2 = Meshes(verts=[verts2], faces=[faces2]) meshes2 = Meshes(verts=[verts2], faces=[faces2])
kwargs = {"image_size": 64, "perspective_correct": True} kwargs = {"image_size": 64, "perspective_correct": True}
@ -367,7 +367,8 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
def test_cuda_naive_vs_binned_perspective_correct(self): def test_cuda_naive_vs_binned_perspective_correct(self):
meshes = ico_sphere(2, device=torch.device("cuda")) device = get_random_cuda_device()
meshes = ico_sphere(2, device=device)
verts1, faces1 = meshes.get_mesh_verts_faces(0) verts1, faces1 = meshes.get_mesh_verts_faces(0)
verts1.requires_grad = True verts1.requires_grad = True
meshes1 = Meshes(verts=[verts1], faces=[faces1]) meshes1 = Meshes(verts=[verts1], faces=[faces1])
@ -1029,7 +1030,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
max_faces_per_bin: int, max_faces_per_bin: int,
): ):
meshes = ico_sphere(ico_level, torch.device("cuda:0")) meshes = ico_sphere(ico_level, get_random_cuda_device())
meshes_batch = meshes.extend(num_meshes) meshes_batch = meshes.extend(num_meshes)
torch.cuda.synchronize() torch.cuda.synchronize()

View File

@ -5,7 +5,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C from pytorch3d import _C
from pytorch3d.renderer.points.rasterize_points import ( from pytorch3d.renderer.points.rasterize_points import (
rasterize_points, rasterize_points,
@ -25,7 +25,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
self._simple_test_case(rasterize_points, device) self._simple_test_case(rasterize_points, device)
def test_naive_simple_cuda(self): def test_naive_simple_cuda(self):
device = torch.device("cuda") device = get_random_cuda_device()
self._simple_test_case(rasterize_points, device, bin_size=0) self._simple_test_case(rasterize_points, device, bin_size=0)
def test_python_behind_camera(self): def test_python_behind_camera(self):
@ -37,7 +37,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
self._test_behind_camera(rasterize_points, torch.device("cpu")) self._test_behind_camera(rasterize_points, torch.device("cpu"))
def test_cuda_behind_camera(self): def test_cuda_behind_camera(self):
self._test_behind_camera(rasterize_points, torch.device("cuda"), bin_size=0) device = get_random_cuda_device()
self._test_behind_camera(rasterize_points, device, bin_size=0)
def test_cpp_vs_naive_vs_binned(self): def test_cpp_vs_naive_vs_binned(self):
# Make sure that the backward pass runs for all pathways # Make sure that the backward pass runs for all pathways
@ -373,7 +374,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
return self._test_coarse_rasterize(torch.device("cpu")) return self._test_coarse_rasterize(torch.device("cpu"))
def test_coarse_cuda(self): def test_coarse_cuda(self):
return self._test_coarse_rasterize(torch.device("cuda")) device = get_random_cuda_device()
return self._test_coarse_rasterize(device)
def test_compare_coarse_cpu_vs_cuda(self): def test_compare_coarse_cpu_vs_cuda(self):
torch.manual_seed(231) torch.manual_seed(231)
@ -405,7 +407,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
) )
bp_cpu = _C._rasterize_points_coarse(*args) bp_cpu = _C._rasterize_points_coarse(*args)
pointclouds_cuda = pointclouds.to("cuda:0") device = get_random_cuda_device()
pointclouds_cuda = pointclouds.to(device)
points_packed = pointclouds_cuda.points_packed() points_packed = pointclouds_cuda.points_packed()
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx() cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud() num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()

View File

@ -5,7 +5,7 @@ import unittest
from pathlib import Path from pathlib import Path
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import sample_points_from_meshes from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
@ -42,7 +42,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
Check sample_points_from_meshes raises an exception if all meshes are Check sample_points_from_meshes raises an exception if all meshes are
invalid. invalid.
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
verts1 = torch.tensor([], dtype=torch.float32, device=device) verts1 = torch.tensor([], dtype=torch.float32, device=device)
faces1 = torch.tensor([], dtype=torch.int64, device=device) faces1 = torch.tensor([], dtype=torch.int64, device=device)
meshes = Meshes(verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1]) meshes = Meshes(verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1])
@ -56,7 +56,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
For an ico_sphere, the sampled vertices should lie on a unit sphere. For an ico_sphere, the sampled vertices should lie on a unit sphere.
For an empty mesh, the samples and normals should be 0. For an empty mesh, the samples and normals should be 0.
""" """
device = torch.device("cuda:0") device = get_random_cuda_device()
# Unit simplex. # Unit simplex.
verts_pyramid = torch.tensor( verts_pyramid = torch.tensor(