mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Cuda updates
Summary: Updates to: - enable cuda kernel launches on any GPU (not just the default) - cuda and contiguous checks for all kernels - checks to ensure all tensors are on the same device - error reporting in the cuda kernels - cuda tests now run on a random device not just the default Reviewed By: jcjohnson, gkioxari Differential Revision: D21215280 fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
parent
c9267ab7af
commit
c3d636dc8c
@ -2,6 +2,8 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/TensorAccessor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
@ -136,6 +138,17 @@ at::Tensor alphaCompositeCudaForward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg features_t{features, "features", 1},
|
||||
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
|
||||
at::CheckedFrom c = "alphaCompositeCudaForward";
|
||||
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
|
||||
at::checkAllSameType(c, {features_t, alphas_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(features.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
@ -143,19 +156,24 @@ at::Tensor alphaCompositeCudaForward(
|
||||
|
||||
auto result = at::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
if (result.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return result;
|
||||
}
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
// clang-format off
|
||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -164,9 +182,26 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
|
||||
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
|
||||
points_idx_t{points_idx, "points_idx", 4};
|
||||
at::CheckedFrom c = "alphaCompositeCudaBackward";
|
||||
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
|
||||
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(features.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
auto grad_features = at::zeros_like(features);
|
||||
auto grad_alphas = at::zeros_like(alphas);
|
||||
|
||||
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
||||
|
||||
const int64_t bs = alphas.size(0);
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
@ -174,7 +209,7 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
@ -183,6 +218,6 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/TensorAccessor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
@ -151,6 +153,17 @@ at::Tensor weightedSumNormCudaForward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg features_t{features, "features", 1},
|
||||
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
|
||||
at::CheckedFrom c = "weightedSumNormCudaForward";
|
||||
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
|
||||
at::checkAllSameType(c, {features_t, alphas_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(features.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
@ -158,19 +171,25 @@ at::Tensor weightedSumNormCudaForward(
|
||||
|
||||
auto result = at::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
if (result.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return result;
|
||||
}
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
// clang-format off
|
||||
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -179,9 +198,26 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
|
||||
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
|
||||
points_idx_t{points_idx, "points_idx", 4};
|
||||
at::CheckedFrom c = "weightedSumNormCudaBackward";
|
||||
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
|
||||
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(features.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
auto grad_features = at::zeros_like(features);
|
||||
auto grad_alphas = at::zeros_like(alphas);
|
||||
|
||||
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
||||
|
||||
const int64_t bs = points_idx.size(0);
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
@ -189,7 +225,7 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
@ -198,6 +234,6 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/TensorAccessor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
@ -110,6 +112,17 @@ at::Tensor weightedSumCudaForward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg features_t{features, "features", 1},
|
||||
alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3};
|
||||
at::CheckedFrom c = "weightedSumCudaForward";
|
||||
at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t});
|
||||
at::checkAllSameType(c, {features_t, alphas_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(features.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
@ -117,19 +130,24 @@ at::Tensor weightedSumCudaForward(
|
||||
|
||||
auto result = at::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
if (result.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return result;
|
||||
}
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
// clang-format off
|
||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -138,9 +156,26 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
|
||||
const at::Tensor& features,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& points_idx) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1},
|
||||
features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3},
|
||||
points_idx_t{points_idx, "points_idx", 4};
|
||||
at::CheckedFrom c = "weightedSumCudaBackward";
|
||||
at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t});
|
||||
at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(features.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
auto grad_features = at::zeros_like(features);
|
||||
auto grad_alphas = at::zeros_like(alphas);
|
||||
|
||||
if (grad_features.numel() == 0 || grad_alphas.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
||||
|
||||
const int64_t bs = points_idx.size(0);
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
@ -148,7 +183,7 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
@ -157,6 +192,6 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
|
||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
|
||||
// clang-format on
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
#endif
|
||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||
m.def("knn_points_backward", &KNearestNeighborBackward);
|
||||
m.def("gather_scatter", &gather_scatter);
|
||||
m.def("gather_scatter", &GatherScatter);
|
||||
m.def("rasterize_points", &RasterizePoints);
|
||||
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
||||
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <tuple>
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -213,14 +215,30 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
|
||||
const auto V = verts.size(0);
|
||||
const auto F = faces.size(0);
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2};
|
||||
at::CheckedFrom c = "FaceAreasNormalsForwardCuda";
|
||||
at::checkAllSameGPU(c, {verts_t, faces_t});
|
||||
at::checkAllSameType(c, {verts_t, faces_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of verts
|
||||
at::cuda::CUDAGuard device_guard(verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
at::Tensor areas = at::empty({F}, verts.options());
|
||||
at::Tensor normals = at::empty({F, 3}, verts.options());
|
||||
|
||||
if (areas.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(areas, normals);
|
||||
}
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
verts.scalar_type(), "face_areas_normals_forward_cuda", ([&] {
|
||||
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads>>>(
|
||||
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
verts.data_ptr<scalar_t>(),
|
||||
faces.data_ptr<int64_t>(),
|
||||
areas.data_ptr<scalar_t>(),
|
||||
@ -228,7 +246,7 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
|
||||
V,
|
||||
F);
|
||||
}));
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(areas, normals);
|
||||
}
|
||||
|
||||
@ -237,16 +255,33 @@ at::Tensor FaceAreasNormalsBackwardCuda(
|
||||
const at::Tensor grad_normals,
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2},
|
||||
grad_areas_t{verts, "grad_areas", 3},
|
||||
grad_normals_t{verts, "grad_normals", 4};
|
||||
at::CheckedFrom c = "FaceAreasNormalsBackwardCuda";
|
||||
at::checkAllSameGPU(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
|
||||
at::checkAllSameType(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of verts
|
||||
at::cuda::CUDAGuard device_guard(verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const auto V = verts.size(0);
|
||||
const auto F = faces.size(0);
|
||||
|
||||
at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options());
|
||||
|
||||
if (grad_verts.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_verts;
|
||||
}
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
FaceAreasNormalsBackwardKernel<<<blocks, threads>>>(
|
||||
FaceAreasNormalsBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
grad_areas.data_ptr<float>(),
|
||||
grad_normals.data_ptr<float>(),
|
||||
verts.data_ptr<float>(),
|
||||
@ -255,5 +290,6 @@ at::Tensor FaceAreasNormalsBackwardCuda(
|
||||
V,
|
||||
F);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_verts;
|
||||
}
|
||||
|
@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// Compute areas of mesh faces using packed representation.
|
||||
//
|
||||
@ -46,6 +47,8 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
|
||||
const at::Tensor faces) {
|
||||
if (verts.is_cuda() && faces.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(verts);
|
||||
CHECK_CONTIGUOUS_CUDA(faces);
|
||||
return FaceAreasNormalsForwardCuda(verts, faces);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
@ -62,6 +65,10 @@ at::Tensor FaceAreasNormalsBackward(
|
||||
const at::Tensor faces) {
|
||||
if (verts.is_cuda() && faces.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(verts);
|
||||
CHECK_CONTIGUOUS_CUDA(faces);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_areas);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_normals);
|
||||
return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
|
@ -1,9 +1,11 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
// TODO(T47953967) to make this cuda kernel support all datatypes.
|
||||
__global__ void gather_scatter_kernel(
|
||||
__global__ void GatherScatterCudaKernel(
|
||||
const float* __restrict__ input,
|
||||
const int64_t* __restrict__ edges,
|
||||
float* __restrict__ output,
|
||||
@ -41,11 +43,20 @@ __global__ void gather_scatter_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor gather_scatter_cuda(
|
||||
at::Tensor GatherScatterCuda(
|
||||
const at::Tensor input,
|
||||
const at::Tensor edges,
|
||||
bool directed,
|
||||
bool backward) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg input_t{input, "input", 1}, edges_t{edges, "edges", 2};
|
||||
at::CheckedFrom c = "GatherScatterCuda";
|
||||
at::checkAllSameGPU(c, {input_t, edges_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(input.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const auto num_vertices = input.size(0);
|
||||
const auto input_feature_dim = input.size(1);
|
||||
const auto num_edges = edges.size(0);
|
||||
@ -55,7 +66,12 @@ at::Tensor gather_scatter_cuda(
|
||||
const size_t max_blocks = 1920;
|
||||
const size_t blocks = num_edges < max_blocks ? num_edges : max_blocks;
|
||||
|
||||
gather_scatter_kernel<<<blocks, threads>>>(
|
||||
if (output.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return output;
|
||||
}
|
||||
|
||||
GatherScatterCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
input.data_ptr<float>(),
|
||||
edges.data_ptr<int64_t>(),
|
||||
output.data_ptr<float>(),
|
||||
@ -64,6 +80,6 @@ at::Tensor gather_scatter_cuda(
|
||||
num_vertices,
|
||||
input_feature_dim,
|
||||
num_edges);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return output;
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// Fused gather scatter operation for aggregating features of neighbor nodes
|
||||
// in a graph. This gather scatter operation is specific to graphs as edge
|
||||
@ -20,21 +21,23 @@
|
||||
// output: float32 Tensor of same shape as input.
|
||||
|
||||
// Cuda implementation.
|
||||
at::Tensor gather_scatter_cuda(
|
||||
at::Tensor GatherScatterCuda(
|
||||
const at::Tensor input,
|
||||
const at::Tensor edges,
|
||||
bool directed,
|
||||
bool backward);
|
||||
|
||||
// Exposed implementation.
|
||||
at::Tensor gather_scatter(
|
||||
at::Tensor GatherScatter(
|
||||
const at::Tensor input,
|
||||
const at::Tensor edges,
|
||||
bool directed,
|
||||
bool backward) {
|
||||
if (input.is_cuda() && edges.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return gather_scatter_cuda(input, edges, directed, backward);
|
||||
CHECK_CONTIGUOUS_CUDA(input);
|
||||
CHECK_CONTIGUOUS_CUDA(edges);
|
||||
return GatherScatterCuda(input, edges, directed, backward);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <float.h>
|
||||
#include <iostream>
|
||||
#include <tuple>
|
||||
@ -114,7 +116,8 @@ struct KNearestNeighborV1Functor {
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K) {
|
||||
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads>>>(
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
|
||||
}
|
||||
};
|
||||
@ -178,7 +181,8 @@ struct KNearestNeighborKernelV2Functor {
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads>>>(
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||
}
|
||||
};
|
||||
@ -245,7 +249,8 @@ struct KNearestNeighborKernelV3Functor {
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2) {
|
||||
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads>>>(
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||
}
|
||||
};
|
||||
@ -296,17 +301,33 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
|
||||
at::CheckedFrom c = "KNearestNeighborIdxCuda";
|
||||
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
|
||||
at::checkAllSameType(c, {p1_t, p2_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(p1.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const auto N = p1.size(0);
|
||||
const auto P1 = p1.size(1);
|
||||
const auto P2 = p2.size(1);
|
||||
const auto D = p2.size(2);
|
||||
const int64_t K_64 = K;
|
||||
|
||||
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
auto long_dtype = p1.options().dtype(at::kLong);
|
||||
auto idxs = at::zeros({N, P1, K}, long_dtype);
|
||||
auto dists = at::zeros({N, P1, K}, p1.options());
|
||||
|
||||
if (idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(idxs, dists);
|
||||
}
|
||||
|
||||
if (version < 0) {
|
||||
version = ChooseVersion(D, K);
|
||||
} else if (!KnnCheckVersion(version, D, K)) {
|
||||
@ -328,7 +349,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
if (version == 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||
KNearestNeighborKernelV0<scalar_t>
|
||||
<<<blocks, threads>>>(
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
@ -409,7 +430,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
P2);
|
||||
}));
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(idxs, dists);
|
||||
}
|
||||
|
||||
@ -465,27 +486,45 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4},
|
||||
idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6};
|
||||
at::CheckedFrom c = "KNearestNeighborBackwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(p1.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const auto N = p1.size(0);
|
||||
const auto P1 = p1.size(1);
|
||||
const auto P2 = p2.size(1);
|
||||
const auto D = p2.size(2);
|
||||
const auto K = idxs.size(2);
|
||||
|
||||
AT_ASSERTM(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
AT_ASSERTM(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
|
||||
TORCH_CHECK(
|
||||
idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
|
||||
AT_ASSERTM(grad_dists.size(0) == N);
|
||||
AT_ASSERTM(grad_dists.size(1) == P1);
|
||||
AT_ASSERTM(grad_dists.size(2) == K);
|
||||
TORCH_CHECK(grad_dists.size(0) == N);
|
||||
TORCH_CHECK(grad_dists.size(1) == P1);
|
||||
TORCH_CHECK(grad_dists.size(2) == K);
|
||||
|
||||
auto grad_p1 = at::zeros({N, P1, D}, p1.options());
|
||||
auto grad_p2 = at::zeros({N, P2, D}, p2.options());
|
||||
|
||||
if (grad_p1.numel() == 0 || grad_p2.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_p1, grad_p2);
|
||||
}
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
KNearestNeighborBackwardKernel<<<blocks, threads>>>(
|
||||
KNearestNeighborBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
p1.data_ptr<float>(),
|
||||
p2.data_ptr<float>(),
|
||||
lengths1.data_ptr<int64_t>(),
|
||||
@ -500,5 +539,6 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||
K,
|
||||
D);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_p1, grad_p2);
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
// Kernel for inputs_packed of shape (F, D), where D > 1
|
||||
template <typename scalar_t>
|
||||
@ -114,21 +116,36 @@ at::Tensor PackedToPaddedCuda(
|
||||
const at::Tensor inputs_packed,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t max_size) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg inputs_packed_t{inputs_packed, "inputs_packed", 1},
|
||||
first_idxs_t{first_idxs, "first_idxs", 2};
|
||||
at::CheckedFrom c = "PackedToPaddedCuda";
|
||||
at::checkAllSameGPU(c, {inputs_packed_t, first_idxs_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(inputs_packed.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t num_inputs = inputs_packed.size(0);
|
||||
const int64_t batch_size = first_idxs.size(0);
|
||||
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(
|
||||
inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
|
||||
const int64_t D = inputs_packed.size(1);
|
||||
at::Tensor inputs_padded =
|
||||
at::zeros({batch_size, max_size, D}, inputs_packed.options());
|
||||
|
||||
if (inputs_padded.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return inputs_padded;
|
||||
}
|
||||
|
||||
const int threads = 512;
|
||||
const int blocks = batch_size;
|
||||
if (D == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs_packed.scalar_type(), "packed_to_padded_d1_kernel", ([&] {
|
||||
PackedToPaddedKernelD1<scalar_t><<<blocks, threads>>>(
|
||||
PackedToPaddedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
inputs_packed.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<int64_t>(),
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
@ -139,7 +156,7 @@ at::Tensor PackedToPaddedCuda(
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs_packed.scalar_type(), "packed_to_padded_kernel", ([&] {
|
||||
PackedToPaddedKernel<scalar_t><<<blocks, threads>>>(
|
||||
PackedToPaddedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
inputs_packed.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<int64_t>(),
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
@ -150,6 +167,7 @@ at::Tensor PackedToPaddedCuda(
|
||||
}));
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return inputs_padded;
|
||||
}
|
||||
|
||||
@ -157,11 +175,21 @@ at::Tensor PaddedToPackedCuda(
|
||||
const at::Tensor inputs_padded,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t num_inputs) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg inputs_padded_t{inputs_padded, "inputs_padded", 1},
|
||||
first_idxs_t{first_idxs, "first_idxs", 2};
|
||||
at::CheckedFrom c = "PaddedToPackedCuda";
|
||||
at::checkAllSameGPU(c, {inputs_padded_t, first_idxs_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(inputs_padded.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t batch_size = inputs_padded.size(0);
|
||||
const int64_t max_size = inputs_padded.size(1);
|
||||
|
||||
AT_ASSERTM(batch_size == first_idxs.size(0), "sizes mismatch");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(batch_size == first_idxs.size(0), "sizes mismatch");
|
||||
TORCH_CHECK(
|
||||
inputs_padded.dim() == 3,
|
||||
"inputs_padded must be a 3-dimensional tensor");
|
||||
const int64_t D = inputs_padded.size(2);
|
||||
@ -169,13 +197,18 @@ at::Tensor PaddedToPackedCuda(
|
||||
at::Tensor inputs_packed =
|
||||
at::zeros({num_inputs, D}, inputs_padded.options());
|
||||
|
||||
if (inputs_packed.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return inputs_packed;
|
||||
}
|
||||
|
||||
const int threads = 512;
|
||||
const int blocks = batch_size;
|
||||
|
||||
if (D == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs_padded.scalar_type(), "padded_to_packed_d1_kernel", ([&] {
|
||||
PaddedToPackedKernelD1<scalar_t><<<blocks, threads>>>(
|
||||
PaddedToPackedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<int64_t>(),
|
||||
inputs_packed.data_ptr<scalar_t>(),
|
||||
@ -186,7 +219,7 @@ at::Tensor PaddedToPackedCuda(
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs_padded.scalar_type(), "padded_to_packed_kernel", ([&] {
|
||||
PaddedToPackedKernel<scalar_t><<<blocks, threads>>>(
|
||||
PaddedToPackedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<int64_t>(),
|
||||
inputs_packed.data_ptr<scalar_t>(),
|
||||
@ -197,5 +230,6 @@ at::Tensor PaddedToPackedCuda(
|
||||
}));
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return inputs_packed;
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// PackedToPadded
|
||||
// Converts a packed tensor into a padded tensor, restoring the batch dimension.
|
||||
@ -74,6 +75,8 @@ at::Tensor PackedToPadded(
|
||||
const int64_t max_size) {
|
||||
if (inputs_packed.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(inputs_packed);
|
||||
CHECK_CONTIGUOUS_CUDA(first_idxs);
|
||||
return PackedToPaddedCuda(inputs_packed, first_idxs, max_size);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
@ -89,6 +92,8 @@ at::Tensor PaddedToPacked(
|
||||
const int64_t num_inputs) {
|
||||
if (inputs_padded.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(inputs_padded);
|
||||
CHECK_CONTIGUOUS_CUDA(first_idxs);
|
||||
return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
@ -103,26 +105,45 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& segms_first_idx,
|
||||
const int64_t max_points) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
points_first_idx_t{points_first_idx, "points_first_idx", 2},
|
||||
segms_t{segms, "segms", 3},
|
||||
segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
|
||||
at::CheckedFrom c = "PointEdgeDistanceForwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM(segms_first_idx.size(0) == B);
|
||||
TORCH_CHECK(segms_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor dists = at::zeros({P,}, points.options());
|
||||
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_points, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
PointEdgeForwardKernel<<<blocks, threads, shared_size>>>(
|
||||
PointEdgeForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
points_first_idx.data_ptr<int64_t>(),
|
||||
segms.data_ptr<float>(),
|
||||
@ -132,7 +153,7 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
|
||||
B,
|
||||
P,
|
||||
S);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
@ -183,25 +204,42 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& idx_points,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
idx_points_t{idx_points, "idx_points", 2}, segms_t{segms, "segms", 3},
|
||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||
at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idx_points_t, segms_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM(idx_points.size(0) == P);
|
||||
AT_ASSERTM(grad_dists.size(0) == P);
|
||||
TORCH_CHECK(idx_points.size(0) == P);
|
||||
TORCH_CHECK(grad_dists.size(0) == P);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
// clang-format on
|
||||
|
||||
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
PointEdgeBackwardKernel<<<blocks, threads>>>(
|
||||
PointEdgeBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
segms.data_ptr<float>(),
|
||||
idx_points.data_ptr<int64_t>(),
|
||||
@ -210,6 +248,7 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
grad_segms.data_ptr<float>(),
|
||||
P);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
@ -308,26 +347,45 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& segms_first_idx,
|
||||
const int64_t max_segms) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
points_first_idx_t{points_first_idx, "points_first_idx", 2},
|
||||
segms_t{segms, "segms", 3},
|
||||
segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
|
||||
at::CheckedFrom c = "EdgePointDistanceForwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM(segms_first_idx.size(0) == B);
|
||||
TORCH_CHECK(segms_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor dists = at::zeros({S,}, segms.options());
|
||||
at::Tensor idxs = at::zeros({S,}, segms_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_segms, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
EdgePointForwardKernel<<<blocks, threads, shared_size>>>(
|
||||
EdgePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
points_first_idx.data_ptr<int64_t>(),
|
||||
segms.data_ptr<float>(),
|
||||
@ -337,7 +395,7 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
|
||||
B,
|
||||
P,
|
||||
S);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
@ -389,15 +447,27 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& idx_segms,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
idx_segms_t{idx_segms, "idx_segms", 2}, segms_t{segms, "segms", 3},
|
||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||
at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idx_segms_t, segms_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM(idx_segms.size(0) == S);
|
||||
AT_ASSERTM(grad_dists.size(0) == S);
|
||||
TORCH_CHECK(idx_segms.size(0) == S);
|
||||
TORCH_CHECK(grad_dists.size(0) == S);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
@ -407,7 +477,7 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
EdgePointBackwardKernel<<<blocks, threads>>>(
|
||||
EdgePointBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
segms.data_ptr<float>(),
|
||||
idx_segms.data_ptr<int64_t>(),
|
||||
@ -451,26 +521,42 @@ __global__ void PointEdgeArrayForwardKernel(
|
||||
at::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2};
|
||||
at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, segms_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
|
||||
at::Tensor dists = at::zeros({P, S}, points.options());
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return dists;
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointEdgeArrayForwardKernel<<<blocks, threads>>>(
|
||||
PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
segms.data_ptr<float>(),
|
||||
dists.data_ptr<float>(),
|
||||
P,
|
||||
S);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return dists;
|
||||
}
|
||||
|
||||
@ -520,22 +606,38 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2},
|
||||
grad_dists_t{grad_dists, "grad_dists", 3};
|
||||
at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
|
||||
TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
|
||||
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
|
||||
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointEdgeArrayBackwardKernel<<<blocks, threads>>>(
|
||||
PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
segms.data_ptr<float>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
@ -543,6 +645,6 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
grad_segms.data_ptr<float>(),
|
||||
P,
|
||||
S);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeDistance *
|
||||
@ -53,6 +54,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
||||
const int64_t max_points) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(points_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(segms);
|
||||
CHECK_CONTIGUOUS_CUDA(segms_first_idx);
|
||||
return PointEdgeDistanceForwardCuda(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
#else
|
||||
@ -93,6 +98,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(segms);
|
||||
CHECK_CONTIGUOUS_CUDA(idx_points);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
@ -149,6 +158,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
||||
const int64_t max_segms) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(points_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(segms);
|
||||
CHECK_CONTIGUOUS_CUDA(segms_first_idx);
|
||||
return EdgePointDistanceForwardCuda(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
#else
|
||||
@ -189,6 +202,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(segms);
|
||||
CHECK_CONTIGUOUS_CUDA(idx_segms);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
@ -220,7 +237,6 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
||||
// will require for the forward pass 5.8G of memory to store dists.
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms);
|
||||
@ -231,6 +247,8 @@ torch::Tensor PointEdgeArrayDistanceForward(
|
||||
const torch::Tensor& segms) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(segms);
|
||||
return PointEdgeArrayDistanceForwardCuda(points, segms);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
@ -265,6 +283,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(segms);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
@ -104,26 +106,45 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_points) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
points_first_idx_t{points_first_idx, "points_first_idx", 2},
|
||||
tris_t{tris, "tris", 3},
|
||||
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
|
||||
at::CheckedFrom c = "PointFaceDistanceForwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM(tris_first_idx.size(0) == B);
|
||||
TORCH_CHECK(tris_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor dists = at::zeros({P,}, points.options());
|
||||
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_points, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
PointFaceForwardKernel<<<blocks, threads, shared_size>>>(
|
||||
PointFaceForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
points_first_idx.data_ptr<int64_t>(),
|
||||
tris.data_ptr<float>(),
|
||||
@ -134,6 +155,7 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
||||
P,
|
||||
T);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
@ -191,25 +213,42 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& idx_points,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
idx_points_t{idx_points, "idx_points", 2}, tris_t{tris, "tris", 3},
|
||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||
at::CheckedFrom c = "PointFaceDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idx_points_t, tris_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM(idx_points.size(0) == P);
|
||||
AT_ASSERTM(grad_dists.size(0) == P);
|
||||
TORCH_CHECK(idx_points.size(0) == P);
|
||||
TORCH_CHECK(grad_dists.size(0) == P);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
|
||||
// clang-format on
|
||||
|
||||
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
PointFaceBackwardKernel<<<blocks, threads>>>(
|
||||
PointFaceBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
tris.data_ptr<float>(),
|
||||
idx_points.data_ptr<int64_t>(),
|
||||
@ -218,6 +257,7 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
||||
grad_tris.data_ptr<float>(),
|
||||
P);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
@ -317,26 +357,45 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_tris) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
points_first_idx_t{points_first_idx, "points_first_idx", 2},
|
||||
tris_t{tris, "tris", 3},
|
||||
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
|
||||
at::CheckedFrom c = "FacePointDistanceForwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM(tris_first_idx.size(0) == B);
|
||||
TORCH_CHECK(tris_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor dists = at::zeros({T,}, tris.options());
|
||||
at::Tensor idxs = at::zeros({T,}, tris_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_tris, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
FacePointForwardKernel<<<blocks, threads, shared_size>>>(
|
||||
FacePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
points_first_idx.data_ptr<int64_t>(),
|
||||
tris.data_ptr<float>(),
|
||||
@ -347,6 +406,7 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
|
||||
P,
|
||||
T);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
@ -405,25 +465,42 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& idx_tris,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
idx_tris_t{idx_tris, "idx_tris", 2}, tris_t{tris, "tris", 3},
|
||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||
at::CheckedFrom c = "FacePointDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idx_tris_t, tris_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM(idx_tris.size(0) == T);
|
||||
AT_ASSERTM(grad_dists.size(0) == T);
|
||||
TORCH_CHECK(idx_tris.size(0) == T);
|
||||
TORCH_CHECK(grad_dists.size(0) == T);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
|
||||
// clang-format on
|
||||
|
||||
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
FacePointBackwardKernel<<<blocks, threads>>>(
|
||||
FacePointBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
tris.data_ptr<float>(),
|
||||
idx_tris.data_ptr<int64_t>(),
|
||||
@ -432,6 +509,7 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
|
||||
grad_tris.data_ptr<float>(),
|
||||
T);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
@ -468,26 +546,42 @@ __global__ void PointFaceArrayForwardKernel(
|
||||
at::Tensor PointFaceArrayDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& tris) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2};
|
||||
at::CheckedFrom c = "PointFaceArrayDistanceForwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, tris_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
|
||||
at::Tensor dists = at::zeros({P, T}, points.options());
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return dists;
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointFaceArrayForwardKernel<<<blocks, threads>>>(
|
||||
PointFaceArrayForwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
tris.data_ptr<float>(),
|
||||
dists.data_ptr<float>(),
|
||||
P,
|
||||
T);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return dists;
|
||||
}
|
||||
|
||||
@ -546,22 +640,38 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2},
|
||||
grad_dists_t{grad_dists, "grad_dists", 3};
|
||||
at::CheckedFrom c = "PointFaceArrayDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, tris_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == T));
|
||||
TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == T));
|
||||
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
|
||||
|
||||
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointFaceArrayBackwardKernel<<<blocks, threads>>>(
|
||||
PointFaceArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.data_ptr<float>(),
|
||||
tris.data_ptr<float>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
@ -570,5 +680,6 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
P,
|
||||
T);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceDistance *
|
||||
@ -55,6 +56,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
|
||||
const int64_t max_points) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(points_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(tris);
|
||||
CHECK_CONTIGUOUS_CUDA(tris_first_idx);
|
||||
return PointFaceDistanceForwardCuda(
|
||||
points, points_first_idx, tris, tris_first_idx, max_points);
|
||||
#else
|
||||
@ -95,6 +100,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(tris);
|
||||
CHECK_CONTIGUOUS_CUDA(idx_points);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
@ -151,6 +160,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
|
||||
const int64_t max_tris) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(points_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(tris);
|
||||
CHECK_CONTIGUOUS_CUDA(tris_first_idx);
|
||||
return FacePointDistanceForwardCuda(
|
||||
points, points_first_idx, tris, tris_first_idx, max_tris);
|
||||
#else
|
||||
@ -191,6 +204,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(tris);
|
||||
CHECK_CONTIGUOUS_CUDA(idx_tris);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
@ -233,6 +250,8 @@ torch::Tensor PointFaceArrayDistanceForward(
|
||||
const torch::Tensor& tris) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(tris);
|
||||
return PointFaceArrayDistanceForwardCuda(points, tris);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
@ -254,7 +273,6 @@ torch::Tensor PointFaceArrayDistanceForward(
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
@ -267,6 +285,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(tris);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <thrust/tuple.h>
|
||||
@ -285,14 +287,14 @@ RasterizeMeshesNaiveCuda(
|
||||
const int num_closest,
|
||||
const bool perspective_correct,
|
||||
const bool cull_backfaces) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
}
|
||||
if (num_faces_per_mesh.size(0) != mesh_to_faces_packed_first_idx.size(0)) {
|
||||
AT_ERROR(
|
||||
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
face_verts.size(2) == 3,
|
||||
"face_verts must have dimensions (num_faces, 3, 3)");
|
||||
|
||||
TORCH_CHECK(
|
||||
num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
|
||||
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
|
||||
|
||||
if (num_closest > kMaxPointsPerPixel) {
|
||||
std::stringstream ss;
|
||||
@ -300,6 +302,20 @@ RasterizeMeshesNaiveCuda(
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
mesh_to_faces_packed_first_idx_t{
|
||||
mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
|
||||
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
|
||||
at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
|
||||
at::checkAllSameGPU(
|
||||
c,
|
||||
{face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int N = num_faces_per_mesh.size(0); // batch size.
|
||||
const int H = image_size; // Assume square images.
|
||||
const int W = image_size;
|
||||
@ -313,10 +329,15 @@ RasterizeMeshesNaiveCuda(
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
|
||||
|
||||
if (face_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
RasterizeMeshesNaiveCudaKernel<<<blocks, threads>>>(
|
||||
RasterizeMeshesNaiveCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
|
||||
@ -332,6 +353,7 @@ RasterizeMeshesNaiveCuda(
|
||||
pix_dists.contiguous().data_ptr<float>(),
|
||||
bary.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||
}
|
||||
|
||||
@ -465,6 +487,22 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
const at::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const at::Tensor& grad_dists, // (N, H, W, K)
|
||||
const bool perspective_correct) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
pix_to_face_t{pix_to_face, "pix_to_face", 2},
|
||||
grad_zbuf_t{grad_zbuf, "grad_zbuf", 3},
|
||||
grad_bary_t{grad_bary, "grad_bary", 4},
|
||||
grad_dists_t{grad_dists, "grad_dists", 5};
|
||||
at::CheckedFrom c = "RasterizeMeshesBackwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {face_verts_t, pix_to_face_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
|
||||
at::checkAllSameType(
|
||||
c, {face_verts_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int F = face_verts.size(0);
|
||||
const int N = pix_to_face.size(0);
|
||||
const int H = pix_to_face.size(1);
|
||||
@ -472,10 +510,16 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
const int K = pix_to_face.size(3);
|
||||
|
||||
at::Tensor grad_face_verts = at::zeros({F, 3, 3}, face_verts.options());
|
||||
|
||||
if (grad_face_verts.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_face_verts;
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
RasterizeMeshesBackwardCudaKernel<<<blocks, threads>>>(
|
||||
RasterizeMeshesBackwardCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
pix_to_face.contiguous().data_ptr<int64_t>(),
|
||||
perspective_correct,
|
||||
@ -488,6 +532,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_face_verts.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_face_verts;
|
||||
}
|
||||
|
||||
@ -626,10 +671,24 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int max_faces_per_bin) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
face_verts.size(2) == 3,
|
||||
"face_verts must have dimensions (num_faces, 3, 3)");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
mesh_to_face_first_idx_t{
|
||||
mesh_to_face_first_idx, "mesh_to_face_first_idx", 2},
|
||||
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
|
||||
at::CheckedFrom c = "RasterizeMeshesCoarseCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {face_verts_t, mesh_to_face_first_idx_t, num_faces_per_mesh_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int W = image_size;
|
||||
const int H = image_size;
|
||||
const int F = face_verts.size(0);
|
||||
@ -645,12 +704,18 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
auto opts = face_verts.options().dtype(at::kInt);
|
||||
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
|
||||
if (bin_faces.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_faces;
|
||||
}
|
||||
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
const size_t threads = 512;
|
||||
|
||||
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size>>>(
|
||||
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
|
||||
@ -664,6 +729,8 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
M,
|
||||
faces_per_bin.contiguous().data_ptr<int32_t>(),
|
||||
bin_faces.contiguous().data_ptr<int32_t>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_faces;
|
||||
}
|
||||
|
||||
@ -775,13 +842,22 @@ RasterizeMeshesFineCuda(
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct,
|
||||
const bool cull_backfaces) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
}
|
||||
if (bin_faces.ndimension() != 4) {
|
||||
AT_ERROR("bin_faces must have 4 dimensions");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
face_verts.ndimension() == 3 && face_verts.size(1) == 3 &&
|
||||
face_verts.size(2) == 3,
|
||||
"face_verts must have dimensions (num_faces, 3, 3)");
|
||||
TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
bin_faces_t{bin_faces, "bin_faces", 2};
|
||||
at::CheckedFrom c = "RasterizeMeshesFineCuda";
|
||||
at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int N = bin_faces.size(0);
|
||||
const int B = bin_faces.size(1);
|
||||
const int M = bin_faces.size(3);
|
||||
@ -790,7 +866,7 @@ RasterizeMeshesFineCuda(
|
||||
const int W = image_size;
|
||||
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 8");
|
||||
AT_ERROR("Must have num_closest <= 150");
|
||||
}
|
||||
auto long_opts = face_verts.options().dtype(at::kLong);
|
||||
auto float_opts = face_verts.options().dtype(at::kFloat);
|
||||
@ -800,10 +876,15 @@ RasterizeMeshesFineCuda(
|
||||
at::Tensor pix_dists = at::full({N, H, W, K}, -1, float_opts);
|
||||
at::Tensor bary = at::full({N, H, W, K, 3}, -1, float_opts);
|
||||
|
||||
if (face_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
RasterizeMeshesFineCudaKernel<<<blocks, threads>>>(
|
||||
RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
bin_faces.contiguous().data_ptr<int32_t>(),
|
||||
blur_radius,
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
// * FORWARD PASS *
|
||||
@ -95,6 +96,9 @@ RasterizeMeshesNaive(
|
||||
// TODO: Better type checking.
|
||||
if (face_verts.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(face_verts);
|
||||
CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh);
|
||||
return RasterizeMeshesNaiveCuda(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
@ -175,6 +179,11 @@ torch::Tensor RasterizeMeshesBackward(
|
||||
const bool perspective_correct) {
|
||||
if (face_verts.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(face_verts);
|
||||
CHECK_CONTIGUOUS_CUDA(pix_to_face);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_bary);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return RasterizeMeshesBackwardCuda(
|
||||
face_verts,
|
||||
pix_to_face,
|
||||
@ -251,6 +260,9 @@ torch::Tensor RasterizeMeshesCoarse(
|
||||
const int max_faces_per_bin) {
|
||||
if (face_verts.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(face_verts);
|
||||
CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh);
|
||||
return RasterizeMeshesCoarseCuda(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
@ -347,6 +359,8 @@ RasterizeMeshesFine(
|
||||
const bool cull_backfaces) {
|
||||
if (face_verts.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(face_verts);
|
||||
CHECK_CONTIGUOUS_CUDA(bin_faces);
|
||||
return RasterizeMeshesFineCuda(
|
||||
face_verts,
|
||||
bin_faces,
|
||||
|
@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <math.h>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
@ -145,13 +147,25 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel) {
|
||||
if (points.ndimension() != 2 || points.size(1) != 3) {
|
||||
AT_ERROR("points must have dimensions (num_points, 3)");
|
||||
}
|
||||
if (num_points_per_cloud.size(0) != cloud_to_packed_first_idx.size(0)) {
|
||||
AT_ERROR(
|
||||
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
|
||||
}
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
cloud_to_packed_first_idx_t{
|
||||
cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
|
||||
num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
|
||||
at::CheckedFrom c = "RasterizePointsNaiveCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
TORCH_CHECK(
|
||||
points.ndimension() == 2 && points.size(1) == 3,
|
||||
"points must have dimensions (num_points, 3)");
|
||||
TORCH_CHECK(
|
||||
num_points_per_cloud.size(0) == cloud_to_packed_first_idx.size(0),
|
||||
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
|
||||
|
||||
const int N = num_points_per_cloud.size(0); // batch size.
|
||||
const int S = image_size;
|
||||
@ -169,9 +183,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
if (point_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
RasterizePointsNaiveCudaKernel<<<blocks, threads>>>(
|
||||
RasterizePointsNaiveCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||
@ -182,6 +201,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
point_idxs.contiguous().data_ptr<int32_t>(),
|
||||
zbuf.contiguous().data_ptr<float>(),
|
||||
pix_dists.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
@ -323,14 +344,28 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
TORCH_CHECK(
|
||||
points.ndimension() == 2 && points.size(1) == 3,
|
||||
"points must have dimensions (num_points, 3)");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
cloud_to_packed_first_idx_t{
|
||||
cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
|
||||
num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
|
||||
at::CheckedFrom c = "RasterizePointsCoarseCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int P = points.size(0);
|
||||
const int N = num_points_per_cloud.size(0);
|
||||
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
|
||||
const int M = max_points_per_bin;
|
||||
|
||||
if (points.ndimension() != 2 || points.size(1) != 3) {
|
||||
AT_ERROR("points must have dimensions (num_points, 3)");
|
||||
}
|
||||
if (num_bins >= 22) {
|
||||
// Make sure we do not use too much shared memory.
|
||||
std::stringstream ss;
|
||||
@ -340,12 +375,18 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
auto opts = points.options().dtype(at::kInt);
|
||||
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
|
||||
if (bin_points.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_points;
|
||||
}
|
||||
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
const size_t threads = 512;
|
||||
|
||||
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
|
||||
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||
@ -358,6 +399,8 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
M,
|
||||
points_per_bin.contiguous().data_ptr<int32_t>(),
|
||||
bin_points.contiguous().data_ptr<int32_t>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_points;
|
||||
}
|
||||
|
||||
@ -448,13 +491,23 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
bin_points_t{bin_points, "bin_points", 2};
|
||||
at::CheckedFrom c = "RasterizePointsFineCuda";
|
||||
at::checkAllSameGPU(c, {points_t, bin_points_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int N = bin_points.size(0);
|
||||
const int B = bin_points.size(1); // num_bins
|
||||
const int M = bin_points.size(3);
|
||||
const int S = image_size;
|
||||
const int K = points_per_pixel;
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 8");
|
||||
AT_ERROR("Must have num_closest <= 150");
|
||||
}
|
||||
auto int_opts = points.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
@ -462,9 +515,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
if (point_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
RasterizePointsFineCudaKernel<<<blocks, threads>>>(
|
||||
RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
bin_points.contiguous().data_ptr<int32_t>(),
|
||||
radius,
|
||||
@ -478,6 +536,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
zbuf.contiguous().data_ptr<float>(),
|
||||
pix_dists.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
@ -537,6 +596,19 @@ at::Tensor RasterizePointsBackwardCuda(
|
||||
const at::Tensor& idxs, // (N, H, W, K)
|
||||
const at::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const at::Tensor& grad_dists) { // (N, H, W, K)
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, idxs_t{idxs, "idxs", 2},
|
||||
grad_zbuf_t{grad_zbuf, "grad_zbuf", 3},
|
||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||
at::CheckedFrom c = "RasterizePointsBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idxs_t, grad_zbuf_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, grad_zbuf_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int P = points.size(0);
|
||||
const int N = idxs.size(0);
|
||||
const int H = idxs.size(1);
|
||||
@ -544,10 +616,16 @@ at::Tensor RasterizePointsBackwardCuda(
|
||||
const int K = idxs.size(3);
|
||||
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
|
||||
if (grad_points.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_points;
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
RasterizePointsBackwardCudaKernel<<<blocks, threads>>>(
|
||||
RasterizePointsBackwardCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
idxs.contiguous().data_ptr<int32_t>(),
|
||||
N,
|
||||
@ -559,5 +637,6 @@ at::Tensor RasterizePointsBackwardCuda(
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_points.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_points;
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
// * NAIVE RASTERIZATION *
|
||||
@ -66,6 +67,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||
num_points_per_cloud.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
|
||||
return RasterizePointsNaiveCuda(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@ -140,6 +144,9 @@ torch::Tensor RasterizePointsCoarse(
|
||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||
num_points_per_cloud.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
|
||||
return RasterizePointsCoarseCuda(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@ -208,6 +215,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
||||
const int points_per_pixel) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(bin_points);
|
||||
return RasterizePointsFineCuda(
|
||||
points, bin_points, image_size, radius, bin_size, points_per_pixel);
|
||||
#else
|
||||
@ -257,6 +266,10 @@ torch::Tensor RasterizePointsBackward(
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(idxs);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
|
@ -3,9 +3,9 @@
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x "must be a CUDA tensor.")
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x "must be contiguous.")
|
||||
TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.")
|
||||
#define CHECK_CONTIGUOUS_CUDA(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
@ -8,11 +8,21 @@ from test_chamfer import TestChamfer
|
||||
|
||||
|
||||
def bm_chamfer() -> None:
|
||||
kwargs_list_naive = [
|
||||
{"batch_size": 1, "P1": 32, "P2": 64, "return_normals": False},
|
||||
{"batch_size": 1, "P1": 32, "P2": 64, "return_normals": True},
|
||||
{"batch_size": 32, "P1": 32, "P2": 64, "return_normals": False},
|
||||
]
|
||||
devices = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices.append("cuda:0")
|
||||
|
||||
kwargs_list_naive = []
|
||||
batch_size = [1, 32]
|
||||
return_normals = [True, False]
|
||||
test_cases = product(batch_size, return_normals, devices)
|
||||
|
||||
for case in test_cases:
|
||||
b, n, d = case
|
||||
kwargs_list_naive.append(
|
||||
{"batch_size": b, "P1": 32, "P2": 64, "return_normals": n, "device": d}
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestChamfer.chamfer_naive_with_init,
|
||||
"CHAMFER_NAIVE",
|
||||
@ -21,6 +31,7 @@ def bm_chamfer() -> None:
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
kwargs_list = []
|
||||
batch_size = [1, 32]
|
||||
P1 = [32, 1000, 10000]
|
||||
@ -38,6 +49,7 @@ def bm_chamfer() -> None:
|
||||
"P2": p2,
|
||||
"return_normals": n,
|
||||
"homogeneous": h,
|
||||
"device": device,
|
||||
}
|
||||
)
|
||||
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)
|
||||
|
@ -20,6 +20,18 @@ def load_rgb_image(filename: str, data_dir: Union[str, Path]):
|
||||
TensorOrArray = Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
def get_random_cuda_device() -> str:
|
||||
"""
|
||||
Function to get a random GPU device from the
|
||||
available devices. This is useful for testing
|
||||
that custom cuda kernels can support inputs on
|
||||
any device without having to set the device explicitly.
|
||||
"""
|
||||
num_devices = torch.cuda.device_count()
|
||||
rand_device_id = torch.randint(high=num_devices, size=(1,)).item()
|
||||
return "cuda:%d" % rand_device_id
|
||||
|
||||
|
||||
class TestCaseMixin(unittest.TestCase):
|
||||
def assertSeparate(self, tensor1, tensor2) -> None:
|
||||
"""
|
||||
|
@ -6,7 +6,7 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.loss import chamfer_distance
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
@ -81,7 +81,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def chamfer_distance_naive_pointclouds(p1, p2):
|
||||
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
|
||||
"""
|
||||
Naive iterative implementation of nearest neighbor and chamfer distance.
|
||||
x and y are assumed to be pointclouds objects with points and optionally normals.
|
||||
@ -97,7 +97,6 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
x_normals = p1.normals_padded()
|
||||
y_normals = p2.normals_padded()
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
|
||||
# Initialize all distances to + inf
|
||||
@ -163,7 +162,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
N, P1, D = x.shape
|
||||
P2 = y.size(1)
|
||||
device = torch.device("cuda:0")
|
||||
device = x.device
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
dist = torch.zeros((N, P1, P2), dtype=torch.float32, device=device)
|
||||
|
||||
@ -203,7 +202,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
This tests only uses homogeneous pointclouds.
|
||||
"""
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
@ -237,7 +236,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
which supports heterogeneous pointcloud objects.
|
||||
"""
|
||||
N, max_P1, max_P2 = 3, 70, 70
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
weights = points_normals.weights
|
||||
x_lengths = points_normals.p1_lengths
|
||||
@ -256,7 +255,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Chamfer with pointclouds as input.
|
||||
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
|
||||
points_normals.cloud1, points_normals.cloud2
|
||||
points_normals.cloud1, points_normals.cloud2, device=device
|
||||
)
|
||||
|
||||
# Mean reduction point loss.
|
||||
@ -299,7 +298,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
def test_chamfer_pointcloud_object_withnormals(self):
|
||||
N = 5
|
||||
P1, P2 = 100, 100
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
|
||||
reductions = [
|
||||
("sum", "sum"),
|
||||
@ -359,7 +358,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
def test_chamfer_pointcloud_object_nonormals(self):
|
||||
N = 5
|
||||
P1, P2 = 100, 100
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
|
||||
reductions = [
|
||||
("sum", "sum"),
|
||||
@ -415,7 +414,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
for point_reduction = "mean" and batch_reduction = None.
|
||||
"""
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
@ -464,7 +463,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
for point_reduction = "sum" and batch_reduction = None.
|
||||
"""
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
@ -579,7 +578,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
point_reduction in ["mean", "sum"].
|
||||
"""
|
||||
N, max_P1, max_P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
|
||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||
p1 = points_normals.p1
|
||||
@ -681,7 +680,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_incorrect_weights(self):
|
||||
N, P1, P2 = 16, 64, 128
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
p1 = torch.rand(
|
||||
(N, P1, 3), dtype=torch.float32, device=device, requires_grad=True
|
||||
)
|
||||
@ -716,7 +715,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_incorrect_inputs(self):
|
||||
N, P1, P2 = 7, 10, 18
|
||||
device = "cuda:0"
|
||||
device = get_random_cuda_device()
|
||||
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||
p1 = points_normals.p1
|
||||
p2 = points_normals.p2
|
||||
@ -740,11 +739,16 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def chamfer_with_init(
|
||||
batch_size: int, P1: int, P2: int, return_normals: bool, homogeneous: bool
|
||||
batch_size: int,
|
||||
P1: int,
|
||||
P2: int,
|
||||
return_normals: bool,
|
||||
homogeneous: bool,
|
||||
device="cpu",
|
||||
):
|
||||
p1, p2, p1_normals, p2_normals, weights, l1, l2 = TestChamfer.init_pointclouds(
|
||||
batch_size, P1, P2
|
||||
)
|
||||
points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
|
||||
l1 = points_normals.p1_lengths
|
||||
l2 = points_normals.p2_lengths
|
||||
if homogeneous:
|
||||
# Set lengths to None so in Chamfer it assumes
|
||||
# there is no padding.
|
||||
@ -754,13 +758,13 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def loss():
|
||||
loss, loss_normals = chamfer_distance(
|
||||
p1,
|
||||
p2,
|
||||
points_normals.p1,
|
||||
points_normals.p2,
|
||||
x_lengths=l1,
|
||||
y_lengths=l2,
|
||||
x_normals=p1_normals,
|
||||
y_normals=p2_normals,
|
||||
weights=weights,
|
||||
x_normals=points_normals.n1,
|
||||
y_normals=points_normals.n2,
|
||||
weights=points_normals.weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -768,16 +772,17 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def chamfer_naive_with_init(
|
||||
batch_size: int, P1: int, P2: int, return_normals: bool
|
||||
batch_size: int, P1: int, P2: int, return_normals: bool, device="cpu"
|
||||
):
|
||||
p1, p2, p1_normals, p2_normals, weights, _, _ = TestChamfer.init_pointclouds(
|
||||
batch_size, P1, P2
|
||||
)
|
||||
points_normals = TestChamfer.init_pointclouds(batch_size, P1, P2, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def loss():
|
||||
loss, loss_normals = TestChamfer.chamfer_distance_naive(
|
||||
p1, p2, x_normals=p1_normals, y_normals=p2_normals
|
||||
points_normals.p1,
|
||||
points_normals.p2,
|
||||
x_normals=points_normals.n1,
|
||||
y_normals=points_normals.n2,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.renderer.compositing import (
|
||||
alpha_composite,
|
||||
norm_weighted_sum,
|
||||
@ -10,7 +11,7 @@ from pytorch3d.renderer.compositing import (
|
||||
)
|
||||
|
||||
|
||||
class TestAccumulatePoints(unittest.TestCase):
|
||||
class TestAccumulatePoints(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
|
||||
@staticmethod
|
||||
@ -120,7 +121,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
self._simple_wsumnorm(norm_weighted_sum, device)
|
||||
|
||||
def test_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._simple_alphacomposite(alpha_composite, device)
|
||||
self._simple_wsum(weighted_sum, device)
|
||||
self._simple_wsumnorm(norm_weighted_sum, device)
|
||||
@ -142,7 +143,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
C = 3
|
||||
P = 32
|
||||
|
||||
for d in ["cpu", "cuda"]:
|
||||
for d in ["cpu", get_random_cuda_device()]:
|
||||
# TODO(gkioxari) add torch.float64 to types after double precision
|
||||
# support is added to atomicAdd
|
||||
for t in [torch.float32]:
|
||||
@ -181,7 +182,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
res1 = fn1(*args1)
|
||||
res2 = fn2(*args2)
|
||||
|
||||
self.assertTrue(torch.allclose(res1.cpu(), res2.cpu(), atol=1e-6))
|
||||
self.assertClose(res1.cpu(), res2.cpu(), atol=1e-6)
|
||||
|
||||
if not compare_grads:
|
||||
return
|
||||
@ -200,7 +201,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
|
||||
|
||||
for i in range(0, len(grads1)):
|
||||
self.assertTrue(torch.allclose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6))
|
||||
self.assertClose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
|
||||
|
||||
def _simple_wsum(self, accum_func, device):
|
||||
# Initialise variables
|
||||
@ -273,7 +274,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3))
|
||||
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
||||
|
||||
def _simple_wsumnorm(self, accum_func, device):
|
||||
# Initialise variables
|
||||
@ -346,7 +347,7 @@ class TestAccumulatePoints(unittest.TestCase):
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue(torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3))
|
||||
self.assertClose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
||||
|
||||
def _simple_alphacomposite(self, accum_func, device):
|
||||
# Initialise variables
|
||||
|
@ -33,7 +33,9 @@ class TestCubify(unittest.TestCase):
|
||||
|
||||
# 1st-check
|
||||
verts, faces = meshes.get_mesh_verts_faces(0)
|
||||
self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1])))
|
||||
self.assertTrue(
|
||||
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
verts,
|
||||
@ -78,7 +80,9 @@ class TestCubify(unittest.TestCase):
|
||||
)
|
||||
# 2nd-check
|
||||
verts, faces = meshes.get_mesh_verts_faces(1)
|
||||
self.assertTrue(torch.allclose(faces.max(), torch.tensor([verts.size(0) - 1])))
|
||||
self.assertTrue(
|
||||
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
verts,
|
||||
|
@ -4,7 +4,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import mesh_face_areas_normals
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
@ -94,13 +94,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
self._test_face_areas_normals_helper("cpu")
|
||||
|
||||
def test_face_areas_normals_cuda(self):
|
||||
self._test_face_areas_normals_helper("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_face_areas_normals_helper(device)
|
||||
|
||||
def test_nonfloats_cpu(self):
|
||||
self._test_face_areas_normals_helper("cpu", dtype=torch.double)
|
||||
|
||||
def test_nonfloats_cuda(self):
|
||||
self._test_face_areas_normals_helper("cuda:0", dtype=torch.double)
|
||||
device = get_random_cuda_device()
|
||||
self._test_face_areas_normals_helper(device, dtype=torch.double)
|
||||
|
||||
@staticmethod
|
||||
def face_areas_normals_with_init(
|
||||
|
@ -4,7 +4,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
@ -14,7 +14,7 @@ from pytorch3d.utils import ico_sphere
|
||||
class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
def test_undirected(self):
|
||||
dtype = torch.float32
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
verts = torch.tensor(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device
|
||||
)
|
||||
@ -97,7 +97,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(y, expected_y)
|
||||
|
||||
def test_backward(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
mesh = ico_sphere()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
@ -118,7 +118,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)")
|
||||
|
||||
def test_cpu_cuda_tensor_error(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
verts = torch.tensor(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
|
||||
)
|
||||
@ -134,7 +134,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
Check that gather_scatter cuda version throws an error if cpu tensors
|
||||
are given as input.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
mesh = ico_sphere()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
|
@ -4,7 +4,7 @@ import unittest
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
self._knn_vs_python_square_helper(device)
|
||||
|
||||
def test_knn_vs_python_square_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._knn_vs_python_square_helper(device)
|
||||
|
||||
def _knn_vs_python_ragged_helper(self, device):
|
||||
@ -133,11 +133,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
self._knn_vs_python_ragged_helper(device)
|
||||
|
||||
def test_knn_vs_python_ragged_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._knn_vs_python_ragged_helper(device)
|
||||
|
||||
def test_knn_gather(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, P1, P2, K, D = 4, 16, 12, 8, 3
|
||||
x = torch.rand((N, P1, D), device=device)
|
||||
y = torch.rand((N, P2, D), device=device)
|
||||
|
@ -3,7 +3,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import packed_to_padded, padded_to_packed
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
@ -126,13 +126,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
|
||||
self._test_packed_to_padded_helper(16, "cpu")
|
||||
|
||||
def test_packed_to_padded_flat_cuda(self):
|
||||
self._test_packed_to_padded_helper(0, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_packed_to_padded_helper(0, device)
|
||||
|
||||
def test_packed_to_padded_D1_cuda(self):
|
||||
self._test_packed_to_padded_helper(1, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_packed_to_padded_helper(1, device)
|
||||
|
||||
def test_packed_to_padded_D16_cuda(self):
|
||||
self._test_packed_to_padded_helper(16, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_packed_to_padded_helper(16, device)
|
||||
|
||||
def _test_padded_to_packed_helper(self, D, device):
|
||||
"""
|
||||
@ -191,13 +194,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
|
||||
self._test_padded_to_packed_helper(16, "cpu")
|
||||
|
||||
def test_padded_to_packed_flat_cuda(self):
|
||||
self._test_padded_to_packed_helper(0, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_padded_to_packed_helper(0, device)
|
||||
|
||||
def test_padded_to_packed_D1_cuda(self):
|
||||
self._test_padded_to_packed_helper(1, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_padded_to_packed_helper(1, device)
|
||||
|
||||
def test_padded_to_packed_D16_cuda(self):
|
||||
self._test_padded_to_packed_helper(16, "cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._test_padded_to_packed_helper(16, device)
|
||||
|
||||
def test_invalid_inputs_shapes(self, device="cuda:0"):
|
||||
with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."):
|
||||
|
@ -4,7 +4,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance
|
||||
from pytorch3d.structures import Meshes, Pointclouds, packed_to_list
|
||||
@ -203,7 +203,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
& PointEdgeArrayDistanceBackward
|
||||
"""
|
||||
P, E = 16, 32
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
points = torch.rand((P, 3), dtype=torch.float32, device=device)
|
||||
edges = torch.rand((E, 2, 3), dtype=torch.float32, device=device)
|
||||
|
||||
@ -246,9 +246,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
Test CUDA implementation for PointEdgeDistanceForward
|
||||
& PointEdgeDistanceBackward
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# make points packed a leaf node
|
||||
points_packed = pcls.points_packed().detach().clone() # (P, 3)
|
||||
@ -327,9 +327,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
Test CUDA implementation for EdgePointDistanceForward
|
||||
& EdgePointDistanceBackward
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# make points packed a leaf node
|
||||
points_packed = pcls.points_packed().detach().clone() # (P, 3)
|
||||
@ -409,9 +409,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
Test point_mesh_edge_distance from pytorch3d.loss
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# clone and detach for another backward pass through the op
|
||||
verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
|
||||
@ -480,7 +480,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
& PointFaceArrayDistanceBackward
|
||||
"""
|
||||
P, T = 16, 32
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
points = torch.rand((P, 3), dtype=torch.float32, device=device)
|
||||
tris = torch.rand((T, 3, 3), dtype=torch.float32, device=device)
|
||||
|
||||
@ -525,9 +525,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
Test CUDA implementation for PointFaceDistanceForward
|
||||
& PointFaceDistanceBackward
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# make points packed a leaf node
|
||||
points_packed = pcls.points_packed().detach().clone() # (P, 3)
|
||||
@ -608,9 +608,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
Test CUDA implementation for FacePointDistanceForward
|
||||
& FacePointDistanceBackward
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# make points packed a leaf node
|
||||
points_packed = pcls.points_packed().detach().clone() # (P, 3)
|
||||
@ -690,9 +690,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
Test point_mesh_face_distance from pytorch3d.loss
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
N, V, F, P = 4, 32, 16, 24
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = self.init_meshes_clouds(N, V, F, P, device=device)
|
||||
|
||||
# clone and detach for another backward pass through the op
|
||||
verts_op = [verts.clone().detach() for verts in meshes.verts_list()]
|
||||
@ -751,7 +751,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
@staticmethod
|
||||
def point_mesh_edge(N: int, V: int, F: int, P: int, device: str):
|
||||
device = torch.device(device)
|
||||
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
|
||||
N, V, F, P, device=device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def loss():
|
||||
@ -763,7 +765,9 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
@staticmethod
|
||||
def point_mesh_face(N: int, V: int, F: int, P: int, device: str):
|
||||
device = torch.device(device)
|
||||
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(N, V, F, P)
|
||||
meshes, pcls = TestPointMeshDistance.init_meshes_clouds(
|
||||
N, V, F, P, device=device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def loss():
|
||||
|
@ -4,7 +4,7 @@ import functools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import (
|
||||
rasterize_meshes,
|
||||
@ -32,7 +32,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
|
||||
|
||||
def test_simple_cuda_naive(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
|
||||
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
|
||||
self._test_behind_camera(rasterize_meshes, device, bin_size=0)
|
||||
@ -40,7 +40,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
|
||||
|
||||
def test_simple_cuda_binned(self):
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
self._simple_triangle_raster(rasterize_meshes, device, bin_size=5)
|
||||
self._simple_blurry_raster(rasterize_meshes, device, bin_size=5)
|
||||
self._test_behind_camera(rasterize_meshes, device, bin_size=5)
|
||||
@ -54,7 +54,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
blur_radius = 0.1 ** 2
|
||||
faces_per_pixel = 3
|
||||
|
||||
for d in ["cpu", "cuda"]:
|
||||
for d in ["cpu", get_random_cuda_device()]:
|
||||
device = torch.device(d)
|
||||
compare_grads = True
|
||||
# Mesh with a single face.
|
||||
@ -164,7 +164,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
verts1.requires_grad = True
|
||||
meshes_cpu = Meshes(verts=[verts1], faces=[faces1])
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
meshes_cuda = ico_sphere(0, device)
|
||||
verts2, faces2 = meshes_cuda.get_mesh_verts_faces(0)
|
||||
verts2.requires_grad = True
|
||||
@ -186,7 +186,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
return self._test_coarse_rasterize(torch.device("cpu"))
|
||||
|
||||
def test_coarse_cuda(self):
|
||||
return self._test_coarse_rasterize(torch.device("cuda:0"))
|
||||
return self._test_coarse_rasterize(get_random_cuda_device())
|
||||
|
||||
def test_cpp_vs_cuda_naive_vs_cuda_binned(self):
|
||||
# Make sure that the backward pass runs for all pathways
|
||||
@ -221,7 +221,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
grad1 = verts.grad.data.cpu().clone()
|
||||
|
||||
# Option II: CUDA, naive
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
meshes = ico_sphere(0, device)
|
||||
verts, faces = meshes.get_mesh_verts_faces(0)
|
||||
verts.requires_grad = True
|
||||
@ -229,9 +229,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
args = (meshes, image_size, radius, faces_per_pixel, 0, 0)
|
||||
idx2, zbuf2, bary2, dist2 = rasterize_meshes(*args)
|
||||
grad_zbuf = grad_zbuf.cuda()
|
||||
grad_dist = grad_dist.cuda()
|
||||
grad_bary = grad_bary.cuda()
|
||||
grad_zbuf = grad_zbuf.to(device)
|
||||
grad_dist = grad_dist.to(device)
|
||||
grad_bary = grad_bary.to(device)
|
||||
loss = (
|
||||
(zbuf2 * grad_zbuf).sum()
|
||||
+ (dist2 * grad_dist).sum()
|
||||
@ -244,7 +244,6 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
grad2 = verts.grad.data.cpu().clone()
|
||||
|
||||
# Option III: CUDA, binned
|
||||
device = torch.device("cuda:0")
|
||||
meshes = ico_sphere(0, device)
|
||||
verts, faces = meshes.get_mesh_verts_faces(0)
|
||||
verts.requires_grad = True
|
||||
@ -302,7 +301,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
bin_size,
|
||||
max_faces_per_bin,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
meshes = meshes.clone().to(device)
|
||||
|
||||
faces = meshes.faces_packed()
|
||||
@ -356,8 +355,9 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
verts1, faces1 = meshes.get_mesh_verts_faces(0)
|
||||
verts1.requires_grad = True
|
||||
meshes1 = Meshes(verts=[verts1], faces=[faces1])
|
||||
verts2 = verts1.detach().cuda().requires_grad_(True)
|
||||
faces2 = faces1.detach().clone().cuda()
|
||||
device = get_random_cuda_device()
|
||||
verts2 = verts1.detach().to(device).requires_grad_(True)
|
||||
faces2 = faces1.detach().clone().to(device)
|
||||
meshes2 = Meshes(verts=[verts2], faces=[faces2])
|
||||
|
||||
kwargs = {"image_size": 64, "perspective_correct": True}
|
||||
@ -367,7 +367,8 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
|
||||
|
||||
def test_cuda_naive_vs_binned_perspective_correct(self):
|
||||
meshes = ico_sphere(2, device=torch.device("cuda"))
|
||||
device = get_random_cuda_device()
|
||||
meshes = ico_sphere(2, device=device)
|
||||
verts1, faces1 = meshes.get_mesh_verts_faces(0)
|
||||
verts1.requires_grad = True
|
||||
meshes1 = Meshes(verts=[verts1], faces=[faces1])
|
||||
@ -1029,7 +1030,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
max_faces_per_bin: int,
|
||||
):
|
||||
|
||||
meshes = ico_sphere(ico_level, torch.device("cuda:0"))
|
||||
meshes = ico_sphere(ico_level, get_random_cuda_device())
|
||||
meshes_batch = meshes.extend(num_meshes)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
@ -5,7 +5,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer.points.rasterize_points import (
|
||||
rasterize_points,
|
||||
@ -25,7 +25,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
self._simple_test_case(rasterize_points, device)
|
||||
|
||||
def test_naive_simple_cuda(self):
|
||||
device = torch.device("cuda")
|
||||
device = get_random_cuda_device()
|
||||
self._simple_test_case(rasterize_points, device, bin_size=0)
|
||||
|
||||
def test_python_behind_camera(self):
|
||||
@ -37,7 +37,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
self._test_behind_camera(rasterize_points, torch.device("cpu"))
|
||||
|
||||
def test_cuda_behind_camera(self):
|
||||
self._test_behind_camera(rasterize_points, torch.device("cuda"), bin_size=0)
|
||||
device = get_random_cuda_device()
|
||||
self._test_behind_camera(rasterize_points, device, bin_size=0)
|
||||
|
||||
def test_cpp_vs_naive_vs_binned(self):
|
||||
# Make sure that the backward pass runs for all pathways
|
||||
@ -373,7 +374,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
return self._test_coarse_rasterize(torch.device("cpu"))
|
||||
|
||||
def test_coarse_cuda(self):
|
||||
return self._test_coarse_rasterize(torch.device("cuda"))
|
||||
device = get_random_cuda_device()
|
||||
return self._test_coarse_rasterize(device)
|
||||
|
||||
def test_compare_coarse_cpu_vs_cuda(self):
|
||||
torch.manual_seed(231)
|
||||
@ -405,7 +407,8 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
bp_cpu = _C._rasterize_points_coarse(*args)
|
||||
|
||||
pointclouds_cuda = pointclouds.to("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
pointclouds_cuda = pointclouds.to(device)
|
||||
points_packed = pointclouds_cuda.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
|
||||
|
@ -5,7 +5,7 @@ import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import sample_points_from_meshes
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
@ -42,7 +42,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
|
||||
Check sample_points_from_meshes raises an exception if all meshes are
|
||||
invalid.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
verts1 = torch.tensor([], dtype=torch.float32, device=device)
|
||||
faces1 = torch.tensor([], dtype=torch.int64, device=device)
|
||||
meshes = Meshes(verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1])
|
||||
@ -56,7 +56,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
|
||||
For an ico_sphere, the sampled vertices should lie on a unit sphere.
|
||||
For an empty mesh, the samples and normals should be 0.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
device = get_random_cuda_device()
|
||||
|
||||
# Unit simplex.
|
||||
verts_pyramid = torch.tensor(
|
||||
|
Loading…
x
Reference in New Issue
Block a user