From 9ad98c87c314877541187724a620c81332339a87 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 1 Oct 2021 11:57:07 -0700 Subject: [PATCH] Cuda function for points2vols Summary: Added CUDA implementation to match the new, still unused, C++ function for the core of points2vols. Reviewed By: nikhilaravi Differential Revision: D29548608 fbshipit-source-id: 16ebb61787fcb4c70461f9215a86ad5f97aecb4e --- .../points_to_volumes/points_to_volumes.cu | 347 ++++++++++++++++++ .../points_to_volumes/points_to_volumes.h | 64 +++- tests/test_points_to_volumes.py | 24 ++ 3 files changed, 433 insertions(+), 2 deletions(-) create mode 100644 pytorch3d/csrc/points_to_volumes/points_to_volumes.cu diff --git a/pytorch3d/csrc/points_to_volumes/points_to_volumes.cu b/pytorch3d/csrc/points_to_volumes/points_to_volumes.cu new file mode 100644 index 00000000..cde26c33 --- /dev/null +++ b/pytorch3d/csrc/points_to_volumes/points_to_volumes.cu @@ -0,0 +1,347 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using torch::PackedTensorAccessor64; +using torch::RestrictPtrTraits; + +// A chunk of work is blocksize-many points. +// There are N clouds in the batch, and P points in each cloud. +// The number of potential chunks to do per cloud is (1+(P-1)/blocksize), +// which we call chunks_per_cloud. +// These (N*chunks_per_cloud) chunks are divided among the gridSize-many blocks. +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . +// In chunk i, we work on cloud (i/chunks_per_cloud) on points starting from +// blocksize*(i%chunks_per_cloud). + +// Explanation of the calculation is in the cpp file. + +// EightDirections(t) runs t(a,b,c) for every combination of boolean a, b, c. +template +static __device__ void EightDirections(T&& t) { + t(false, false, false); + t(false, false, true); + t(false, true, false); + t(false, true, true); + t(true, false, false); + t(true, false, true); + t(true, true, false); + t(true, true, true); +} + +__global__ void PointsToVolumesForwardKernel( + const PackedTensorAccessor64 points_3d, + const PackedTensorAccessor64 points_features, + PackedTensorAccessor64 volume_densities, + PackedTensorAccessor64 volume_features, + PackedTensorAccessor64 grid_sizes, + PackedTensorAccessor64 mask, + const float point_weight, + const bool align_corners, + const bool splat, + const int64_t batch_size, + const int64_t P, + const int64_t n_features) { + const int64_t chunks_per_cloud = (1 + (P - 1) / blockDim.x); + const int64_t chunks_to_do = batch_size * chunks_per_cloud; + const int scale_offset = align_corners ? 1 : 0; + const float offset = align_corners ? 0 : 0.5; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t batch_index = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t point_idx = start_point + threadIdx.x; + if (point_idx >= P) { + continue; + } + if (mask[batch_index][point_idx] == 0) { + continue; + } + auto volume_densities_aa = volume_densities[batch_index][0]; + auto volume_features_aa = volume_features[batch_index]; + auto point = points_3d[batch_index][point_idx]; + auto point_features = points_features[batch_index][point_idx]; + const int64_t grid_size_x = grid_sizes[batch_index][2]; + const int64_t grid_size_y = grid_sizes[batch_index][1]; + const int64_t grid_size_z = grid_sizes[batch_index][0]; + auto increment_location = + [&](int64_t x, int64_t y, int64_t z, float weight) { + if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) { + return; + } + if (x < 0 || y < 0 || z < 0) { + return; + } + + atomicAdd(&volume_densities_aa[z][y][x], weight * point_weight); + + for (int64_t feature_idx = 0; feature_idx < n_features; + ++feature_idx) { + atomicAdd( + &volume_features_aa[feature_idx][z][y][x], + point_features[feature_idx] * weight * point_weight); + } + }; + if (!splat) { + long x = std::lround( + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset); + long y = std::lround( + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset); + long z = std::lround( + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset); + increment_location(x, y, z, 1); + } else { + float x = 0, y = 0, z = 0; + float rx = std::modf( + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x); + float ry = std::modf( + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y); + float rz = std::modf( + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z); + auto handle_point = [&](bool up_x, bool up_y, bool up_z) { + float weight = + (up_x ? rx : 1 - rx) * (up_y ? ry : 1 - ry) * (up_z ? rz : 1 - rz); + increment_location(x + up_x, y + up_y, z + up_z, weight); + }; + EightDirections(handle_point); + } + } +} + +void PointsToVolumesForwardCuda( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& volume_densities, + const torch::Tensor& volume_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + const float point_weight, + const bool align_corners, + const bool splat) { + // Check inputs are on the same device + at::TensorArg points_3d_t{points_3d, "points_3d", 1}, + points_features_t{points_features, "points_features", 2}, + volume_densities_t{volume_densities, "volume_densities", 3}, + volume_features_t{volume_features, "volume_features", 4}, + grid_sizes_t{grid_sizes, "grid_sizes", 5}, mask_t{mask, "mask", 6}; + at::CheckedFrom c = "PointsToVolumesForwardCuda"; + at::checkAllSameGPU( + c, + {points_3d_t, + points_features_t, + volume_densities_t, + volume_features_t, + grid_sizes_t, + mask_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points_3d.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int blocks = 1024; + const int threads = 32; + + const int64_t batch_size = points_3d.size(0); + const int64_t P = points_3d.size(1); + const int64_t n_features = points_features.size(2); + + PointsToVolumesForwardKernel<<>>( + points_3d.packed_accessor64(), + points_features.packed_accessor64(), + volume_densities.packed_accessor64(), + volume_features.packed_accessor64(), + grid_sizes.packed_accessor64(), + mask.packed_accessor64(), + point_weight, + align_corners, + splat, + batch_size, + P, + n_features); +} + +__global__ void PointsToVolumesBackwardKernel( + const PackedTensorAccessor64 points_3d, + const PackedTensorAccessor64 points_features, + const PackedTensorAccessor64 grid_sizes, + const PackedTensorAccessor64 mask, + PackedTensorAccessor64 grad_volume_densities, + PackedTensorAccessor64 grad_volume_features, + PackedTensorAccessor64 grad_points_3d, + PackedTensorAccessor64 grad_points_features, + const float point_weight, + const bool align_corners, + const bool splat, + const int64_t batch_size, + const int64_t P, + const int64_t n_features) { + const int64_t chunks_per_cloud = (1 + (P - 1) / blockDim.x); + const int64_t chunks_to_do = batch_size * chunks_per_cloud; + const int scale_offset = align_corners ? 1 : 0; + const float offset = align_corners ? 0 : 0.5; + // Note that the gradients belonging to each point are only touched by + // a single thread in one of our "chunks", which is in a single block. + // So unlike in the forward pass, there's no need for atomics here. + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t batch_index = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t point_idx = start_point + threadIdx.x; + if (point_idx >= P) { + continue; + } + if (mask[batch_index][point_idx] == 0) { + continue; + } + auto point = points_3d[batch_index][point_idx]; + auto point_features = points_features[batch_index][point_idx]; + auto grad_point = grad_points_3d[batch_index][point_idx]; + auto grad_point_features = grad_points_features[batch_index][point_idx]; + auto grad_volume_densities_a = grad_volume_densities[batch_index][0]; + auto grad_volume_features_a = grad_volume_features[batch_index]; + const int64_t grid_size_x = grid_sizes[batch_index][2]; + const int64_t grid_size_y = grid_sizes[batch_index][1]; + const int64_t grid_size_z = grid_sizes[batch_index][0]; + + auto increment_location = + [&](int64_t x, int64_t y, int64_t z, float weight) { + if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) { + return false; + } + if (x < 0 || y < 0 || z < 0) { + return false; + } + + // This is a forward line, for comparison + // volume_densities_aa[z][y][x] += weight * point_weight; + + for (int64_t feature_idx = 0; feature_idx < n_features; + ++feature_idx) { + // This is a forward line, for comparison + // volume_features_aa[feature_idx][z][y][x] += + // point_features[feature_idx] * weight * point_weight; + grad_point_features[feature_idx] += + grad_volume_features_a[feature_idx][z][y][x] * weight * + point_weight; + } + return true; + }; + + if (!splat) { + long x = std::lround( + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset); + long y = std::lround( + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset); + long z = std::lround( + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset); + increment_location(x, y, z, 1); + } else { + float x = 0, y = 0, z = 0; + float rx = std::modf( + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x); + float ry = std::modf( + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y); + float rz = std::modf( + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z); + auto handle_point = [&](bool up_x, bool up_y, bool up_z) { + float weight_x = (up_x ? rx : 1 - rx); + float weight_y = (up_y ? ry : 1 - ry); + float weight_z = (up_z ? rz : 1 - rz); + float weight = weight_x * weight_y * weight_z; + if (increment_location(x + up_x, y + up_y, z + up_z, weight)) { + // weight * point_weight has been added to + // volume_densities_aa[z+up_z][y+up_y][x+up_x] + // Also for each feature_idx, + // point_features[feature_idx] * weight * point_weight + // has been added to + // volume_features_aa[feature_idx][z+up_z][y+up_y][x+up_x] + + double source_gradient = + grad_volume_densities_a[z + up_z][y + up_y][x + up_x]; + for (int64_t feature_idx = 0; feature_idx < n_features; + ++feature_idx) { + source_gradient += point_features[feature_idx] * + grad_volume_features_a[feature_idx][z + up_z][y + up_y] + [x + up_x]; + } + grad_point[0] += source_gradient * (up_x ? 1 : -1) * weight_y * + weight_z * 0.5 * (grid_size_x - scale_offset) * point_weight; + grad_point[1] += source_gradient * (up_y ? 1 : -1) * weight_x * + weight_z * 0.5 * (grid_size_y - scale_offset) * point_weight; + grad_point[2] += source_gradient * (up_z ? 1 : -1) * weight_x * + weight_y * 0.5 * (grid_size_z - scale_offset) * point_weight; + } + }; + EightDirections(handle_point); + } + } +} + +void PointsToVolumesBackwardCuda( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + const float point_weight, + const bool align_corners, + const bool splat, + const torch::Tensor& grad_volume_densities, + const torch::Tensor& grad_volume_features, + const torch::Tensor& grad_points_3d, + const torch::Tensor& grad_points_features) { + // Check inputs are on the same device + at::TensorArg points_3d_t{points_3d, "points_3d", 1}, + points_features_t{points_features, "points_features", 2}, + grid_sizes_t{grid_sizes, "grid_sizes", 3}, mask_t{mask, "mask", 4}, + grad_volume_densities_t{ + grad_volume_densities, "grad_volume_densities", 8}, + grad_volume_features_t{grad_volume_features, "grad_volume_features", 9}, + grad_points_3d_t{grad_points_3d, "grad_points_3d", 10}, + grad_points_features_t{grad_points_features, "grad_points_features", 11}; + + at::CheckedFrom c = "PointsToVolumesBackwardCuda"; + at::checkAllSameGPU( + c, + {points_3d_t, + points_features_t, + grid_sizes_t, + mask_t, + grad_volume_densities_t, + grad_volume_features_t, + grad_points_3d_t, + grad_points_features_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points_3d.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int blocks = 1024; + const int threads = 32; + + const int64_t batch_size = points_3d.size(0); + const int64_t P = points_3d.size(1); + const int64_t n_features = points_features.size(2); + + PointsToVolumesBackwardKernel<<>>( + points_3d.packed_accessor64(), + points_features.packed_accessor64(), + grid_sizes.packed_accessor64(), + mask.packed_accessor64(), + grad_volume_densities.packed_accessor64(), + grad_volume_features.packed_accessor64(), + grad_points_3d.packed_accessor64(), + grad_points_features.packed_accessor64(), + point_weight, + align_corners, + splat, + batch_size, + P, + n_features); +} diff --git a/pytorch3d/csrc/points_to_volumes/points_to_volumes.h b/pytorch3d/csrc/points_to_volumes/points_to_volumes.h index a53c2ab4..9a93c905 100644 --- a/pytorch3d/csrc/points_to_volumes/points_to_volumes.h +++ b/pytorch3d/csrc/points_to_volumes/points_to_volumes.h @@ -57,6 +57,17 @@ void PointsToVolumesForwardCpu( bool align_corners, bool splat); +void PointsToVolumesForwardCuda( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& volume_densities, + const torch::Tensor& volume_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + float point_weight, + bool align_corners, + bool splat); + inline void PointsToVolumesForward( const torch::Tensor& points_3d, const torch::Tensor& points_features, @@ -69,7 +80,23 @@ inline void PointsToVolumesForward( bool splat) { if (points_3d.is_cuda()) { #ifdef WITH_CUDA - AT_ERROR("CUDA not implemented yet"); + CHECK_CUDA(points_3d); + CHECK_CUDA(points_features); + CHECK_CUDA(volume_densities); + CHECK_CUDA(volume_features); + CHECK_CUDA(grid_sizes); + CHECK_CUDA(mask); + PointsToVolumesForwardCuda( + points_3d, + points_features, + volume_densities, + volume_features, + grid_sizes, + mask, + point_weight, + align_corners, + splat); + return; #else AT_ERROR("Not compiled with GPU support."); #endif @@ -101,6 +128,19 @@ void PointsToVolumesBackwardCpu( const torch::Tensor& grad_points_3d, const torch::Tensor& grad_points_features); +void PointsToVolumesBackwardCuda( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + float point_weight, + bool align_corners, + bool splat, + const torch::Tensor& grad_volume_densities, + const torch::Tensor& grad_volume_features, + const torch::Tensor& grad_points_3d, + const torch::Tensor& grad_points_features); + inline void PointsToVolumesBackward( const torch::Tensor& points_3d, const torch::Tensor& points_features, @@ -115,7 +155,27 @@ inline void PointsToVolumesBackward( const torch::Tensor& grad_points_features) { if (points_3d.is_cuda()) { #ifdef WITH_CUDA - AT_ERROR("CUDA not implemented yet"); + CHECK_CUDA(points_3d); + CHECK_CUDA(points_features); + CHECK_CUDA(grid_sizes); + CHECK_CUDA(mask); + CHECK_CUDA(grad_volume_densities); + CHECK_CUDA(grad_volume_features); + CHECK_CUDA(grad_points_3d); + CHECK_CUDA(grad_points_features); + PointsToVolumesBackwardCuda( + points_3d, + points_features, + grid_sizes, + mask, + point_weight, + align_corners, + splat, + grad_volume_densities, + grad_volume_features, + grad_points_3d, + grad_points_features); + return; #else AT_ERROR("Not compiled with GPU support."); #endif diff --git a/tests/test_points_to_volumes.py b/tests/test_points_to_volumes.py index 2ca4352a..e417781e 100644 --- a/tests/test_points_to_volumes.py +++ b/tests/test_points_to_volumes.py @@ -420,6 +420,18 @@ class TestRawFunction(TestCaseMixin, unittest.TestCase): def test_grad_round_cpu(self): self.do_gradcheck(torch.device("cpu"), False, False) + def test_grad_corners_splat_cuda(self): + self.do_gradcheck(torch.device("cuda:0"), True, True) + + def test_grad_corners_round_cuda(self): + self.do_gradcheck(torch.device("cuda:0"), False, True) + + def test_grad_splat_cuda(self): + self.do_gradcheck(torch.device("cuda:0"), True, False) + + def test_grad_round_cuda(self): + self.do_gradcheck(torch.device("cuda:0"), False, False) + def do_gradcheck(self, device, splat: bool, align_corners: bool): """ Use gradcheck to verify the gradient of _points_to_volumes @@ -492,6 +504,18 @@ class TestRawFunction(TestCaseMixin, unittest.TestCase): def test_single_splat_cpu(self): self.single_point(torch.device("cpu"), True, False) + def test_single_corners_round_cuda(self): + self.single_point(torch.device("cuda:0"), False, True) + + def test_single_corners_splat_cuda(self): + self.single_point(torch.device("cuda:0"), True, True) + + def test_single_round_cuda(self): + self.single_point(torch.device("cuda:0"), False, False) + + def test_single_splat_cuda(self): + self.single_point(torch.device("cuda:0"), True, False) + def single_point(self, device, splat: bool, align_corners: bool): """ Check the outcome of _points_to_volumes where a single point