diff --git a/pytorch3d/csrc/compositing/alpha_composite.cu b/pytorch3d/csrc/compositing/alpha_composite.cu index 601a1d2b..27c5d7e2 100644 --- a/pytorch3d/csrc/compositing/alpha_composite.cu +++ b/pytorch3d/csrc/compositing/alpha_composite.cu @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -136,6 +138,17 @@ at::Tensor alphaCompositeCudaForward( const at::Tensor& features, const at::Tensor& alphas, const at::Tensor& points_idx) { + // Check inputs are on the same device + at::TensorArg features_t{features, "features", 1}, + alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3}; + at::CheckedFrom c = "alphaCompositeCudaForward"; + at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t}); + at::checkAllSameType(c, {features_t, alphas_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(features.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t batch_size = points_idx.size(0); const int64_t C = features.size(0); const int64_t H = points_idx.size(2); @@ -143,19 +156,24 @@ at::Tensor alphaCompositeCudaForward( auto result = at::zeros({batch_size, C, H, W}, features.options()); + if (result.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return result; + } + const dim3 threadsPerBlock(64); const dim3 numBlocks(batch_size, 1024 / batch_size + 1); // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // doubles. Currently, support is for floats only. - alphaCompositeCudaForwardKernel<<>>( + alphaCompositeCudaForwardKernel<<>>( // clang-format off result.packed_accessor64(), features.packed_accessor64(), alphas.packed_accessor64(), points_idx.packed_accessor64()); // clang-format on - + AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -164,9 +182,26 @@ std::tuple alphaCompositeCudaBackward( const at::Tensor& features, const at::Tensor& alphas, const at::Tensor& points_idx) { + // Check inputs are on the same device + at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1}, + features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3}, + points_idx_t{points_idx, "points_idx", 4}; + at::CheckedFrom c = "alphaCompositeCudaBackward"; + at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t}); + at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(features.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto grad_features = at::zeros_like(features); auto grad_alphas = at::zeros_like(alphas); + if (grad_features.numel() == 0 || grad_alphas.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_features, grad_alphas); + } + const int64_t bs = alphas.size(0); const dim3 threadsPerBlock(64); @@ -174,7 +209,7 @@ std::tuple alphaCompositeCudaBackward( // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // doubles. Currently, support is for floats only. - alphaCompositeCudaBackwardKernel<<>>( + alphaCompositeCudaBackwardKernel<<>>( // clang-format off grad_features.packed_accessor64(), grad_alphas.packed_accessor64(), @@ -183,6 +218,6 @@ std::tuple alphaCompositeCudaBackward( alphas.packed_accessor64(), points_idx.packed_accessor64()); // clang-format on - + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_features, grad_alphas); } diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.cu b/pytorch3d/csrc/compositing/norm_weighted_sum.cu index e78e8e47..d3d094ff 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.cu +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.cu @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -151,6 +153,17 @@ at::Tensor weightedSumNormCudaForward( const at::Tensor& features, const at::Tensor& alphas, const at::Tensor& points_idx) { + // Check inputs are on the same device + at::TensorArg features_t{features, "features", 1}, + alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3}; + at::CheckedFrom c = "weightedSumNormCudaForward"; + at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t}); + at::checkAllSameType(c, {features_t, alphas_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(features.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t batch_size = points_idx.size(0); const int64_t C = features.size(0); const int64_t H = points_idx.size(2); @@ -158,19 +171,25 @@ at::Tensor weightedSumNormCudaForward( auto result = at::zeros({batch_size, C, H, W}, features.options()); + if (result.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return result; + } + const dim3 threadsPerBlock(64); const dim3 numBlocks(batch_size, 1024 / batch_size + 1); // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // doubles. Currently, support is for floats only. // clang-format off - weightedSumNormCudaForwardKernel<<>>( + weightedSumNormCudaForwardKernel<<>>( result.packed_accessor64(), features.packed_accessor64(), alphas.packed_accessor64(), points_idx.packed_accessor64()); // clang-format on + AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -179,9 +198,26 @@ std::tuple weightedSumNormCudaBackward( const at::Tensor& features, const at::Tensor& alphas, const at::Tensor& points_idx) { + // Check inputs are on the same device + at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1}, + features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3}, + points_idx_t{points_idx, "points_idx", 4}; + at::CheckedFrom c = "weightedSumNormCudaBackward"; + at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t}); + at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(features.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto grad_features = at::zeros_like(features); auto grad_alphas = at::zeros_like(alphas); + if (grad_features.numel() == 0 || grad_alphas.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_features, grad_alphas); + } + const int64_t bs = points_idx.size(0); const dim3 threadsPerBlock(64); @@ -189,7 +225,7 @@ std::tuple weightedSumNormCudaBackward( // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // doubles. Currently, support is for floats only. - weightedSumNormCudaBackwardKernel<<>>( + weightedSumNormCudaBackwardKernel<<>>( // clang-format off grad_features.packed_accessor64(), grad_alphas.packed_accessor64(), @@ -198,6 +234,6 @@ std::tuple weightedSumNormCudaBackward( alphas.packed_accessor64(), points_idx.packed_accessor64()); // clang-format on - + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_features, grad_alphas); } diff --git a/pytorch3d/csrc/compositing/weighted_sum.cu b/pytorch3d/csrc/compositing/weighted_sum.cu index fd551192..862aea0a 100644 --- a/pytorch3d/csrc/compositing/weighted_sum.cu +++ b/pytorch3d/csrc/compositing/weighted_sum.cu @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -110,6 +112,17 @@ at::Tensor weightedSumCudaForward( const at::Tensor& features, const at::Tensor& alphas, const at::Tensor& points_idx) { + // Check inputs are on the same device + at::TensorArg features_t{features, "features", 1}, + alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3}; + at::CheckedFrom c = "weightedSumCudaForward"; + at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t}); + at::checkAllSameType(c, {features_t, alphas_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(features.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t batch_size = points_idx.size(0); const int64_t C = features.size(0); const int64_t H = points_idx.size(2); @@ -117,19 +130,24 @@ at::Tensor weightedSumCudaForward( auto result = at::zeros({batch_size, C, H, W}, features.options()); + if (result.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return result; + } + const dim3 threadsPerBlock(64); const dim3 numBlocks(batch_size, 1024 / batch_size + 1); // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // doubles. Currently, support is for floats only. - weightedSumCudaForwardKernel<<>>( + weightedSumCudaForwardKernel<<>>( // clang-format off result.packed_accessor64(), features.packed_accessor64(), alphas.packed_accessor64(), points_idx.packed_accessor64()); // clang-format on - + AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -138,9 +156,26 @@ std::tuple weightedSumCudaBackward( const at::Tensor& features, const at::Tensor& alphas, const at::Tensor& points_idx) { + // Check inputs are on the same device + at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1}, + features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3}, + points_idx_t{points_idx, "points_idx", 4}; + at::CheckedFrom c = "weightedSumCudaBackward"; + at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t}); + at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(features.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto grad_features = at::zeros_like(features); auto grad_alphas = at::zeros_like(alphas); + if (grad_features.numel() == 0 || grad_alphas.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_features, grad_alphas); + } + const int64_t bs = points_idx.size(0); const dim3 threadsPerBlock(64); @@ -148,7 +183,7 @@ std::tuple weightedSumCudaBackward( // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // doubles. Currently, support is for floats only. - weightedSumCudaBackwardKernel<<>>( + weightedSumCudaBackwardKernel<<>>( // clang-format off grad_features.packed_accessor64(), grad_alphas.packed_accessor64(), @@ -157,6 +192,6 @@ std::tuple weightedSumCudaBackward( alphas.packed_accessor64(), points_idx.packed_accessor64()); // clang-format on - + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_features, grad_alphas); } diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 21d32fb7..4dc9454a 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -23,7 +23,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #endif m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_backward", &KNearestNeighborBackward); - m.def("gather_scatter", &gather_scatter); + m.def("gather_scatter", &GatherScatter); m.def("rasterize_points", &RasterizePoints); m.def("rasterize_points_backward", &RasterizePointsBackward); m.def("rasterize_meshes_backward", &RasterizeMeshesBackward); diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu b/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu index c500eb12..e1ee2261 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +#include #include template @@ -213,14 +215,30 @@ std::tuple FaceAreasNormalsForwardCuda( const auto V = verts.size(0); const auto F = faces.size(0); + // Check inputs are on the same device + at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2}; + at::CheckedFrom c = "FaceAreasNormalsForwardCuda"; + at::checkAllSameGPU(c, {verts_t, faces_t}); + at::checkAllSameType(c, {verts_t, faces_t}); + + // Set the device for the kernel launch based on the device of verts + at::cuda::CUDAGuard device_guard(verts.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + at::Tensor areas = at::empty({F}, verts.options()); at::Tensor normals = at::empty({F, 3}, verts.options()); + if (areas.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(areas, normals); + } + const int blocks = 64; const int threads = 512; + AT_DISPATCH_FLOATING_TYPES( verts.scalar_type(), "face_areas_normals_forward_cuda", ([&] { - FaceAreasNormalsForwardKernel<<>>( + FaceAreasNormalsForwardKernel<<>>( verts.data_ptr(), faces.data_ptr(), areas.data_ptr(), @@ -228,7 +246,7 @@ std::tuple FaceAreasNormalsForwardCuda( V, F); })); - + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(areas, normals); } @@ -237,16 +255,33 @@ at::Tensor FaceAreasNormalsBackwardCuda( const at::Tensor grad_normals, const at::Tensor verts, const at::Tensor faces) { + // Check inputs are on the same device + at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2}, + grad_areas_t{verts, "grad_areas", 3}, + grad_normals_t{verts, "grad_normals", 4}; + at::CheckedFrom c = "FaceAreasNormalsBackwardCuda"; + at::checkAllSameGPU(c, {verts_t, faces_t, grad_areas_t, grad_normals_t}); + at::checkAllSameType(c, {verts_t, faces_t, grad_areas_t, grad_normals_t}); + + // Set the device for the kernel launch based on the device of verts + at::cuda::CUDAGuard device_guard(verts.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const auto V = verts.size(0); const auto F = faces.size(0); at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options()); + if (grad_verts.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_verts; + } + const int blocks = 64; const int threads = 512; // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports // doubles. Currently, support is for floats only. - FaceAreasNormalsBackwardKernel<<>>( + FaceAreasNormalsBackwardKernel<<>>( grad_areas.data_ptr(), grad_normals.data_ptr(), verts.data_ptr(), @@ -255,5 +290,6 @@ at::Tensor FaceAreasNormalsBackwardCuda( V, F); + AT_CUDA_CHECK(cudaGetLastError()); return grad_verts; } diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals.h b/pytorch3d/csrc/face_areas_normals/face_areas_normals.h index ad5d5065..3188ad3b 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals.h +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals.h @@ -3,6 +3,7 @@ #pragma once #include #include +#include "utils/pytorch3d_cutils.h" // Compute areas of mesh faces using packed representation. // @@ -46,6 +47,8 @@ std::tuple FaceAreasNormalsForward( const at::Tensor faces) { if (verts.is_cuda() && faces.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(verts); + CHECK_CONTIGUOUS_CUDA(faces); return FaceAreasNormalsForwardCuda(verts, faces); #else AT_ERROR("Not compiled with GPU support."); @@ -62,6 +65,10 @@ at::Tensor FaceAreasNormalsBackward( const at::Tensor faces) { if (verts.is_cuda() && faces.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(verts); + CHECK_CONTIGUOUS_CUDA(faces); + CHECK_CONTIGUOUS_CUDA(grad_areas); + CHECK_CONTIGUOUS_CUDA(grad_normals); return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces); #else AT_ERROR("Not compiled with GPU support."); diff --git a/pytorch3d/csrc/gather_scatter/gather_scatter.cu b/pytorch3d/csrc/gather_scatter/gather_scatter.cu index 826a6ab7..4740a00e 100644 --- a/pytorch3d/csrc/gather_scatter/gather_scatter.cu +++ b/pytorch3d/csrc/gather_scatter/gather_scatter.cu @@ -1,9 +1,11 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +#include // TODO(T47953967) to make this cuda kernel support all datatypes. -__global__ void gather_scatter_kernel( +__global__ void GatherScatterCudaKernel( const float* __restrict__ input, const int64_t* __restrict__ edges, float* __restrict__ output, @@ -41,11 +43,20 @@ __global__ void gather_scatter_kernel( } } -at::Tensor gather_scatter_cuda( +at::Tensor GatherScatterCuda( const at::Tensor input, const at::Tensor edges, bool directed, bool backward) { + // Check inputs are on the same device + at::TensorArg input_t{input, "input", 1}, edges_t{edges, "edges", 2}; + at::CheckedFrom c = "GatherScatterCuda"; + at::checkAllSameGPU(c, {input_t, edges_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const auto num_vertices = input.size(0); const auto input_feature_dim = input.size(1); const auto num_edges = edges.size(0); @@ -55,7 +66,12 @@ at::Tensor gather_scatter_cuda( const size_t max_blocks = 1920; const size_t blocks = num_edges < max_blocks ? num_edges : max_blocks; - gather_scatter_kernel<<>>( + if (output.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return output; + } + + GatherScatterCudaKernel<<>>( input.data_ptr(), edges.data_ptr(), output.data_ptr(), @@ -64,6 +80,6 @@ at::Tensor gather_scatter_cuda( num_vertices, input_feature_dim, num_edges); - + AT_CUDA_CHECK(cudaGetLastError()); return output; } diff --git a/pytorch3d/csrc/gather_scatter/gather_scatter.h b/pytorch3d/csrc/gather_scatter/gather_scatter.h index e5199c71..53f3d1ac 100644 --- a/pytorch3d/csrc/gather_scatter/gather_scatter.h +++ b/pytorch3d/csrc/gather_scatter/gather_scatter.h @@ -2,6 +2,7 @@ #pragma once #include +#include "utils/pytorch3d_cutils.h" // Fused gather scatter operation for aggregating features of neighbor nodes // in a graph. This gather scatter operation is specific to graphs as edge @@ -20,21 +21,23 @@ // output: float32 Tensor of same shape as input. // Cuda implementation. -at::Tensor gather_scatter_cuda( +at::Tensor GatherScatterCuda( const at::Tensor input, const at::Tensor edges, bool directed, bool backward); // Exposed implementation. -at::Tensor gather_scatter( +at::Tensor GatherScatter( const at::Tensor input, const at::Tensor edges, bool directed, bool backward) { if (input.is_cuda() && edges.is_cuda()) { #ifdef WITH_CUDA - return gather_scatter_cuda(input, edges, directed, backward); + CHECK_CONTIGUOUS_CUDA(input); + CHECK_CONTIGUOUS_CUDA(edges); + return GatherScatterCuda(input, edges, directed, backward); #else AT_ERROR("Not compiled with GPU support."); #endif diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index b045788c..9725e7b5 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +#include #include #include #include @@ -114,7 +116,8 @@ struct KNearestNeighborV1Functor { const size_t P1, const size_t P2, const size_t K) { - KNearestNeighborKernelV1<<>>( + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV1<<>>( points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K); } }; @@ -178,7 +181,8 @@ struct KNearestNeighborKernelV2Functor { const int64_t N, const int64_t P1, const int64_t P2) { - KNearestNeighborKernelV2<<>>( + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV2<<>>( points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); } }; @@ -245,7 +249,8 @@ struct KNearestNeighborKernelV3Functor { const size_t N, const size_t P1, const size_t P2) { - KNearestNeighborKernelV3<<>>( + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV3<<>>( points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); } }; @@ -296,17 +301,33 @@ std::tuple KNearestNeighborIdxCuda( const at::Tensor& lengths2, int K, int version) { + // Check inputs are on the same device + at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, + lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; + at::CheckedFrom c = "KNearestNeighborIdxCuda"; + at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t}); + at::checkAllSameType(c, {p1_t, p2_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(p1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const auto N = p1.size(0); const auto P1 = p1.size(1); const auto P2 = p2.size(1); const auto D = p2.size(2); const int64_t K_64 = K; - AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension"); + TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension"); auto long_dtype = p1.options().dtype(at::kLong); auto idxs = at::zeros({N, P1, K}, long_dtype); auto dists = at::zeros({N, P1, K}, p1.options()); + if (idxs.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(idxs, dists); + } + if (version < 0) { version = ChooseVersion(D, K); } else if (!KnnCheckVersion(version, D, K)) { @@ -328,7 +349,7 @@ std::tuple KNearestNeighborIdxCuda( if (version == 0) { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { KNearestNeighborKernelV0 - <<>>( + <<>>( p1.data_ptr(), p2.data_ptr(), lengths1.data_ptr(), @@ -409,7 +430,7 @@ std::tuple KNearestNeighborIdxCuda( P2); })); } - + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(idxs, dists); } @@ -465,27 +486,45 @@ std::tuple KNearestNeighborBackwardCuda( const at::Tensor& lengths2, const at::Tensor& idxs, const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, + lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}, + idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6}; + at::CheckedFrom c = "KNearestNeighborBackwardCuda"; + at::checkAllSameGPU( + c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t}); + at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(p1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const auto N = p1.size(0); const auto P1 = p1.size(1); const auto P2 = p2.size(1); const auto D = p2.size(2); const auto K = idxs.size(2); - AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension"); - AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension"); - AT_ASSERTM( + TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension"); + TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension"); + TORCH_CHECK( idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1"); - AT_ASSERTM(grad_dists.size(0) == N); - AT_ASSERTM(grad_dists.size(1) == P1); - AT_ASSERTM(grad_dists.size(2) == K); + TORCH_CHECK(grad_dists.size(0) == N); + TORCH_CHECK(grad_dists.size(1) == P1); + TORCH_CHECK(grad_dists.size(2) == K); auto grad_p1 = at::zeros({N, P1, D}, p1.options()); auto grad_p2 = at::zeros({N, P2, D}, p2.options()); + if (grad_p1.numel() == 0 || grad_p2.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_p1, grad_p2); + } + const int blocks = 64; const int threads = 512; - KNearestNeighborBackwardKernel<<>>( + KNearestNeighborBackwardKernel<<>>( p1.data_ptr(), p2.data_ptr(), lengths1.data_ptr(), @@ -500,5 +539,6 @@ std::tuple KNearestNeighborBackwardCuda( K, D); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_p1, grad_p2); } diff --git a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu index c2347eb8..09e408a7 100644 --- a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu +++ b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +#include // Kernel for inputs_packed of shape (F, D), where D > 1 template @@ -114,21 +116,36 @@ at::Tensor PackedToPaddedCuda( const at::Tensor inputs_packed, const at::Tensor first_idxs, const int64_t max_size) { + // Check inputs are on the same device + at::TensorArg inputs_packed_t{inputs_packed, "inputs_packed", 1}, + first_idxs_t{first_idxs, "first_idxs", 2}; + at::CheckedFrom c = "PackedToPaddedCuda"; + at::checkAllSameGPU(c, {inputs_packed_t, first_idxs_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(inputs_packed.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t num_inputs = inputs_packed.size(0); const int64_t batch_size = first_idxs.size(0); - AT_ASSERTM( + TORCH_CHECK( inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor"); const int64_t D = inputs_packed.size(1); at::Tensor inputs_padded = at::zeros({batch_size, max_size, D}, inputs_packed.options()); + if (inputs_padded.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return inputs_padded; + } + const int threads = 512; const int blocks = batch_size; if (D == 1) { AT_DISPATCH_FLOATING_TYPES( inputs_packed.scalar_type(), "packed_to_padded_d1_kernel", ([&] { - PackedToPaddedKernelD1<<>>( + PackedToPaddedKernelD1<<>>( inputs_packed.data_ptr(), first_idxs.data_ptr(), inputs_padded.data_ptr(), @@ -139,7 +156,7 @@ at::Tensor PackedToPaddedCuda( } else { AT_DISPATCH_FLOATING_TYPES( inputs_packed.scalar_type(), "packed_to_padded_kernel", ([&] { - PackedToPaddedKernel<<>>( + PackedToPaddedKernel<<>>( inputs_packed.data_ptr(), first_idxs.data_ptr(), inputs_padded.data_ptr(), @@ -150,6 +167,7 @@ at::Tensor PackedToPaddedCuda( })); } + AT_CUDA_CHECK(cudaGetLastError()); return inputs_padded; } @@ -157,11 +175,21 @@ at::Tensor PaddedToPackedCuda( const at::Tensor inputs_padded, const at::Tensor first_idxs, const int64_t num_inputs) { + // Check inputs are on the same device + at::TensorArg inputs_padded_t{inputs_padded, "inputs_padded", 1}, + first_idxs_t{first_idxs, "first_idxs", 2}; + at::CheckedFrom c = "PaddedToPackedCuda"; + at::checkAllSameGPU(c, {inputs_padded_t, first_idxs_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(inputs_padded.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t batch_size = inputs_padded.size(0); const int64_t max_size = inputs_padded.size(1); - AT_ASSERTM(batch_size == first_idxs.size(0), "sizes mismatch"); - AT_ASSERTM( + TORCH_CHECK(batch_size == first_idxs.size(0), "sizes mismatch"); + TORCH_CHECK( inputs_padded.dim() == 3, "inputs_padded must be a 3-dimensional tensor"); const int64_t D = inputs_padded.size(2); @@ -169,13 +197,18 @@ at::Tensor PaddedToPackedCuda( at::Tensor inputs_packed = at::zeros({num_inputs, D}, inputs_padded.options()); + if (inputs_packed.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return inputs_packed; + } + const int threads = 512; const int blocks = batch_size; if (D == 1) { AT_DISPATCH_FLOATING_TYPES( inputs_padded.scalar_type(), "padded_to_packed_d1_kernel", ([&] { - PaddedToPackedKernelD1<<>>( + PaddedToPackedKernelD1<<>>( inputs_padded.data_ptr(), first_idxs.data_ptr(), inputs_packed.data_ptr(), @@ -186,7 +219,7 @@ at::Tensor PaddedToPackedCuda( } else { AT_DISPATCH_FLOATING_TYPES( inputs_padded.scalar_type(), "padded_to_packed_kernel", ([&] { - PaddedToPackedKernel<<>>( + PaddedToPackedKernel<<>>( inputs_padded.data_ptr(), first_idxs.data_ptr(), inputs_packed.data_ptr(), @@ -197,5 +230,6 @@ at::Tensor PaddedToPackedCuda( })); } + AT_CUDA_CHECK(cudaGetLastError()); return inputs_packed; } diff --git a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h index c272bb3e..234cf084 100644 --- a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h +++ b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h @@ -2,6 +2,7 @@ #pragma once #include +#include "utils/pytorch3d_cutils.h" // PackedToPadded // Converts a packed tensor into a padded tensor, restoring the batch dimension. @@ -74,6 +75,8 @@ at::Tensor PackedToPadded( const int64_t max_size) { if (inputs_packed.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(inputs_packed); + CHECK_CONTIGUOUS_CUDA(first_idxs); return PackedToPaddedCuda(inputs_packed, first_idxs, max_size); #else AT_ERROR("Not compiled with GPU support."); @@ -89,6 +92,8 @@ at::Tensor PaddedToPacked( const int64_t num_inputs) { if (inputs_padded.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(inputs_padded); + CHECK_CONTIGUOUS_CUDA(first_idxs); return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs); #else AT_ERROR("Not compiled with GPU support."); diff --git a/pytorch3d/csrc/point_mesh/point_mesh_edge.cu b/pytorch3d/csrc/point_mesh/point_mesh_edge.cu index de2acb86..5b438a10 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_edge.cu +++ b/pytorch3d/csrc/point_mesh/point_mesh_edge.cu @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +#include #include #include #include @@ -103,26 +105,45 @@ std::tuple PointEdgeDistanceForwardCuda( const at::Tensor& segms, const at::Tensor& segms_first_idx, const int64_t max_points) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + points_first_idx_t{points_first_idx, "points_first_idx", 2}, + segms_t{segms, "segms", 3}, + segms_first_idx_t{segms_first_idx, "segms_first_idx", 4}; + at::CheckedFrom c = "PointEdgeDistanceForwardCuda"; + at::checkAllSameGPU( + c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t}); + at::checkAllSameType(c, {points_t, segms_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t S = segms.size(0); const int64_t B = points_first_idx.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (segms.size(1) == 2) && (segms.size(2) == 3), "segms must be of shape Sx2x3"); - AT_ASSERTM(segms_first_idx.size(0) == B); + TORCH_CHECK(segms_first_idx.size(0) == B); // clang-format off at::Tensor dists = at::zeros({P,}, points.options()); at::Tensor idxs = at::zeros({P,}, points_first_idx.options()); // clang-format on + if (dists.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(dists, idxs); + } + const int threads = 128; const dim3 blocks(max_points, B); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); - PointEdgeForwardKernel<<>>( + PointEdgeForwardKernel<<>>( points.data_ptr(), points_first_idx.data_ptr(), segms.data_ptr(), @@ -132,7 +153,7 @@ std::tuple PointEdgeDistanceForwardCuda( B, P, S); - + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(dists, idxs); } @@ -183,25 +204,42 @@ std::tuple PointEdgeDistanceBackwardCuda( const at::Tensor& segms, const at::Tensor& idx_points, const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + idx_points_t{idx_points, "idx_points", 2}, segms_t{segms, "segms", 3}, + grad_dists_t{grad_dists, "grad_dists", 4}; + at::CheckedFrom c = "PointEdgeDistanceBackwardCuda"; + at::checkAllSameGPU(c, {points_t, idx_points_t, segms_t, grad_dists_t}); + at::checkAllSameType(c, {points_t, segms_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t S = segms.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (segms.size(1) == 2) && (segms.size(2) == 3), "segms must be of shape Sx2x3"); - AT_ASSERTM(idx_points.size(0) == P); - AT_ASSERTM(grad_dists.size(0) == P); + TORCH_CHECK(idx_points.size(0) == P); + TORCH_CHECK(grad_dists.size(0) == P); // clang-format off at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options()); // clang-format on + if (grad_points.numel() == 0 || grad_segms.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_points, grad_segms); + } + const int blocks = 64; const int threads = 512; - PointEdgeBackwardKernel<<>>( + PointEdgeBackwardKernel<<>>( points.data_ptr(), segms.data_ptr(), idx_points.data_ptr(), @@ -210,6 +248,7 @@ std::tuple PointEdgeDistanceBackwardCuda( grad_segms.data_ptr(), P); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_points, grad_segms); } @@ -308,26 +347,45 @@ std::tuple EdgePointDistanceForwardCuda( const at::Tensor& segms, const at::Tensor& segms_first_idx, const int64_t max_segms) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + points_first_idx_t{points_first_idx, "points_first_idx", 2}, + segms_t{segms, "segms", 3}, + segms_first_idx_t{segms_first_idx, "segms_first_idx", 4}; + at::CheckedFrom c = "EdgePointDistanceForwardCuda"; + at::checkAllSameGPU( + c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t}); + at::checkAllSameType(c, {points_t, segms_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t S = segms.size(0); const int64_t B = points_first_idx.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (segms.size(1) == 2) && (segms.size(2) == 3), "segms must be of shape Sx2x3"); - AT_ASSERTM(segms_first_idx.size(0) == B); + TORCH_CHECK(segms_first_idx.size(0) == B); // clang-format off at::Tensor dists = at::zeros({S,}, segms.options()); at::Tensor idxs = at::zeros({S,}, segms_first_idx.options()); // clang-format on + if (dists.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(dists, idxs); + } + const int threads = 128; const dim3 blocks(max_segms, B); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); - EdgePointForwardKernel<<>>( + EdgePointForwardKernel<<>>( points.data_ptr(), points_first_idx.data_ptr(), segms.data_ptr(), @@ -337,7 +395,7 @@ std::tuple EdgePointDistanceForwardCuda( B, P, S); - + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(dists, idxs); } @@ -389,15 +447,27 @@ std::tuple EdgePointDistanceBackwardCuda( const at::Tensor& segms, const at::Tensor& idx_segms, const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + idx_segms_t{idx_segms, "idx_segms", 2}, segms_t{segms, "segms", 3}, + grad_dists_t{grad_dists, "grad_dists", 4}; + at::CheckedFrom c = "PointEdgeDistanceBackwardCuda"; + at::checkAllSameGPU(c, {points_t, idx_segms_t, segms_t, grad_dists_t}); + at::checkAllSameType(c, {points_t, segms_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t S = segms.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (segms.size(1) == 2) && (segms.size(2) == 3), "segms must be of shape Sx2x3"); - AT_ASSERTM(idx_segms.size(0) == S); - AT_ASSERTM(grad_dists.size(0) == S); + TORCH_CHECK(idx_segms.size(0) == S); + TORCH_CHECK(grad_dists.size(0) == S); // clang-format off at::Tensor grad_points = at::zeros({P, 3}, points.options()); @@ -407,7 +477,7 @@ std::tuple EdgePointDistanceBackwardCuda( const int blocks = 64; const int threads = 512; - EdgePointBackwardKernel<<>>( + EdgePointBackwardKernel<<>>( points.data_ptr(), segms.data_ptr(), idx_segms.data_ptr(), @@ -451,26 +521,42 @@ __global__ void PointEdgeArrayForwardKernel( at::Tensor PointEdgeArrayDistanceForwardCuda( const at::Tensor& points, const at::Tensor& segms) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2}; + at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda"; + at::checkAllSameGPU(c, {points_t, segms_t}); + at::checkAllSameType(c, {points_t, segms_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t S = segms.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (segms.size(1) == 2) && (segms.size(2) == 3), "segms must be of shape Sx2x3"); at::Tensor dists = at::zeros({P, S}, points.options()); + if (dists.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return dists; + } + const size_t blocks = 1024; const size_t threads = 64; - PointEdgeArrayForwardKernel<<>>( + PointEdgeArrayForwardKernel<<>>( points.data_ptr(), segms.data_ptr(), dists.data_ptr(), P, S); + AT_CUDA_CHECK(cudaGetLastError()); return dists; } @@ -520,22 +606,38 @@ std::tuple PointEdgeArrayDistanceBackwardCuda( const at::Tensor& points, const at::Tensor& segms, const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2}, + grad_dists_t{grad_dists, "grad_dists", 3}; + at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda"; + at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t}); + at::checkAllSameType(c, {points_t, segms_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t S = segms.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (segms.size(1) == 2) && (segms.size(2) == 3), "segms must be of shape Sx2x3"); - AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == S)); + TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S)); at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options()); + if (grad_points.numel() == 0 || grad_segms.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_points, grad_segms); + } + const size_t blocks = 1024; const size_t threads = 64; - PointEdgeArrayBackwardKernel<<>>( + PointEdgeArrayBackwardKernel<<>>( points.data_ptr(), segms.data_ptr(), grad_dists.data_ptr(), @@ -543,6 +645,6 @@ std::tuple PointEdgeArrayDistanceBackwardCuda( grad_segms.data_ptr(), P, S); - + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_points, grad_segms); } diff --git a/pytorch3d/csrc/point_mesh/point_mesh_edge.h b/pytorch3d/csrc/point_mesh/point_mesh_edge.h index de49daf9..2f72a746 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_edge.h +++ b/pytorch3d/csrc/point_mesh/point_mesh_edge.h @@ -4,6 +4,7 @@ #include #include #include +#include "utils/pytorch3d_cutils.h" // **************************************************************************** // * PointEdgeDistance * @@ -53,6 +54,10 @@ std::tuple PointEdgeDistanceForward( const int64_t max_points) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(points_first_idx); + CHECK_CONTIGUOUS_CUDA(segms); + CHECK_CONTIGUOUS_CUDA(segms_first_idx); return PointEdgeDistanceForwardCuda( points, points_first_idx, segms, segms_first_idx, max_points); #else @@ -93,6 +98,10 @@ std::tuple PointEdgeDistanceBackward( const torch::Tensor& grad_dists) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(segms); + CHECK_CONTIGUOUS_CUDA(idx_points); + CHECK_CONTIGUOUS_CUDA(grad_dists); return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists); #else AT_ERROR("Not compiled with GPU support."); @@ -149,6 +158,10 @@ std::tuple EdgePointDistanceForward( const int64_t max_segms) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(points_first_idx); + CHECK_CONTIGUOUS_CUDA(segms); + CHECK_CONTIGUOUS_CUDA(segms_first_idx); return EdgePointDistanceForwardCuda( points, points_first_idx, segms, segms_first_idx, max_segms); #else @@ -189,6 +202,10 @@ std::tuple EdgePointDistanceBackward( const torch::Tensor& grad_dists) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(segms); + CHECK_CONTIGUOUS_CUDA(idx_segms); + CHECK_CONTIGUOUS_CUDA(grad_dists); return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists); #else AT_ERROR("Not compiled with GPU support."); @@ -220,7 +237,6 @@ std::tuple EdgePointDistanceBackward( // will require for the forward pass 5.8G of memory to store dists. #ifdef WITH_CUDA - torch::Tensor PointEdgeArrayDistanceForwardCuda( const torch::Tensor& points, const torch::Tensor& segms); @@ -231,6 +247,8 @@ torch::Tensor PointEdgeArrayDistanceForward( const torch::Tensor& segms) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(segms); return PointEdgeArrayDistanceForwardCuda(points, segms); #else AT_ERROR("Not compiled with GPU support."); @@ -265,6 +283,9 @@ std::tuple PointEdgeArrayDistanceBackward( const torch::Tensor& grad_dists) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(segms); + CHECK_CONTIGUOUS_CUDA(grad_dists); return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists); #else AT_ERROR("Not compiled with GPU support."); diff --git a/pytorch3d/csrc/point_mesh/point_mesh_face.cu b/pytorch3d/csrc/point_mesh/point_mesh_face.cu index d36e24ef..9b1b22e4 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_face.cu +++ b/pytorch3d/csrc/point_mesh/point_mesh_face.cu @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +#include #include #include #include @@ -104,26 +106,45 @@ std::tuple PointFaceDistanceForwardCuda( const at::Tensor& tris, const at::Tensor& tris_first_idx, const int64_t max_points) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + points_first_idx_t{points_first_idx, "points_first_idx", 2}, + tris_t{tris, "tris", 3}, + tris_first_idx_t{tris_first_idx, "tris_first_idx", 4}; + at::CheckedFrom c = "PointFaceDistanceForwardCuda"; + at::checkAllSameGPU( + c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t}); + at::checkAllSameType(c, {points_t, tris_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t T = tris.size(0); const int64_t B = points_first_idx.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (tris.size(1) == 3) && (tris.size(2) == 3), "tris must be of shape Tx3x3"); - AT_ASSERTM(tris_first_idx.size(0) == B); + TORCH_CHECK(tris_first_idx.size(0) == B); // clang-format off at::Tensor dists = at::zeros({P,}, points.options()); at::Tensor idxs = at::zeros({P,}, points_first_idx.options()); // clang-format on + if (dists.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(dists, idxs); + } + const int threads = 128; const dim3 blocks(max_points, B); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); - PointFaceForwardKernel<<>>( + PointFaceForwardKernel<<>>( points.data_ptr(), points_first_idx.data_ptr(), tris.data_ptr(), @@ -134,6 +155,7 @@ std::tuple PointFaceDistanceForwardCuda( P, T); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(dists, idxs); } @@ -191,25 +213,42 @@ std::tuple PointFaceDistanceBackwardCuda( const at::Tensor& tris, const at::Tensor& idx_points, const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + idx_points_t{idx_points, "idx_points", 2}, tris_t{tris, "tris", 3}, + grad_dists_t{grad_dists, "grad_dists", 4}; + at::CheckedFrom c = "PointFaceDistanceBackwardCuda"; + at::checkAllSameGPU(c, {points_t, idx_points_t, tris_t, grad_dists_t}); + at::checkAllSameType(c, {points_t, tris_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t T = tris.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (tris.size(1) == 3) && (tris.size(2) == 3), "tris must be of shape Tx3x3"); - AT_ASSERTM(idx_points.size(0) == P); - AT_ASSERTM(grad_dists.size(0) == P); + TORCH_CHECK(idx_points.size(0) == P); + TORCH_CHECK(grad_dists.size(0) == P); // clang-format off at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options()); // clang-format on + if (grad_points.numel() == 0 || grad_tris.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_points, grad_tris); + } + const int blocks = 64; const int threads = 512; - PointFaceBackwardKernel<<>>( + PointFaceBackwardKernel<<>>( points.data_ptr(), tris.data_ptr(), idx_points.data_ptr(), @@ -218,6 +257,7 @@ std::tuple PointFaceDistanceBackwardCuda( grad_tris.data_ptr(), P); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_points, grad_tris); } @@ -317,26 +357,45 @@ std::tuple FacePointDistanceForwardCuda( const at::Tensor& tris, const at::Tensor& tris_first_idx, const int64_t max_tris) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + points_first_idx_t{points_first_idx, "points_first_idx", 2}, + tris_t{tris, "tris", 3}, + tris_first_idx_t{tris_first_idx, "tris_first_idx", 4}; + at::CheckedFrom c = "FacePointDistanceForwardCuda"; + at::checkAllSameGPU( + c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t}); + at::checkAllSameType(c, {points_t, tris_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t T = tris.size(0); const int64_t B = points_first_idx.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (tris.size(1) == 3) && (tris.size(2) == 3), "tris must be of shape Tx3x3"); - AT_ASSERTM(tris_first_idx.size(0) == B); + TORCH_CHECK(tris_first_idx.size(0) == B); // clang-format off at::Tensor dists = at::zeros({T,}, tris.options()); at::Tensor idxs = at::zeros({T,}, tris_first_idx.options()); // clang-format on + if (dists.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(dists, idxs); + } + const int threads = 128; const dim3 blocks(max_tris, B); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); - FacePointForwardKernel<<>>( + FacePointForwardKernel<<>>( points.data_ptr(), points_first_idx.data_ptr(), tris.data_ptr(), @@ -347,6 +406,7 @@ std::tuple FacePointDistanceForwardCuda( P, T); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(dists, idxs); } @@ -405,25 +465,42 @@ std::tuple FacePointDistanceBackwardCuda( const at::Tensor& tris, const at::Tensor& idx_tris, const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + idx_tris_t{idx_tris, "idx_tris", 2}, tris_t{tris, "tris", 3}, + grad_dists_t{grad_dists, "grad_dists", 4}; + at::CheckedFrom c = "FacePointDistanceBackwardCuda"; + at::checkAllSameGPU(c, {points_t, idx_tris_t, tris_t, grad_dists_t}); + at::checkAllSameType(c, {points_t, tris_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t T = tris.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (tris.size(1) == 3) && (tris.size(2) == 3), "tris must be of shape Tx3x3"); - AT_ASSERTM(idx_tris.size(0) == T); - AT_ASSERTM(grad_dists.size(0) == T); + TORCH_CHECK(idx_tris.size(0) == T); + TORCH_CHECK(grad_dists.size(0) == T); // clang-format off at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options()); // clang-format on + if (grad_points.numel() == 0 || grad_tris.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_points, grad_tris); + } + const int blocks = 64; const int threads = 512; - FacePointBackwardKernel<<>>( + FacePointBackwardKernel<<>>( points.data_ptr(), tris.data_ptr(), idx_tris.data_ptr(), @@ -432,6 +509,7 @@ std::tuple FacePointDistanceBackwardCuda( grad_tris.data_ptr(), T); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_points, grad_tris); } @@ -468,26 +546,42 @@ __global__ void PointFaceArrayForwardKernel( at::Tensor PointFaceArrayDistanceForwardCuda( const at::Tensor& points, const at::Tensor& tris) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2}; + at::CheckedFrom c = "PointFaceArrayDistanceForwardCuda"; + at::checkAllSameGPU(c, {points_t, tris_t}); + at::checkAllSameType(c, {points_t, tris_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t T = tris.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (tris.size(1) == 3) && (tris.size(2) == 3), "tris must be of shape Tx3x3"); at::Tensor dists = at::zeros({P, T}, points.options()); + if (dists.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return dists; + } + const size_t blocks = 1024; const size_t threads = 64; - PointFaceArrayForwardKernel<<>>( + PointFaceArrayForwardKernel<<>>( points.data_ptr(), tris.data_ptr(), dists.data_ptr(), P, T); + AT_CUDA_CHECK(cudaGetLastError()); return dists; } @@ -546,22 +640,38 @@ std::tuple PointFaceArrayDistanceBackwardCuda( const at::Tensor& points, const at::Tensor& tris, const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2}, + grad_dists_t{grad_dists, "grad_dists", 3}; + at::CheckedFrom c = "PointFaceArrayDistanceBackwardCuda"; + at::checkAllSameGPU(c, {points_t, tris_t, grad_dists_t}); + at::checkAllSameType(c, {points_t, tris_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int64_t P = points.size(0); const int64_t T = tris.size(0); - AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3"); - AT_ASSERTM( + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( (tris.size(1) == 3) && (tris.size(2) == 3), "tris must be of shape Tx3x3"); - AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == T)); + TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == T)); at::Tensor grad_points = at::zeros({P, 3}, points.options()); at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options()); + if (grad_points.numel() == 0 || grad_tris.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_points, grad_tris); + } + const size_t blocks = 1024; const size_t threads = 64; - PointFaceArrayBackwardKernel<<>>( + PointFaceArrayBackwardKernel<<>>( points.data_ptr(), tris.data_ptr(), grad_dists.data_ptr(), @@ -570,5 +680,6 @@ std::tuple PointFaceArrayDistanceBackwardCuda( P, T); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_points, grad_tris); } diff --git a/pytorch3d/csrc/point_mesh/point_mesh_face.h b/pytorch3d/csrc/point_mesh/point_mesh_face.h index e2093b1d..39b9b359 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_face.h +++ b/pytorch3d/csrc/point_mesh/point_mesh_face.h @@ -4,6 +4,7 @@ #include #include #include +#include "utils/pytorch3d_cutils.h" // **************************************************************************** // * PointFaceDistance * @@ -55,6 +56,10 @@ std::tuple PointFaceDistanceForward( const int64_t max_points) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(points_first_idx); + CHECK_CONTIGUOUS_CUDA(tris); + CHECK_CONTIGUOUS_CUDA(tris_first_idx); return PointFaceDistanceForwardCuda( points, points_first_idx, tris, tris_first_idx, max_points); #else @@ -95,6 +100,10 @@ std::tuple PointFaceDistanceBackward( const torch::Tensor& grad_dists) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(tris); + CHECK_CONTIGUOUS_CUDA(idx_points); + CHECK_CONTIGUOUS_CUDA(grad_dists); return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists); #else AT_ERROR("Not compiled with GPU support."); @@ -151,6 +160,10 @@ std::tuple FacePointDistanceForward( const int64_t max_tris) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(points_first_idx); + CHECK_CONTIGUOUS_CUDA(tris); + CHECK_CONTIGUOUS_CUDA(tris_first_idx); return FacePointDistanceForwardCuda( points, points_first_idx, tris, tris_first_idx, max_tris); #else @@ -191,6 +204,10 @@ std::tuple FacePointDistanceBackward( const torch::Tensor& grad_dists) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(tris); + CHECK_CONTIGUOUS_CUDA(idx_tris); + CHECK_CONTIGUOUS_CUDA(grad_dists); return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists); #else AT_ERROR("Not compiled with GPU support."); @@ -233,6 +250,8 @@ torch::Tensor PointFaceArrayDistanceForward( const torch::Tensor& tris) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(tris); return PointFaceArrayDistanceForwardCuda(points, tris); #else AT_ERROR("Not compiled with GPU support."); @@ -254,7 +273,6 @@ torch::Tensor PointFaceArrayDistanceForward( // #ifdef WITH_CUDA - std::tuple PointFaceArrayDistanceBackwardCuda( const torch::Tensor& points, const torch::Tensor& tris, @@ -267,6 +285,9 @@ std::tuple PointFaceArrayDistanceBackward( const torch::Tensor& grad_dists) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(tris); + CHECK_CONTIGUOUS_CUDA(grad_dists); return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists); #else AT_ERROR("Not compiled with GPU support."); diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index c977b0e7..a11285e0 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +#include #include #include #include @@ -285,14 +287,14 @@ RasterizeMeshesNaiveCuda( const int num_closest, const bool perspective_correct, const bool cull_backfaces) { - if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || - face_verts.size(2) != 3) { - AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); - } - if (num_faces_per_mesh.size(0) != mesh_to_faces_packed_first_idx.size(0)) { - AT_ERROR( - "num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx"); - } + TORCH_CHECK( + face_verts.ndimension() == 3 && face_verts.size(1) == 3 && + face_verts.size(2) == 3, + "face_verts must have dimensions (num_faces, 3, 3)"); + + TORCH_CHECK( + num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0), + "num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx"); if (num_closest > kMaxPointsPerPixel) { std::stringstream ss; @@ -300,6 +302,20 @@ RasterizeMeshesNaiveCuda( AT_ERROR(ss.str()); } + // Check inputs are on the same device + at::TensorArg face_verts_t{face_verts, "face_verts", 1}, + mesh_to_faces_packed_first_idx_t{ + mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2}, + num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3}; + at::CheckedFrom c = "RasterizeMeshesNaiveCuda"; + at::checkAllSameGPU( + c, + {face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(face_verts.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int N = num_faces_per_mesh.size(0); // batch size. const int H = image_size; // Assume square images. const int W = image_size; @@ -313,10 +329,15 @@ RasterizeMeshesNaiveCuda( at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts); at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts); + if (face_idxs.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(face_idxs, zbuf, bary, pix_dists); + } + const size_t blocks = 1024; const size_t threads = 64; - RasterizeMeshesNaiveCudaKernel<<>>( + RasterizeMeshesNaiveCudaKernel<<>>( face_verts.contiguous().data_ptr(), mesh_to_faces_packed_first_idx.contiguous().data_ptr(), num_faces_per_mesh.contiguous().data_ptr(), @@ -332,6 +353,7 @@ RasterizeMeshesNaiveCuda( pix_dists.contiguous().data_ptr(), bary.contiguous().data_ptr()); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(face_idxs, zbuf, bary, pix_dists); } @@ -465,6 +487,22 @@ at::Tensor RasterizeMeshesBackwardCuda( const at::Tensor& grad_bary, // (N, H, W, K, 3) const at::Tensor& grad_dists, // (N, H, W, K) const bool perspective_correct) { + // Check inputs are on the same device + at::TensorArg face_verts_t{face_verts, "face_verts", 1}, + pix_to_face_t{pix_to_face, "pix_to_face", 2}, + grad_zbuf_t{grad_zbuf, "grad_zbuf", 3}, + grad_bary_t{grad_bary, "grad_bary", 4}, + grad_dists_t{grad_dists, "grad_dists", 5}; + at::CheckedFrom c = "RasterizeMeshesBackwardCuda"; + at::checkAllSameGPU( + c, {face_verts_t, pix_to_face_t, grad_zbuf_t, grad_bary_t, grad_dists_t}); + at::checkAllSameType( + c, {face_verts_t, grad_zbuf_t, grad_bary_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(face_verts.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int F = face_verts.size(0); const int N = pix_to_face.size(0); const int H = pix_to_face.size(1); @@ -472,10 +510,16 @@ at::Tensor RasterizeMeshesBackwardCuda( const int K = pix_to_face.size(3); at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options()); + + if (grad_face_verts.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_face_verts; + } + const size_t blocks = 1024; const size_t threads = 64; - RasterizeMeshesBackwardCudaKernel<<>>( + RasterizeMeshesBackwardCudaKernel<<>>( face_verts.contiguous().data_ptr(), pix_to_face.contiguous().data_ptr(), perspective_correct, @@ -488,6 +532,7 @@ at::Tensor RasterizeMeshesBackwardCuda( grad_dists.contiguous().data_ptr(), grad_face_verts.contiguous().data_ptr()); + AT_CUDA_CHECK(cudaGetLastError()); return grad_face_verts; } @@ -626,10 +671,24 @@ at::Tensor RasterizeMeshesCoarseCuda( const float blur_radius, const int bin_size, const int max_faces_per_bin) { - if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || - face_verts.size(2) != 3) { - AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); - } + TORCH_CHECK( + face_verts.ndimension() == 3 && face_verts.size(1) == 3 && + face_verts.size(2) == 3, + "face_verts must have dimensions (num_faces, 3, 3)"); + + // Check inputs are on the same device + at::TensorArg face_verts_t{face_verts, "face_verts", 1}, + mesh_to_face_first_idx_t{ + mesh_to_face_first_idx, "mesh_to_face_first_idx", 2}, + num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3}; + at::CheckedFrom c = "RasterizeMeshesCoarseCuda"; + at::checkAllSameGPU( + c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(face_verts.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int W = image_size; const int H = image_size; const int F = face_verts.size(0); @@ -645,12 +704,18 @@ at::Tensor RasterizeMeshesCoarseCuda( auto opts = face_verts.options().dtype(at::kInt); at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts); at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts); + + if (bin_faces.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return bin_faces; + } + const int chunk_size = 512; const size_t shared_size = num_bins * num_bins * chunk_size / 8; const size_t blocks = 64; const size_t threads = 512; - RasterizeMeshesCoarseCudaKernel<<>>( + RasterizeMeshesCoarseCudaKernel<<>>( face_verts.contiguous().data_ptr(), mesh_to_face_first_idx.contiguous().data_ptr(), num_faces_per_mesh.contiguous().data_ptr(), @@ -664,6 +729,8 @@ at::Tensor RasterizeMeshesCoarseCuda( M, faces_per_bin.contiguous().data_ptr(), bin_faces.contiguous().data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); return bin_faces; } @@ -775,13 +842,22 @@ RasterizeMeshesFineCuda( const int faces_per_pixel, const bool perspective_correct, const bool cull_backfaces) { - if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || - face_verts.size(2) != 3) { - AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); - } - if (bin_faces.ndimension() != 4) { - AT_ERROR("bin_faces must have 4 dimensions"); - } + TORCH_CHECK( + face_verts.ndimension() == 3 && face_verts.size(1) == 3 && + face_verts.size(2) == 3, + "face_verts must have dimensions (num_faces, 3, 3)"); + TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions"); + + // Check inputs are on the same device + at::TensorArg face_verts_t{face_verts, "face_verts", 1}, + bin_faces_t{bin_faces, "bin_faces", 2}; + at::CheckedFrom c = "RasterizeMeshesFineCuda"; + at::checkAllSameGPU(c, {face_verts_t, bin_faces_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(face_verts.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int N = bin_faces.size(0); const int B = bin_faces.size(1); const int M = bin_faces.size(3); @@ -790,7 +866,7 @@ RasterizeMeshesFineCuda( const int W = image_size; if (K > kMaxPointsPerPixel) { - AT_ERROR("Must have num_closest <= 8"); + AT_ERROR("Must have num_closest <= 150"); } auto long_opts = face_verts.options().dtype(at::kLong); auto float_opts = face_verts.options().dtype(at::kFloat); @@ -800,10 +876,15 @@ RasterizeMeshesFineCuda( at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts); at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts); + if (face_idxs.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(face_idxs, zbuf, bary, pix_dists); + } + const size_t blocks = 1024; const size_t threads = 64; - RasterizeMeshesFineCudaKernel<<>>( + RasterizeMeshesFineCudaKernel<<>>( face_verts.contiguous().data_ptr(), bin_faces.contiguous().data_ptr(), blur_radius, diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h index 210cecbf..4f8f4044 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h @@ -4,6 +4,7 @@ #include #include #include +#include "utils/pytorch3d_cutils.h" // **************************************************************************** // * FORWARD PASS * @@ -95,6 +96,9 @@ RasterizeMeshesNaive( // TODO: Better type checking. if (face_verts.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(face_verts); + CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx); + CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh); return RasterizeMeshesNaiveCuda( face_verts, mesh_to_face_first_idx, @@ -175,6 +179,11 @@ torch::Tensor RasterizeMeshesBackward( const bool perspective_correct) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(face_verts); + CHECK_CONTIGUOUS_CUDA(pix_to_face); + CHECK_CONTIGUOUS_CUDA(grad_zbuf); + CHECK_CONTIGUOUS_CUDA(grad_bary); + CHECK_CONTIGUOUS_CUDA(grad_dists); return RasterizeMeshesBackwardCuda( face_verts, pix_to_face, @@ -251,6 +260,9 @@ torch::Tensor RasterizeMeshesCoarse( const int max_faces_per_bin) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(face_verts); + CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx); + CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh); return RasterizeMeshesCoarseCuda( face_verts, mesh_to_face_first_idx, @@ -347,6 +359,8 @@ RasterizeMeshesFine( const bool cull_backfaces) { if (face_verts.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(face_verts); + CHECK_CONTIGUOUS_CUDA(bin_faces); return RasterizeMeshesFineCuda( face_verts, bin_faces, diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.cu b/pytorch3d/csrc/rasterize_points/rasterize_points.cu index 57b32d0a..7279533b 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.cu +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.cu @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +#include #include #include #include @@ -145,13 +147,25 @@ std::tuple RasterizePointsNaiveCuda( const int image_size, const float radius, const int points_per_pixel) { - if (points.ndimension() != 2 || points.size(1) != 3) { - AT_ERROR("points must have dimensions (num_points, 3)"); - } - if (num_points_per_cloud.size(0) != cloud_to_packed_first_idx.size(0)) { - AT_ERROR( - "num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx"); - } + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + cloud_to_packed_first_idx_t{ + cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2}, + num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3}; + at::CheckedFrom c = "RasterizePointsNaiveCuda"; + at::checkAllSameGPU( + c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + TORCH_CHECK( + points.ndimension() == 2 && points.size(1) == 3, + "points must have dimensions (num_points, 3)"); + TORCH_CHECK( + num_points_per_cloud.size(0) == cloud_to_packed_first_idx.size(0), + "num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx"); const int N = num_points_per_cloud.size(0); // batch size. const int S = image_size; @@ -169,9 +183,14 @@ std::tuple RasterizePointsNaiveCuda( at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); + if (point_idxs.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(point_idxs, zbuf, pix_dists); + } + const size_t blocks = 1024; const size_t threads = 64; - RasterizePointsNaiveCudaKernel<<>>( + RasterizePointsNaiveCudaKernel<<>>( points.contiguous().data_ptr(), cloud_to_packed_first_idx.contiguous().data_ptr(), num_points_per_cloud.contiguous().data_ptr(), @@ -182,6 +201,8 @@ std::tuple RasterizePointsNaiveCuda( point_idxs.contiguous().data_ptr(), zbuf.contiguous().data_ptr(), pix_dists.contiguous().data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(point_idxs, zbuf, pix_dists); } @@ -323,14 +344,28 @@ at::Tensor RasterizePointsCoarseCuda( const float radius, const int bin_size, const int max_points_per_bin) { + TORCH_CHECK( + points.ndimension() == 2 && points.size(1) == 3, + "points must have dimensions (num_points, 3)"); + + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + cloud_to_packed_first_idx_t{ + cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2}, + num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3}; + at::CheckedFrom c = "RasterizePointsCoarseCuda"; + at::checkAllSameGPU( + c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int P = points.size(0); const int N = num_points_per_cloud.size(0); const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up const int M = max_points_per_bin; - if (points.ndimension() != 2 || points.size(1) != 3) { - AT_ERROR("points must have dimensions (num_points, 3)"); - } if (num_bins >= 22) { // Make sure we do not use too much shared memory. std::stringstream ss; @@ -340,12 +375,18 @@ at::Tensor RasterizePointsCoarseCuda( auto opts = points.options().dtype(at::kInt); at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts); at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts); + + if (bin_points.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return bin_points; + } + const int chunk_size = 512; const size_t shared_size = num_bins * num_bins * chunk_size / 8; const size_t blocks = 64; const size_t threads = 512; - RasterizePointsCoarseCudaKernel<<>>( + RasterizePointsCoarseCudaKernel<<>>( points.contiguous().data_ptr(), cloud_to_packed_first_idx.contiguous().data_ptr(), num_points_per_cloud.contiguous().data_ptr(), @@ -358,6 +399,8 @@ at::Tensor RasterizePointsCoarseCuda( M, points_per_bin.contiguous().data_ptr(), bin_points.contiguous().data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); return bin_points; } @@ -448,13 +491,23 @@ std::tuple RasterizePointsFineCuda( const float radius, const int bin_size, const int points_per_pixel) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, + bin_points_t{bin_points, "bin_points", 2}; + at::CheckedFrom c = "RasterizePointsFineCuda"; + at::checkAllSameGPU(c, {points_t, bin_points_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int N = bin_points.size(0); const int B = bin_points.size(1); // num_bins const int M = bin_points.size(3); const int S = image_size; const int K = points_per_pixel; if (K > kMaxPointsPerPixel) { - AT_ERROR("Must have num_closest <= 8"); + AT_ERROR("Must have num_closest <= 150"); } auto int_opts = points.options().dtype(at::kInt); auto float_opts = points.options().dtype(at::kFloat); @@ -462,9 +515,14 @@ std::tuple RasterizePointsFineCuda( at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts); at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts); + if (point_idxs.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(point_idxs, zbuf, pix_dists); + } + const size_t blocks = 1024; const size_t threads = 64; - RasterizePointsFineCudaKernel<<>>( + RasterizePointsFineCudaKernel<<>>( points.contiguous().data_ptr(), bin_points.contiguous().data_ptr(), radius, @@ -478,6 +536,7 @@ std::tuple RasterizePointsFineCuda( zbuf.contiguous().data_ptr(), pix_dists.contiguous().data_ptr()); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(point_idxs, zbuf, pix_dists); } @@ -537,6 +596,19 @@ at::Tensor RasterizePointsBackwardCuda( const at::Tensor& idxs, // (N, H, W, K) const at::Tensor& grad_zbuf, // (N, H, W, K) const at::Tensor& grad_dists) { // (N, H, W, K) + + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, idxs_t{idxs, "idxs", 2}, + grad_zbuf_t{grad_zbuf, "grad_zbuf", 3}, + grad_dists_t{grad_dists, "grad_dists", 4}; + at::CheckedFrom c = "RasterizePointsBackwardCuda"; + at::checkAllSameGPU(c, {points_t, idxs_t, grad_zbuf_t, grad_dists_t}); + at::checkAllSameType(c, {points_t, grad_zbuf_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int P = points.size(0); const int N = idxs.size(0); const int H = idxs.size(1); @@ -544,10 +616,16 @@ at::Tensor RasterizePointsBackwardCuda( const int K = idxs.size(3); at::Tensor grad_points = at::zeros({P, 3}, points.options()); + + if (grad_points.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_points; + } + const size_t blocks = 1024; const size_t threads = 64; - RasterizePointsBackwardCudaKernel<<>>( + RasterizePointsBackwardCudaKernel<<>>( points.contiguous().data_ptr(), idxs.contiguous().data_ptr(), N, @@ -559,5 +637,6 @@ at::Tensor RasterizePointsBackwardCuda( grad_dists.contiguous().data_ptr(), grad_points.contiguous().data_ptr()); + AT_CUDA_CHECK(cudaGetLastError()); return grad_points; } diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.h b/pytorch3d/csrc/rasterize_points/rasterize_points.h index ea59732f..9360c020 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.h +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.h @@ -4,6 +4,7 @@ #include #include #include +#include "utils/pytorch3d_cutils.h" // **************************************************************************** // * NAIVE RASTERIZATION * @@ -66,6 +67,9 @@ std::tuple RasterizePointsNaive( if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && num_points_per_cloud.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx); + CHECK_CONTIGUOUS_CUDA(num_points_per_cloud); return RasterizePointsNaiveCuda( points, cloud_to_packed_first_idx, @@ -140,6 +144,9 @@ torch::Tensor RasterizePointsCoarse( if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() && num_points_per_cloud.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx); + CHECK_CONTIGUOUS_CUDA(num_points_per_cloud); return RasterizePointsCoarseCuda( points, cloud_to_packed_first_idx, @@ -208,6 +215,8 @@ std::tuple RasterizePointsFine( const int points_per_pixel) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(bin_points); return RasterizePointsFineCuda( points, bin_points, image_size, radius, bin_size, points_per_pixel); #else @@ -257,6 +266,10 @@ torch::Tensor RasterizePointsBackward( const torch::Tensor& grad_dists) { if (points.is_cuda()) { #ifdef WITH_CUDA + CHECK_CONTIGUOUS_CUDA(points); + CHECK_CONTIGUOUS_CUDA(idxs); + CHECK_CONTIGUOUS_CUDA(grad_zbuf); + CHECK_CONTIGUOUS_CUDA(grad_dists); return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists); #else AT_ERROR("Not compiled with GPU support"); diff --git a/pytorch3d/csrc/utils/pytorch3d_cutils.h b/pytorch3d/csrc/utils/pytorch3d_cutils.h index c8d2853e..c88b7c53 100644 --- a/pytorch3d/csrc/utils/pytorch3d_cutils.h +++ b/pytorch3d/csrc/utils/pytorch3d_cutils.h @@ -3,9 +3,9 @@ #pragma once #include -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x "must be a CUDA tensor.") +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.") #define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x "must be contiguous.") + TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.") #define CHECK_CONTIGUOUS_CUDA(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) diff --git a/tests/bm_chamfer.py b/tests/bm_chamfer.py index 0dcdb803..4de5829e 100644 --- a/tests/bm_chamfer.py +++ b/tests/bm_chamfer.py @@ -8,11 +8,21 @@ from test_chamfer import TestChamfer def bm_chamfer() -> None: - kwargs_list_naive = [ - {"batch_size": 1, "P1": 32, "P2": 64, "return_normals": False}, - {"batch_size": 1, "P1": 32, "P2": 64, "return_normals": True}, - {"batch_size": 32, "P1": 32, "P2": 64, "return_normals": False}, - ] + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda:0") + + kwargs_list_naive = [] + batch_size = [1, 32] + return_normals = [True, False] + test_cases = product(batch_size, return_normals, devices) + + for case in test_cases: + b, n, d = case + kwargs_list_naive.append( + {"batch_size": b, "P1": 32, "P2": 64, "return_normals": n, "device": d} + ) + benchmark( TestChamfer.chamfer_naive_with_init, "CHAMFER_NAIVE", @@ -21,6 +31,7 @@ def bm_chamfer() -> None: ) if torch.cuda.is_available(): + device = "cuda:0" kwargs_list = [] batch_size = [1, 32] P1 = [32, 1000, 10000] @@ -38,6 +49,7 @@ def bm_chamfer() -> None: "P2": p2, "return_normals": n, "homogeneous": h, + "device": device, } ) benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1) diff --git a/tests/common_testing.py b/tests/common_testing.py index b8816fc0..7f6898a8 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -20,6 +20,18 @@ def load_rgb_image(filename: str, data_dir: Union[str, Path]): TensorOrArray = Union[torch.Tensor, np.ndarray] +def get_random_cuda_device() -> str: + """ + Function to get a random GPU device from the + available devices. This is useful for testing + that custom cuda kernels can support inputs on + any device without having to set the device explicitly. + """ + num_devices = torch.cuda.device_count() + rand_device_id = torch.randint(high=num_devices, size=(1,)).item() + return "cuda:%d" % rand_device_id + + class TestCaseMixin(unittest.TestCase): def assertSeparate(self, tensor1, tensor2) -> None: """ diff --git a/tests/test_chamfer.py b/tests/test_chamfer.py index dd5c88e4..f059ba95 100644 --- a/tests/test_chamfer.py +++ b/tests/test_chamfer.py @@ -6,7 +6,7 @@ from collections import namedtuple import numpy as np import torch import torch.nn.functional as F -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d.loss import chamfer_distance from pytorch3d.structures.pointclouds import Pointclouds @@ -81,7 +81,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): ) @staticmethod - def chamfer_distance_naive_pointclouds(p1, p2): + def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"): """ Naive iterative implementation of nearest neighbor and chamfer distance. x and y are assumed to be pointclouds objects with points and optionally normals. @@ -97,7 +97,6 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): x_normals = p1.normals_padded() y_normals = p2.normals_padded() - device = torch.device("cuda:0") return_normals = x_normals is not None and y_normals is not None # Initialize all distances to + inf @@ -163,7 +162,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): """ N, P1, D = x.shape P2 = y.size(1) - device = torch.device("cuda:0") + device = x.device return_normals = x_normals is not None and y_normals is not None dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device) @@ -203,7 +202,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): This tests only uses homogeneous pointclouds. """ N, max_P1, max_P2 = 7, 10, 18 - device = "cuda:0" + device = get_random_cuda_device() points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) p1 = points_normals.p1 p2 = points_normals.p2 @@ -237,7 +236,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): which supports heterogeneous pointcloud objects. """ N, max_P1, max_P2 = 3, 70, 70 - device = "cuda:0" + device = get_random_cuda_device() points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) weights = points_normals.weights x_lengths = points_normals.p1_lengths @@ -256,7 +255,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): # Chamfer with pointclouds as input. pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds( - points_normals.cloud1, points_normals.cloud2 + points_normals.cloud1, points_normals.cloud2, device=device ) # Mean reduction point loss. @@ -299,7 +298,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): def test_chamfer_pointcloud_object_withnormals(self): N = 5 P1, P2 = 100, 100 - device = "cuda:0" + device = get_random_cuda_device() reductions = [ ("sum", "sum"), @@ -359,7 +358,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): def test_chamfer_pointcloud_object_nonormals(self): N = 5 P1, P2 = 100, 100 - device = "cuda:0" + device = get_random_cuda_device() reductions = [ ("sum", "sum"), @@ -415,7 +414,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): for point_reduction = "mean" and batch_reduction = None. """ N, max_P1, max_P2 = 7, 10, 18 - device = "cuda:0" + device = get_random_cuda_device() points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) p1 = points_normals.p1 p2 = points_normals.p2 @@ -464,7 +463,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): for point_reduction = "sum" and batch_reduction = None. """ N, P1, P2 = 7, 10, 18 - device = "cuda:0" + device = get_random_cuda_device() points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) p1 = points_normals.p1 p2 = points_normals.p2 @@ -579,7 +578,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): point_reduction in ["mean", "sum"]. """ N, max_P1, max_P2 = 7, 10, 18 - device = "cuda:0" + device = get_random_cuda_device() points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) p1 = points_normals.p1 @@ -681,7 +680,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): def test_incorrect_weights(self): N, P1, P2 = 16, 64, 128 - device = torch.device("cuda:0") + device = get_random_cuda_device() p1 = torch.rand( (N, P1, 3), dtype=torch.float32, device=device, requires_grad=True ) @@ -716,7 +715,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): def test_incorrect_inputs(self): N, P1, P2 = 7, 10, 18 - device = "cuda:0" + device = get_random_cuda_device() points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) p1 = points_normals.p1 p2 = points_normals.p2 @@ -740,11 +739,16 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): @staticmethod def chamfer_with_init( - batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool + batch_size: int, + P1: int, + P2: int, + return_normals: bool, + homogeneous: bool, + device="cpu", ): - p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds( - batch_size, P1, P2 - ) + points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device) + l1 = points_normals.p1_lengths + l2 = points_normals.p2_lengths if homogeneous: # Set lengths to None so in Chamfer it assumes # there is no padding. @@ -754,13 +758,13 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): def loss(): loss, loss_normals = chamfer_distance( - p1, - p2, + points_normals.p1, + points_normals.p2, x_lengths=l1, y_lengths=l2, - x_normals=p1_normals, - y_normals=p2_normals, - weights=weights, + x_normals=points_normals.n1, + y_normals=points_normals.n2, + weights=points_normals.weights, ) torch.cuda.synchronize() @@ -768,16 +772,17 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): @staticmethod def chamfer_naive_with_init( - batch_size: int, P1: int, P2: int, return_normals: bool + batch_size: int, P1: int, P2: int, return_normals: bool, device="cpu" ): - p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds( - batch_size, P1, P2 - ) + points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device) torch.cuda.synchronize() def loss(): loss, loss_normals = TestChamfer.chamfer_distance_naive( - p1, p2, x_normals=p1_normals, y_normals=p2_normals + points_normals.p1, + points_normals.p2, + x_normals=points_normals.n1, + y_normals=points_normals.n2, ) torch.cuda.synchronize() diff --git a/tests/test_compositing.py b/tests/test_compositing.py index 0e396b2a..42e3ecc1 100644 --- a/tests/test_compositing.py +++ b/tests/test_compositing.py @@ -3,6 +3,7 @@ import unittest import torch +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d.renderer.compositing import ( alpha_composite, norm_weighted_sum, @@ -10,7 +11,7 @@ from pytorch3d.renderer.compositing import ( ) -class TestAccumulatePoints(unittest.TestCase): +class TestAccumulatePoints(TestCaseMixin, unittest.TestCase): # NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING) @staticmethod @@ -120,7 +121,7 @@ class TestAccumulatePoints(unittest.TestCase): self._simple_wsumnorm(norm_weighted_sum, device) def test_cuda(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() self._simple_alphacomposite(alpha_composite, device) self._simple_wsum(weighted_sum, device) self._simple_wsumnorm(norm_weighted_sum, device) @@ -142,7 +143,7 @@ class TestAccumulatePoints(unittest.TestCase): C = 3 P = 32 - for d in ["cpu", "cuda"]: + for d in ["cpu", get_random_cuda_device()]: # TODO(gkioxari) add torch.float64 to types after double precision # support is added to atomicAdd for t in [torch.float32]: @@ -181,7 +182,7 @@ class TestAccumulatePoints(unittest.TestCase): res1 = fn1(*args1) res2 = fn2(*args2) - self.assertTrue(torch.allclose(res1.cpu(), res2.cpu(), atol=1e-6)) + self.assertClose(res1.cpu(), res2.cpu(), atol=1e-6) if not compare_grads: return @@ -200,7 +201,7 @@ class TestAccumulatePoints(unittest.TestCase): grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2] for i in range(0, len(grads1)): - self.assertTrue(torch.allclose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)) + self.assertClose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6) def _simple_wsum(self, accum_func, device): # Initialise variables @@ -273,7 +274,7 @@ class TestAccumulatePoints(unittest.TestCase): ] ).to(device) - self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)) + self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3) def _simple_wsumnorm(self, accum_func, device): # Initialise variables @@ -346,7 +347,7 @@ class TestAccumulatePoints(unittest.TestCase): ] ).to(device) - self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)) + self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3) def _simple_alphacomposite(self, accum_func, device): # Initialise variables diff --git a/tests/test_cubify.py b/tests/test_cubify.py index 5e3c0da4..158b8968 100644 --- a/tests/test_cubify.py +++ b/tests/test_cubify.py @@ -33,7 +33,9 @@ class TestCubify(unittest.TestCase): # 1st-check verts, faces = meshes.get_mesh_verts_faces(0) - self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1]))) + self.assertTrue( + torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1])) + ) self.assertTrue( torch.allclose( verts, @@ -78,7 +80,9 @@ class TestCubify(unittest.TestCase): ) # 2nd-check verts, faces = meshes.get_mesh_verts_faces(1) - self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1]))) + self.assertTrue( + torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1])) + ) self.assertTrue( torch.allclose( verts, diff --git a/tests/test_face_areas_normals.py b/tests/test_face_areas_normals.py index 4b9cb974..9ed5c598 100644 --- a/tests/test_face_areas_normals.py +++ b/tests/test_face_areas_normals.py @@ -4,7 +4,7 @@ import unittest import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d.ops import mesh_face_areas_normals from pytorch3d.structures.meshes import Meshes @@ -94,13 +94,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase): self._test_face_areas_normals_helper("cpu") def test_face_areas_normals_cuda(self): - self._test_face_areas_normals_helper("cuda:0") + device = get_random_cuda_device() + self._test_face_areas_normals_helper(device) def test_nonfloats_cpu(self): self._test_face_areas_normals_helper("cpu", dtype=torch.double) def test_nonfloats_cuda(self): - self._test_face_areas_normals_helper("cuda:0", dtype=torch.double) + device = get_random_cuda_device() + self._test_face_areas_normals_helper(device, dtype=torch.double) @staticmethod def face_areas_normals_with_init( diff --git a/tests/test_graph_conv.py b/tests/test_graph_conv.py index 8462cec7..dd64d82d 100644 --- a/tests/test_graph_conv.py +++ b/tests/test_graph_conv.py @@ -4,7 +4,7 @@ import unittest import torch import torch.nn as nn -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d import _C from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python from pytorch3d.structures.meshes import Meshes @@ -14,7 +14,7 @@ from pytorch3d.utils import ico_sphere class TestGraphConv(TestCaseMixin, unittest.TestCase): def test_undirected(self): dtype = torch.float32 - device = torch.device("cuda:0") + device = get_random_cuda_device() verts = torch.tensor( [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device ) @@ -97,7 +97,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase): self.assertClose(y, expected_y) def test_backward(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() @@ -118,7 +118,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase): self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)") def test_cpu_cuda_tensor_error(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() verts = torch.tensor( [[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device ) @@ -134,7 +134,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase): Check that gather_scatter cuda version throws an error if cpu tensors are given as input. """ - device = torch.device("cuda:0") + device = get_random_cuda_device() mesh = ico_sphere() verts = mesh.verts_packed() edges = mesh.edges_packed() diff --git a/tests/test_knn.py b/tests/test_knn.py index 112d1cc9..18244095 100644 --- a/tests/test_knn.py +++ b/tests/test_knn.py @@ -4,7 +4,7 @@ import unittest from itertools import product import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d.ops.knn import _KNN, knn_gather, knn_points @@ -89,7 +89,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase): self._knn_vs_python_square_helper(device) def test_knn_vs_python_square_cuda(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() self._knn_vs_python_square_helper(device) def _knn_vs_python_ragged_helper(self, device): @@ -133,11 +133,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase): self._knn_vs_python_ragged_helper(device) def test_knn_vs_python_ragged_cuda(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() self._knn_vs_python_ragged_helper(device) def test_knn_gather(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() N, P1, P2, K, D = 4, 16, 12, 8, 3 x = torch.rand((N, P1, D), device=device) y = torch.rand((N, P2, D), device=device) diff --git a/tests/test_packed_to_padded.py b/tests/test_packed_to_padded.py index 28ce5d43..4f8b5176 100644 --- a/tests/test_packed_to_padded.py +++ b/tests/test_packed_to_padded.py @@ -3,7 +3,7 @@ import unittest import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d.ops import packed_to_padded, padded_to_packed from pytorch3d.structures.meshes import Meshes @@ -126,13 +126,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): self._test_packed_to_padded_helper(16, "cpu") def test_packed_to_padded_flat_cuda(self): - self._test_packed_to_padded_helper(0, "cuda:0") + device = get_random_cuda_device() + self._test_packed_to_padded_helper(0, device) def test_packed_to_padded_D1_cuda(self): - self._test_packed_to_padded_helper(1, "cuda:0") + device = get_random_cuda_device() + self._test_packed_to_padded_helper(1, device) def test_packed_to_padded_D16_cuda(self): - self._test_packed_to_padded_helper(16, "cuda:0") + device = get_random_cuda_device() + self._test_packed_to_padded_helper(16, device) def _test_padded_to_packed_helper(self, D, device): """ @@ -191,13 +194,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): self._test_padded_to_packed_helper(16, "cpu") def test_padded_to_packed_flat_cuda(self): - self._test_padded_to_packed_helper(0, "cuda:0") + device = get_random_cuda_device() + self._test_padded_to_packed_helper(0, device) def test_padded_to_packed_D1_cuda(self): - self._test_padded_to_packed_helper(1, "cuda:0") + device = get_random_cuda_device() + self._test_padded_to_packed_helper(1, device) def test_padded_to_packed_D16_cuda(self): - self._test_padded_to_packed_helper(16, "cuda:0") + device = get_random_cuda_device() + self._test_padded_to_packed_helper(16, device) def test_invalid_inputs_shapes(self, device="cuda:0"): with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."): diff --git a/tests/test_point_mesh_distance.py b/tests/test_point_mesh_distance.py index 96be9854..76330d83 100644 --- a/tests/test_point_mesh_distance.py +++ b/tests/test_point_mesh_distance.py @@ -4,7 +4,7 @@ import unittest import numpy as np import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d import _C from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance from pytorch3d.structures import Meshes, Pointclouds, packed_to_list @@ -203,7 +203,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): & PointEdgeArrayDistanceBackward """ P, E = 16, 32 - device = torch.device("cuda:0") + device = get_random_cuda_device() points = torch.rand((P, 3), dtype=torch.float32, device=device) edges = torch.rand((E, 2, 3), dtype=torch.float32, device=device) @@ -246,9 +246,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): Test CUDA implementation for PointEdgeDistanceForward & PointEdgeDistanceBackward """ - device = torch.device("cuda:0") + device = get_random_cuda_device() N, V, F, P = 4, 32, 16, 24 - meshes, pcls = self.init_meshes_clouds(N, V, F, P) + meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device) # make points packed a leaf node points_packed = pcls.points_packed().detach().clone() # (P, 3) @@ -327,9 +327,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): Test CUDA implementation for EdgePointDistanceForward & EdgePointDistanceBackward """ - device = torch.device("cuda:0") + device = get_random_cuda_device() N, V, F, P = 4, 32, 16, 24 - meshes, pcls = self.init_meshes_clouds(N, V, F, P) + meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device) # make points packed a leaf node points_packed = pcls.points_packed().detach().clone() # (P, 3) @@ -409,9 +409,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): """ Test point_mesh_edge_distance from pytorch3d.loss """ - device = torch.device("cuda:0") + device = get_random_cuda_device() N, V, F, P = 4, 32, 16, 24 - meshes, pcls = self.init_meshes_clouds(N, V, F, P) + meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device) # clone and detach for another backward pass through the op verts_op = [verts.clone().detach() for verts in meshes.verts_list()] @@ -480,7 +480,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): & PointFaceArrayDistanceBackward """ P, T = 16, 32 - device = torch.device("cuda:0") + device = get_random_cuda_device() points = torch.rand((P, 3), dtype=torch.float32, device=device) tris = torch.rand((T, 3, 3), dtype=torch.float32, device=device) @@ -525,9 +525,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): Test CUDA implementation for PointFaceDistanceForward & PointFaceDistanceBackward """ - device = torch.device("cuda:0") + device = get_random_cuda_device() N, V, F, P = 4, 32, 16, 24 - meshes, pcls = self.init_meshes_clouds(N, V, F, P) + meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device) # make points packed a leaf node points_packed = pcls.points_packed().detach().clone() # (P, 3) @@ -608,9 +608,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): Test CUDA implementation for FacePointDistanceForward & FacePointDistanceBackward """ - device = torch.device("cuda:0") + device = get_random_cuda_device() N, V, F, P = 4, 32, 16, 24 - meshes, pcls = self.init_meshes_clouds(N, V, F, P) + meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device) # make points packed a leaf node points_packed = pcls.points_packed().detach().clone() # (P, 3) @@ -690,9 +690,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): """ Test point_mesh_face_distance from pytorch3d.loss """ - device = torch.device("cuda:0") + device = get_random_cuda_device() N, V, F, P = 4, 32, 16, 24 - meshes, pcls = self.init_meshes_clouds(N, V, F, P) + meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device) # clone and detach for another backward pass through the op verts_op = [verts.clone().detach() for verts in meshes.verts_list()] @@ -751,7 +751,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): @staticmethod def point_mesh_edge(N: int, V: int, F: int, P: int, device: str): device = torch.device(device) - meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P) + meshes, pcls = TestPointMeshDistance.init_meshes_clouds( + N, V, F, P, device=device + ) torch.cuda.synchronize() def loss(): @@ -763,7 +765,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): @staticmethod def point_mesh_face(N: int, V: int, F: int, P: int, device: str): device = torch.device(device) - meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P) + meshes, pcls = TestPointMeshDistance.init_meshes_clouds( + N, V, F, P, device=device + ) torch.cuda.synchronize() def loss(): diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index 9979a041..bb2441d8 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -4,7 +4,7 @@ import functools import unittest import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d import _C from pytorch3d.renderer.mesh.rasterize_meshes import ( rasterize_meshes, @@ -32,7 +32,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): self._test_back_face_culling(rasterize_meshes, device, bin_size=0) def test_simple_cuda_naive(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() self._simple_triangle_raster(rasterize_meshes, device, bin_size=0) self._simple_blurry_raster(rasterize_meshes, device, bin_size=0) self._test_behind_camera(rasterize_meshes, device, bin_size=0) @@ -40,7 +40,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): self._test_back_face_culling(rasterize_meshes, device, bin_size=0) def test_simple_cuda_binned(self): - device = torch.device("cuda:0") + device = get_random_cuda_device() self._simple_triangle_raster(rasterize_meshes, device, bin_size=5) self._simple_blurry_raster(rasterize_meshes, device, bin_size=5) self._test_behind_camera(rasterize_meshes, device, bin_size=5) @@ -54,7 +54,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): blur_radius = 0.1 ** 2 faces_per_pixel = 3 - for d in ["cpu", "cuda"]: + for d in ["cpu", get_random_cuda_device()]: device = torch.device(d) compare_grads = True # Mesh with a single face. @@ -164,7 +164,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): verts1.requires_grad = True meshes_cpu = Meshes(verts=[verts1], faces=[faces1]) - device = torch.device("cuda:0") + device = get_random_cuda_device() meshes_cuda = ico_sphere(0, device) verts2, faces2 = meshes_cuda.get_mesh_verts_faces(0) verts2.requires_grad = True @@ -186,7 +186,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): return self._test_coarse_rasterize(torch.device("cpu")) def test_coarse_cuda(self): - return self._test_coarse_rasterize(torch.device("cuda:0")) + return self._test_coarse_rasterize(get_random_cuda_device()) def test_cpp_vs_cuda_naive_vs_cuda_binned(self): # Make sure that the backward pass runs for all pathways @@ -221,7 +221,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): grad1 = verts.grad.data.cpu().clone() # Option II: CUDA, naive - device = torch.device("cuda:0") + device = get_random_cuda_device() meshes = ico_sphere(0, device) verts, faces = meshes.get_mesh_verts_faces(0) verts.requires_grad = True @@ -229,9 +229,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): args = (meshes, image_size, radius, faces_per_pixel, 0, 0) idx2, zbuf2, bary2, dist2 = rasterize_meshes(*args) - grad_zbuf = grad_zbuf.cuda() - grad_dist = grad_dist.cuda() - grad_bary = grad_bary.cuda() + grad_zbuf = grad_zbuf.to(device) + grad_dist = grad_dist.to(device) + grad_bary = grad_bary.to(device) loss = ( (zbuf2 * grad_zbuf).sum() + (dist2 * grad_dist).sum() @@ -244,7 +244,6 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): grad2 = verts.grad.data.cpu().clone() # Option III: CUDA, binned - device = torch.device("cuda:0") meshes = ico_sphere(0, device) verts, faces = meshes.get_mesh_verts_faces(0) verts.requires_grad = True @@ -302,7 +301,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): bin_size, max_faces_per_bin, ) - device = torch.device("cuda:0") + device = get_random_cuda_device() meshes = meshes.clone().to(device) faces = meshes.faces_packed() @@ -356,8 +355,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): verts1, faces1 = meshes.get_mesh_verts_faces(0) verts1.requires_grad = True meshes1 = Meshes(verts=[verts1], faces=[faces1]) - verts2 = verts1.detach().cuda().requires_grad_(True) - faces2 = faces1.detach().clone().cuda() + device = get_random_cuda_device() + verts2 = verts1.detach().to(device).requires_grad_(True) + faces2 = faces1.detach().clone().to(device) meshes2 = Meshes(verts=[verts2], faces=[faces2]) kwargs = {"image_size": 64, "perspective_correct": True} @@ -367,7 +367,8 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) def test_cuda_naive_vs_binned_perspective_correct(self): - meshes = ico_sphere(2, device=torch.device("cuda")) + device = get_random_cuda_device() + meshes = ico_sphere(2, device=device) verts1, faces1 = meshes.get_mesh_verts_faces(0) verts1.requires_grad = True meshes1 = Meshes(verts=[verts1], faces=[faces1]) @@ -1029,7 +1030,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): max_faces_per_bin: int, ): - meshes = ico_sphere(ico_level, torch.device("cuda:0")) + meshes = ico_sphere(ico_level, get_random_cuda_device()) meshes_batch = meshes.extend(num_meshes) torch.cuda.synchronize() diff --git a/tests/test_rasterize_points.py b/tests/test_rasterize_points.py index 0230d164..e46dc56b 100644 --- a/tests/test_rasterize_points.py +++ b/tests/test_rasterize_points.py @@ -5,7 +5,7 @@ import unittest import numpy as np import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d import _C from pytorch3d.renderer.points.rasterize_points import ( rasterize_points, @@ -25,7 +25,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): self._simple_test_case(rasterize_points, device) def test_naive_simple_cuda(self): - device = torch.device("cuda") + device = get_random_cuda_device() self._simple_test_case(rasterize_points, device, bin_size=0) def test_python_behind_camera(self): @@ -37,7 +37,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): self._test_behind_camera(rasterize_points, torch.device("cpu")) def test_cuda_behind_camera(self): - self._test_behind_camera(rasterize_points, torch.device("cuda"), bin_size=0) + device = get_random_cuda_device() + self._test_behind_camera(rasterize_points, device, bin_size=0) def test_cpp_vs_naive_vs_binned(self): # Make sure that the backward pass runs for all pathways @@ -373,7 +374,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): return self._test_coarse_rasterize(torch.device("cpu")) def test_coarse_cuda(self): - return self._test_coarse_rasterize(torch.device("cuda")) + device = get_random_cuda_device() + return self._test_coarse_rasterize(device) def test_compare_coarse_cpu_vs_cuda(self): torch.manual_seed(231) @@ -405,7 +407,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): ) bp_cpu = _C._rasterize_points_coarse(*args) - pointclouds_cuda = pointclouds.to("cuda:0") + device = get_random_cuda_device() + pointclouds_cuda = pointclouds.to(device) points_packed = pointclouds_cuda.points_packed() cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx() num_points_per_cloud = pointclouds_cuda.num_points_per_cloud() diff --git a/tests/test_sample_points_from_meshes.py b/tests/test_sample_points_from_meshes.py index 6343aa72..17f08daa 100644 --- a/tests/test_sample_points_from_meshes.py +++ b/tests/test_sample_points_from_meshes.py @@ -5,7 +5,7 @@ import unittest from pathlib import Path import torch -from common_testing import TestCaseMixin +from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d.ops import sample_points_from_meshes from pytorch3d.structures.meshes import Meshes from pytorch3d.utils.ico_sphere import ico_sphere @@ -42,7 +42,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): Check sample_points_from_meshes raises an exception if all meshes are invalid. """ - device = torch.device("cuda:0") + device = get_random_cuda_device() verts1 = torch.tensor([], dtype=torch.float32, device=device) faces1 = torch.tensor([], dtype=torch.int64, device=device) meshes = Meshes(verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1]) @@ -56,7 +56,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): For an ico_sphere, the sampled vertices should lie on a unit sphere. For an empty mesh, the samples and normals should be 0. """ - device = torch.device("cuda:0") + device = get_random_cuda_device() # Unit simplex. verts_pyramid = torch.tensor(