diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 5fae7948..012a95be 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -15,8 +15,7 @@ #include "interp_face_attrs/interp_face_attrs.h" #include "knn/knn.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h" -#include "point_mesh/point_mesh_edge.h" -#include "point_mesh/point_mesh_face.h" +#include "point_mesh/point_mesh_cuda.h" #include "rasterize_meshes/rasterize_meshes.h" #include "rasterize_points/rasterize_points.h" diff --git a/pytorch3d/csrc/point_mesh/point_mesh.cpp b/pytorch3d/csrc/point_mesh/point_mesh_cpu.cpp similarity index 100% rename from pytorch3d/csrc/point_mesh/point_mesh.cpp rename to pytorch3d/csrc/point_mesh/point_mesh_cpu.cpp diff --git a/pytorch3d/csrc/point_mesh/point_mesh_face.cu b/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu similarity index 73% rename from pytorch3d/csrc/point_mesh/point_mesh_face.cu rename to pytorch3d/csrc/point_mesh/point_mesh_cuda.cu index 1ba42ddf..f762c82f 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_face.cu +++ b/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu @@ -12,7 +12,7 @@ #include "utils/warp_reduce.cuh" // **************************************************************************** -// * PointFaceDistance * +// * Generic Forward/Backward Kernels * // **************************************************************************** __global__ void DistanceForwardKernel( @@ -202,16 +202,6 @@ std::tuple DistanceForwardCuda( return std::make_tuple(dists, idxs); } -std::tuple PointFaceDistanceForwardCuda( - const at::Tensor& points, - const at::Tensor& points_first_idx, - const at::Tensor& tris, - const at::Tensor& tris_first_idx, - const int64_t max_points) { - return DistanceForwardCuda( - points, 1, points_first_idx, tris, 3, tris_first_idx, max_points); -} - __global__ void DistanceBackwardKernel( const float* __restrict__ objects, // (O * oD * 3) const size_t objects_size, // O @@ -365,6 +355,20 @@ std::tuple DistanceBackwardCuda( return std::make_tuple(grad_points, grad_tris); } +// **************************************************************************** +// * PointFaceDistance * +// **************************************************************************** + +std::tuple PointFaceDistanceForwardCuda( + const at::Tensor& points, + const at::Tensor& points_first_idx, + const at::Tensor& tris, + const at::Tensor& tris_first_idx, + const int64_t max_points) { + return DistanceForwardCuda( + points, 1, points_first_idx, tris, 3, tris_first_idx, max_points); +} + std::tuple PointFaceDistanceBackwardCuda( const at::Tensor& points, const at::Tensor& tris, @@ -395,9 +399,54 @@ std::tuple FacePointDistanceBackwardCuda( return DistanceBackwardCuda(tris, 3, points, 1, idx_tris, grad_dists); } +// **************************************************************************** +// * PointEdgeDistance * +// **************************************************************************** + +std::tuple PointEdgeDistanceForwardCuda( + const at::Tensor& points, + const at::Tensor& points_first_idx, + const at::Tensor& segms, + const at::Tensor& segms_first_idx, + const int64_t max_points) { + return DistanceForwardCuda( + points, 1, points_first_idx, segms, 2, segms_first_idx, max_points); +} + +std::tuple PointEdgeDistanceBackwardCuda( + const at::Tensor& points, + const at::Tensor& segms, + const at::Tensor& idx_points, + const at::Tensor& grad_dists) { + return DistanceBackwardCuda(points, 1, segms, 2, idx_points, grad_dists); +} + +// **************************************************************************** +// * EdgePointDistance * +// **************************************************************************** + +std::tuple EdgePointDistanceForwardCuda( + const at::Tensor& points, + const at::Tensor& points_first_idx, + const at::Tensor& segms, + const at::Tensor& segms_first_idx, + const int64_t max_segms) { + return DistanceForwardCuda( + segms, 2, segms_first_idx, points, 1, points_first_idx, max_segms); +} + +std::tuple EdgePointDistanceBackwardCuda( + const at::Tensor& points, + const at::Tensor& segms, + const at::Tensor& idx_segms, + const at::Tensor& grad_dists) { + return DistanceBackwardCuda(segms, 2, points, 1, idx_segms, grad_dists); +} + // **************************************************************************** // * PointFaceArrayDistance * // **************************************************************************** +// TODO: Create wrapper function and merge kernel with other array kernel __global__ void PointFaceArrayForwardKernel( const float* __restrict__ points, // (P, 3) @@ -565,3 +614,164 @@ std::tuple PointFaceArrayDistanceBackwardCuda( AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_points, grad_tris); } + +// **************************************************************************** +// * PointEdgeArrayDistance * +// **************************************************************************** +// TODO: Create wrapper function and merge kernel with other array kernel + +__global__ void PointEdgeArrayForwardKernel( + const float* __restrict__ points, // (P, 3) + const float* __restrict__ segms, // (S, 2, 3) + float* __restrict__ dists, // (P, S) + const size_t P, + const size_t S) { + float3* points_f3 = (float3*)points; + float3* segms_f3 = (float3*)segms; + + // Parallelize over P * S computations + const int num_threads = gridDim.x * blockDim.x; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int t_i = tid; t_i < P * S; t_i += num_threads) { + const int s = t_i / P; // segment index. + const int p = t_i % P; // point index + float3 a = segms_f3[s * 2 + 0]; + float3 b = segms_f3[s * 2 + 1]; + + float3 point = points_f3[p]; + float dist = PointLine3DistanceForward(point, a, b); + dists[p * S + s] = dist; + } +} + +at::Tensor PointEdgeArrayDistanceForwardCuda( + const at::Tensor& points, + const at::Tensor& segms) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2}; + at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda"; + at::checkAllSameGPU(c, {points_t, segms_t}); + at::checkAllSameType(c, {points_t, segms_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int64_t P = points.size(0); + const int64_t S = segms.size(0); + + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( + (segms.size(1) == 2) && (segms.size(2) == 3), + "segms must be of shape Sx2x3"); + + at::Tensor dists = at::zeros({P, S}, points.options()); + + if (dists.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return dists; + } + + const size_t blocks = 1024; + const size_t threads = 64; + + PointEdgeArrayForwardKernel<<>>( + points.contiguous().data_ptr(), + segms.contiguous().data_ptr(), + dists.data_ptr(), + P, + S); + + AT_CUDA_CHECK(cudaGetLastError()); + return dists; +} + +__global__ void PointEdgeArrayBackwardKernel( + const float* __restrict__ points, // (P, 3) + const float* __restrict__ segms, // (S, 2, 3) + const float* __restrict__ grad_dists, // (P, S) + float* __restrict__ grad_points, // (P, 3) + float* __restrict__ grad_segms, // (S, 2, 3) + const size_t P, + const size_t S) { + float3* points_f3 = (float3*)points; + float3* segms_f3 = (float3*)segms; + + // Parallelize over P * S computations + const int num_threads = gridDim.x * blockDim.x; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int t_i = tid; t_i < P * S; t_i += num_threads) { + const int s = t_i / P; // segment index. + const int p = t_i % P; // point index + const float3 a = segms_f3[s * 2 + 0]; + const float3 b = segms_f3[s * 2 + 1]; + + const float3 point = points_f3[p]; + const float grad_dist = grad_dists[p * S + s]; + const auto grads = PointLine3DistanceBackward(point, a, b, grad_dist); + const float3 grad_point = thrust::get<0>(grads); + const float3 grad_a = thrust::get<1>(grads); + const float3 grad_b = thrust::get<2>(grads); + + atomicAdd(grad_points + p * 3 + 0, grad_point.x); + atomicAdd(grad_points + p * 3 + 1, grad_point.y); + atomicAdd(grad_points + p * 3 + 2, grad_point.z); + + atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_a.x); + atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_a.y); + atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_a.z); + + atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_b.x); + atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_b.y); + atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_b.z); + } +} + +std::tuple PointEdgeArrayDistanceBackwardCuda( + const at::Tensor& points, + const at::Tensor& segms, + const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2}, + grad_dists_t{grad_dists, "grad_dists", 3}; + at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda"; + at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t}); + at::checkAllSameType(c, {points_t, segms_t, grad_dists_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int64_t P = points.size(0); + const int64_t S = segms.size(0); + + TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( + (segms.size(1) == 2) && (segms.size(2) == 3), + "segms must be of shape Sx2x3"); + TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S)); + + at::Tensor grad_points = at::zeros({P, 3}, points.options()); + at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options()); + + if (grad_points.numel() == 0 || grad_segms.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_points, grad_segms); + } + + const size_t blocks = 1024; + const size_t threads = 64; + + PointEdgeArrayBackwardKernel<<>>( + points.contiguous().data_ptr(), + segms.contiguous().data_ptr(), + grad_dists.contiguous().data_ptr(), + grad_points.data_ptr(), + grad_segms.data_ptr(), + P, + S); + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_points, grad_segms); +} diff --git a/pytorch3d/csrc/point_mesh/point_mesh_face.h b/pytorch3d/csrc/point_mesh/point_mesh_cuda.h similarity index 50% rename from pytorch3d/csrc/point_mesh/point_mesh_face.h rename to pytorch3d/csrc/point_mesh/point_mesh_cuda.h index ec4bd344..d0ffb57d 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_face.h +++ b/pytorch3d/csrc/point_mesh/point_mesh_cuda.h @@ -241,6 +241,242 @@ std::tuple FacePointDistanceBackward( return FacePointDistanceBackwardCpu(points, tris, idx_tris, grad_dists); } +// **************************************************************************** +// * PointEdgeDistance * +// **************************************************************************** + +// Computes the squared euclidean distance of each p in points to the closest +// mesh edge belonging to the corresponding example in the batch of size N. +// +// Args: +// points: FloatTensor of shape (P, 3) +// points_first_idx: LongTensor of shape (N,) indicating the first point +// index for each example in the batch +// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge +// segment is spanned by (segms[s, 0], segms[s, 1]) +// segms_first_idx: LongTensor of shape (N,) indicating the first edge +// index for each example in the batch +// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing +// the maximum number of points in the batch and is used to set +// the grid dimensions in the CUDA implementation. +// +// Returns: +// dists: FloatTensor of shape (P,), where dists[p] is the squared euclidean +// distance of points[p] to the closest edge in the same example in the +// batch. +// idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest +// edge in the batch. +// So, dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1]), +// where d(u, v0, v1) is the distance of u from the segment spanned by +// (v0, v1). +// + +#ifdef WITH_CUDA + +std::tuple PointEdgeDistanceForwardCuda( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t max_points); +#endif + +std::tuple PointEdgeDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t max_points); + +std::tuple PointEdgeDistanceForward( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t max_points) { + if (points.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(points); + CHECK_CUDA(points_first_idx); + CHECK_CUDA(segms); + CHECK_CUDA(segms_first_idx); + return PointEdgeDistanceForwardCuda( + points, points_first_idx, segms, segms_first_idx, max_points); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return PointEdgeDistanceForwardCpu( + points, points_first_idx, segms, segms_first_idx, max_points); +} + +// Backward pass for PointEdgeDistance. +// +// Args: +// points: FloatTensor of shape (P, 3) +// segms: FloatTensor of shape (S, 2, 3) +// idx_points: LongTensor of shape (P,) containing the indices +// of the closest edge in the example in the batch. +// This is computed by the forward pass. +// grad_dists: FloatTensor of shape (P,) +// +// Returns: +// grad_points: FloatTensor of shape (P, 3) +// grad_segms: FloatTensor of shape (S, 2, 3) +// + +#ifdef WITH_CUDA + +std::tuple PointEdgeDistanceBackwardCuda( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_points, + const torch::Tensor& grad_dists); +#endif + +std::tuple PointEdgeDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_points, + const torch::Tensor& grad_dists); + +std::tuple PointEdgeDistanceBackward( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_points, + const torch::Tensor& grad_dists) { + if (points.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(points); + CHECK_CUDA(segms); + CHECK_CUDA(idx_points); + CHECK_CUDA(grad_dists); + return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists); +} + +// **************************************************************************** +// * EdgePointDistance * +// **************************************************************************** + +// Computes the squared euclidean distance of each edge segment to the closest +// point belonging to the corresponding example in the batch of size N. +// +// Args: +// points: FloatTensor of shape (P, 3) +// points_first_idx: LongTensor of shape (N,) indicating the first point +// index for each example in the batch +// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge +// segment is spanned by (segms[s, 0], segms[s, 1]) +// segms_first_idx: LongTensor of shape (N,) indicating the first edge +// index for each example in the batch +// max_segms: Scalar equal to max(S_i) for i in [0, N - 1] containing +// the maximum number of edges in the batch and is used to set +// the block dimensions in the CUDA implementation. +// +// Returns: +// dists: FloatTensor of shape (S,), where dists[s] is the squared +// euclidean distance of s-th edge to the closest point in the +// corresponding example in the batch. +// idxs: LongTensor of shape (S,), where idxs[s] is the index of the closest +// point in the example in the batch. +// So, dists[s] = d(points[idxs[s]], segms[s, 0], segms[s, 1]), where +// d(u, v0, v1) is the distance of u from the segment spanned by (v0, v1) +// +// + +#ifdef WITH_CUDA + +std::tuple EdgePointDistanceForwardCuda( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t max_segms); +#endif + +std::tuple EdgePointDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t max_segms); + +std::tuple EdgePointDistanceForward( + const torch::Tensor& points, + const torch::Tensor& points_first_idx, + const torch::Tensor& segms, + const torch::Tensor& segms_first_idx, + const int64_t max_segms) { + if (points.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(points); + CHECK_CUDA(points_first_idx); + CHECK_CUDA(segms); + CHECK_CUDA(segms_first_idx); + return EdgePointDistanceForwardCuda( + points, points_first_idx, segms, segms_first_idx, max_segms); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return EdgePointDistanceForwardCpu( + points, points_first_idx, segms, segms_first_idx, max_segms); +} + +// Backward pass for EdgePointDistance. +// +// Args: +// points: FloatTensor of shape (P, 3) +// segms: FloatTensor of shape (S, 2, 3) +// idx_segms: LongTensor of shape (S,) containing the indices +// of the closest point in the example in the batch. +// This is computed by the forward pass +// grad_dists: FloatTensor of shape (S,) +// +// Returns: +// grad_points: FloatTensor of shape (P, 3) +// grad_segms: FloatTensor of shape (S, 2, 3) +// + +#ifdef WITH_CUDA + +std::tuple EdgePointDistanceBackwardCuda( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_segms, + const torch::Tensor& grad_dists); +#endif + +std::tuple EdgePointDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_segms, + const torch::Tensor& grad_dists); + +std::tuple EdgePointDistanceBackward( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& idx_segms, + const torch::Tensor& grad_dists) { + if (points.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(points); + CHECK_CUDA(segms); + CHECK_CUDA(idx_segms); + CHECK_CUDA(grad_dists); + return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists); +} + // **************************************************************************** // * PointFaceArrayDistance * // **************************************************************************** @@ -328,3 +564,92 @@ std::tuple PointFaceArrayDistanceBackward( } return PointFaceArrayDistanceBackwardCpu(points, tris, grad_dists); } + +// **************************************************************************** +// * PointEdgeArrayDistance * +// **************************************************************************** + +// Computes the squared euclidean distance of each p in points to each edge +// segment in segms. +// +// Args: +// points: FloatTensor of shape (P, 3) +// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th +// edge segment is spanned by (segms[s, 0], segms[s, 1]) +// +// Returns: +// dists: FloatTensor of shape (P, S), where dists[p, s] is the squared +// euclidean distance of points[p] to the segment spanned by +// (segms[s, 0], segms[s, 1]) +// +// For pointcloud and meshes of batch size N, this function requires N +// computations. The memory occupied is O(NPS) which can become quite large. +// For example, a medium sized batch with N = 32 with P = 10000 and S = 5000 +// will require for the forward pass 5.8G of memory to store dists. + +#ifdef WITH_CUDA +torch::Tensor PointEdgeArrayDistanceForwardCuda( + const torch::Tensor& points, + const torch::Tensor& segms); +#endif + +torch::Tensor PointEdgeArrayDistanceForwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms); + +torch::Tensor PointEdgeArrayDistanceForward( + const torch::Tensor& points, + const torch::Tensor& segms) { + if (points.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(points); + CHECK_CUDA(segms); + return PointEdgeArrayDistanceForwardCuda(points, segms); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return PointEdgeArrayDistanceForwardCpu(points, segms); +} + +// Backward pass for PointEdgeArrayDistance. +// +// Args: +// points: FloatTensor of shape (P, 3) +// segms: FloatTensor of shape (S, 2, 3) +// grad_dists: FloatTensor of shape (P, S) +// +// Returns: +// grad_points: FloatTensor of shape (P, 3) +// grad_segms: FloatTensor of shape (S, 2, 3) +// + +#ifdef WITH_CUDA + +std::tuple PointEdgeArrayDistanceBackwardCuda( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& grad_dists); +#endif + +std::tuple PointEdgeArrayDistanceBackwardCpu( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& grad_dists); + +std::tuple PointEdgeArrayDistanceBackward( + const torch::Tensor& points, + const torch::Tensor& segms, + const torch::Tensor& grad_dists) { + if (points.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(points); + CHECK_CUDA(segms); + CHECK_CUDA(grad_dists); + return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists); +} diff --git a/pytorch3d/csrc/point_mesh/point_mesh_edge.cu b/pytorch3d/csrc/point_mesh/point_mesh_edge.cu deleted file mode 100644 index 98db3bd2..00000000 --- a/pytorch3d/csrc/point_mesh/point_mesh_edge.cu +++ /dev/null @@ -1,651 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -#include -#include -#include -#include -#include -#include -#include -#include "utils/float_math.cuh" -#include "utils/geometry_utils.cuh" -#include "utils/warp_reduce.cuh" - -// **************************************************************************** -// * PointEdgeDistance * -// **************************************************************************** - -__global__ void PointEdgeForwardKernel( - const float* __restrict__ points, // (P, 3) - const int64_t* __restrict__ points_first_idx, // (B,) - const float* __restrict__ segms, // (S, 2, 3) - const int64_t* __restrict__ segms_first_idx, // (B,) - float* __restrict__ dist_points, // (P,) - int64_t* __restrict__ idx_points, // (P,) - const size_t B, - const size_t P, - const size_t S) { - float3* points_f3 = (float3*)points; - float3* segms_f3 = (float3*)segms; - - // Single shared memory buffer which is split and cast to different types. - extern __shared__ char shared_buf[]; - float* min_dists = (float*)shared_buf; // float[NUM_THREADS] - int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS] - - const size_t batch_idx = blockIdx.y; // index of batch element. - - // start and end for points in batch - const int64_t startp = points_first_idx[batch_idx]; - const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P; - - // start and end for segments in batch_idx - const int64_t starts = segms_first_idx[batch_idx]; - const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S; - - const size_t i = blockIdx.x; // index of point within batch element. - const size_t tid = threadIdx.x; // thread idx - - // Each block will compute one element of the output idx_points[startp + i], - // dist_points[startp + i]. Within the block we will use threads to compute - // the distances between points[startp + i] and segms[j] for all j belonging - // in the same batch as i, i.e. j in [starts, ends]. Then use a block - // reduction to take an argmin of the distances. - - // If i exceeds the number of points in batch_idx, then do nothing - if (i < (endp - startp)) { - // Retrieve (startp + i) point - const float3 p_f3 = points_f3[startp + i]; - - // Compute the distances between points[startp + i] and segms[j] for - // all j belonging in the same batch as i, i.e. j in [starts, ends]. - // Here each thread will reduce over (ends-starts) / blockDim.x in serial, - // and store its result to shared memory - float min_dist = FLT_MAX; - size_t min_idx = 0; - for (size_t j = tid; j < (ends - starts); j += blockDim.x) { - const float3 v0 = segms_f3[(starts + j) * 2 + 0]; - const float3 v1 = segms_f3[(starts + j) * 2 + 1]; - float dist = PointLine3DistanceForward(p_f3, v0, v1); - min_dist = (j == tid) ? dist : min_dist; - min_idx = (dist <= min_dist) ? (starts + j) : min_idx; - min_dist = (dist <= min_dist) ? dist : min_dist; - } - min_dists[tid] = min_dist; - min_idxs[tid] = min_idx; - __syncthreads(); - - // Perform reduction in shared memory. - for (int s = blockDim.x / 2; s > 32; s >>= 1) { - if (tid < s) { - if (min_dists[tid] > min_dists[tid + s]) { - min_dists[tid] = min_dists[tid + s]; - min_idxs[tid] = min_idxs[tid + s]; - } - } - __syncthreads(); - } - - // Unroll the last 6 iterations of the loop since they will happen - // synchronized within a single warp. - if (tid < 32) - WarpReduce(min_dists, min_idxs, tid); - - // Finally thread 0 writes the result to the output buffer. - if (tid == 0) { - idx_points[startp + i] = min_idxs[0]; - dist_points[startp + i] = min_dists[0]; - } - } -} - -std::tuple PointEdgeDistanceForwardCuda( - const at::Tensor& points, - const at::Tensor& points_first_idx, - const at::Tensor& segms, - const at::Tensor& segms_first_idx, - const int64_t max_points) { - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - points_first_idx_t{points_first_idx, "points_first_idx", 2}, - segms_t{segms, "segms", 3}, - segms_first_idx_t{segms_first_idx, "segms_first_idx", 4}; - at::CheckedFrom c = "PointEdgeDistanceForwardCuda"; - at::checkAllSameGPU( - c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t}); - at::checkAllSameType(c, {points_t, segms_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int64_t P = points.size(0); - const int64_t S = segms.size(0); - const int64_t B = points_first_idx.size(0); - - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (segms.size(1) == 2) && (segms.size(2) == 3), - "segms must be of shape Sx2x3"); - TORCH_CHECK(segms_first_idx.size(0) == B); - - // clang-format off - at::Tensor dists = at::zeros({P,}, points.options()); - at::Tensor idxs = at::zeros({P,}, points_first_idx.options()); - // clang-format on - - if (dists.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(dists, idxs); - } - - const int threads = 128; - const dim3 blocks(max_points, B); - size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); - - PointEdgeForwardKernel<<>>( - points.contiguous().data_ptr(), - points_first_idx.contiguous().data_ptr(), - segms.contiguous().data_ptr(), - segms_first_idx.contiguous().data_ptr(), - dists.data_ptr(), - idxs.data_ptr(), - B, - P, - S); - - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(dists, idxs); -} - -__global__ void PointEdgeBackwardKernel( - const float* __restrict__ points, // (P, 3) - const float* __restrict__ segms, // (S, 2, 3) - const int64_t* __restrict__ idx_points, // (P,) - const float* __restrict__ grad_dists, // (P,) - float* __restrict__ grad_points, // (P, 3) - float* __restrict__ grad_segms, // (S, 2, 3) - const size_t P) { - float3* points_f3 = (float3*)points; - float3* segms_f3 = (float3*)segms; - - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = gridDim.x * blockDim.x; - - for (size_t p = tid; p < P; p += stride) { - const float3 p_f3 = points_f3[p]; - - const int64_t sidx = idx_points[p]; - const float3 v0 = segms_f3[sidx * 2 + 0]; - const float3 v1 = segms_f3[sidx * 2 + 1]; - - const float grad_dist = grad_dists[p]; - - const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist); - const float3 grad_point = thrust::get<0>(grads); - const float3 grad_v0 = thrust::get<1>(grads); - const float3 grad_v1 = thrust::get<2>(grads); - - atomicAdd(grad_points + p * 3 + 0, grad_point.x); - atomicAdd(grad_points + p * 3 + 1, grad_point.y); - atomicAdd(grad_points + p * 3 + 2, grad_point.z); - - atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 0, grad_v0.x); - atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 1, grad_v0.y); - atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 2, grad_v0.z); - - atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 0, grad_v1.x); - atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 1, grad_v1.y); - atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 2, grad_v1.z); - } -} - -std::tuple PointEdgeDistanceBackwardCuda( - const at::Tensor& points, - const at::Tensor& segms, - const at::Tensor& idx_points, - const at::Tensor& grad_dists) { - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - idx_points_t{idx_points, "idx_points", 2}, segms_t{segms, "segms", 3}, - grad_dists_t{grad_dists, "grad_dists", 4}; - at::CheckedFrom c = "PointEdgeDistanceBackwardCuda"; - at::checkAllSameGPU(c, {points_t, idx_points_t, segms_t, grad_dists_t}); - at::checkAllSameType(c, {points_t, segms_t, grad_dists_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int64_t P = points.size(0); - const int64_t S = segms.size(0); - - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (segms.size(1) == 2) && (segms.size(2) == 3), - "segms must be of shape Sx2x3"); - TORCH_CHECK(idx_points.size(0) == P); - TORCH_CHECK(grad_dists.size(0) == P); - - // clang-format off - at::Tensor grad_points = at::zeros({P, 3}, points.options()); - at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options()); - // clang-format on - - if (grad_points.numel() == 0 || grad_segms.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_points, grad_segms); - } - - const int blocks = 64; - const int threads = 512; - - PointEdgeBackwardKernel<<>>( - points.contiguous().data_ptr(), - segms.contiguous().data_ptr(), - idx_points.contiguous().data_ptr(), - grad_dists.contiguous().data_ptr(), - grad_points.data_ptr(), - grad_segms.data_ptr(), - P); - - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_points, grad_segms); -} - -// **************************************************************************** -// * EdgePointDistance * -// **************************************************************************** - -__global__ void EdgePointForwardKernel( - const float* __restrict__ points, // (P, 3) - const int64_t* __restrict__ points_first_idx, // (B,) - const float* __restrict__ segms, // (S, 2, 3) - const int64_t* __restrict__ segms_first_idx, // (B,) - float* __restrict__ dist_segms, // (S,) - int64_t* __restrict__ idx_segms, // (S,) - const size_t B, - const size_t P, - const size_t S) { - float3* points_f3 = (float3*)points; - float3* segms_f3 = (float3*)segms; - - // Single shared memory buffer which is split and cast to different types. - extern __shared__ char shared_buf[]; - float* min_dists = (float*)shared_buf; // float[NUM_THREADS] - int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS] - - const size_t batch_idx = blockIdx.y; // index of batch element. - - // start and end for points in batch_idx - const int64_t startp = points_first_idx[batch_idx]; - const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P; - - // start and end for segms in batch_idx - const int64_t starts = segms_first_idx[batch_idx]; - const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S; - - const size_t i = blockIdx.x; // index of point within batch element. - const size_t tid = threadIdx.x; // thread index - - // Each block will compute one element of the output idx_segms[starts + i], - // dist_segms[starts + i]. Within the block we will use threads to compute - // the distances between segms[starts + i] and points[j] for all j belonging - // in the same batch as i, i.e. j in [startp, endp]. Then use a block - // reduction to take an argmin of the distances. - - // If i exceeds the number of segms in batch_idx, then do nothing - if (i < (ends - starts)) { - const float3 v0 = segms_f3[(starts + i) * 2 + 0]; - const float3 v1 = segms_f3[(starts + i) * 2 + 1]; - - // Compute the distances between segms[starts + i] and points[j] for - // all j belonging in the same batch as i, i.e. j in [startp, endp]. - // Here each thread will reduce over (endp-startp) / blockDim.x in serial, - // and store its result to shared memory - float min_dist = FLT_MAX; - size_t min_idx = 0; - for (size_t j = tid; j < (endp - startp); j += blockDim.x) { - // Retrieve (startp + i) point - const float3 p_f3 = points_f3[startp + j]; - - float dist = PointLine3DistanceForward(p_f3, v0, v1); - min_dist = (j == tid) ? dist : min_dist; - min_idx = (dist <= min_dist) ? (startp + j) : min_idx; - min_dist = (dist <= min_dist) ? dist : min_dist; - } - min_dists[tid] = min_dist; - min_idxs[tid] = min_idx; - __syncthreads(); - - // Perform reduction in shared memory. - for (int s = blockDim.x / 2; s > 32; s >>= 1) { - if (tid < s) { - if (min_dists[tid] > min_dists[tid + s]) { - min_dists[tid] = min_dists[tid + s]; - min_idxs[tid] = min_idxs[tid + s]; - } - } - __syncthreads(); - } - - // Unroll the last 6 iterations of the loop since they will happen - // synchronized within a single warp. - if (tid < 32) - WarpReduce(min_dists, min_idxs, tid); - - // Finally thread 0 writes the result to the output buffer. - if (tid == 0) { - idx_segms[starts + i] = min_idxs[0]; - dist_segms[starts + i] = min_dists[0]; - } - } -} - -std::tuple EdgePointDistanceForwardCuda( - const at::Tensor& points, - const at::Tensor& points_first_idx, - const at::Tensor& segms, - const at::Tensor& segms_first_idx, - const int64_t max_segms) { - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - points_first_idx_t{points_first_idx, "points_first_idx", 2}, - segms_t{segms, "segms", 3}, - segms_first_idx_t{segms_first_idx, "segms_first_idx", 4}; - at::CheckedFrom c = "EdgePointDistanceForwardCuda"; - at::checkAllSameGPU( - c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t}); - at::checkAllSameType(c, {points_t, segms_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int64_t P = points.size(0); - const int64_t S = segms.size(0); - const int64_t B = points_first_idx.size(0); - - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (segms.size(1) == 2) && (segms.size(2) == 3), - "segms must be of shape Sx2x3"); - TORCH_CHECK(segms_first_idx.size(0) == B); - - // clang-format off - at::Tensor dists = at::zeros({S,}, segms.options()); - at::Tensor idxs = at::zeros({S,}, segms_first_idx.options()); - // clang-format on - - if (dists.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(dists, idxs); - } - - const int threads = 128; - const dim3 blocks(max_segms, B); - size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); - - EdgePointForwardKernel<<>>( - points.contiguous().data_ptr(), - points_first_idx.contiguous().data_ptr(), - segms.contiguous().data_ptr(), - segms_first_idx.contiguous().data_ptr(), - dists.data_ptr(), - idxs.data_ptr(), - B, - P, - S); - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(dists, idxs); -} - -__global__ void EdgePointBackwardKernel( - const float* __restrict__ points, // (P, 3) - const float* __restrict__ segms, // (S, 2, 3) - const int64_t* __restrict__ idx_segms, // (S,) - const float* __restrict__ grad_dists, // (S,) - float* __restrict__ grad_points, // (P, 3) - float* __restrict__ grad_segms, // (S, 2, 3) - const size_t S) { - float3* points_f3 = (float3*)points; - float3* segms_f3 = (float3*)segms; - - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = gridDim.x * blockDim.x; - - for (size_t s = tid; s < S; s += stride) { - const float3 v0 = segms_f3[s * 2 + 0]; - const float3 v1 = segms_f3[s * 2 + 1]; - - const int64_t pidx = idx_segms[s]; - - const float3 p_f3 = points_f3[pidx]; - - const float grad_dist = grad_dists[s]; - - const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist); - const float3 grad_point = thrust::get<0>(grads); - const float3 grad_v0 = thrust::get<1>(grads); - const float3 grad_v1 = thrust::get<2>(grads); - - atomicAdd(grad_points + pidx * 3 + 0, grad_point.x); - atomicAdd(grad_points + pidx * 3 + 1, grad_point.y); - atomicAdd(grad_points + pidx * 3 + 2, grad_point.z); - - atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_v0.x); - atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_v0.y); - atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_v0.z); - - atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_v1.x); - atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_v1.y); - atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_v1.z); - } -} - -std::tuple EdgePointDistanceBackwardCuda( - const at::Tensor& points, - const at::Tensor& segms, - const at::Tensor& idx_segms, - const at::Tensor& grad_dists) { - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - idx_segms_t{idx_segms, "idx_segms", 2}, segms_t{segms, "segms", 3}, - grad_dists_t{grad_dists, "grad_dists", 4}; - at::CheckedFrom c = "PointEdgeDistanceBackwardCuda"; - at::checkAllSameGPU(c, {points_t, idx_segms_t, segms_t, grad_dists_t}); - at::checkAllSameType(c, {points_t, segms_t, grad_dists_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int64_t P = points.size(0); - const int64_t S = segms.size(0); - - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (segms.size(1) == 2) && (segms.size(2) == 3), - "segms must be of shape Sx2x3"); - TORCH_CHECK(idx_segms.size(0) == S); - TORCH_CHECK(grad_dists.size(0) == S); - - // clang-format off - at::Tensor grad_points = at::zeros({P, 3}, points.options()); - at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options()); - // clang-format on - - const int blocks = 64; - const int threads = 512; - - EdgePointBackwardKernel<<>>( - points.contiguous().data_ptr(), - segms.contiguous().data_ptr(), - idx_segms.contiguous().data_ptr(), - grad_dists.contiguous().data_ptr(), - grad_points.data_ptr(), - grad_segms.data_ptr(), - S); - - return std::make_tuple(grad_points, grad_segms); -} - -// **************************************************************************** -// * PointEdgeArrayDistance * -// **************************************************************************** - -__global__ void PointEdgeArrayForwardKernel( - const float* __restrict__ points, // (P, 3) - const float* __restrict__ segms, // (S, 2, 3) - float* __restrict__ dists, // (P, S) - const size_t P, - const size_t S) { - float3* points_f3 = (float3*)points; - float3* segms_f3 = (float3*)segms; - - // Parallelize over P * S computations - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int t_i = tid; t_i < P * S; t_i += num_threads) { - const int s = t_i / P; // segment index. - const int p = t_i % P; // point index - float3 a = segms_f3[s * 2 + 0]; - float3 b = segms_f3[s * 2 + 1]; - - float3 point = points_f3[p]; - float dist = PointLine3DistanceForward(point, a, b); - dists[p * S + s] = dist; - } -} - -at::Tensor PointEdgeArrayDistanceForwardCuda( - const at::Tensor& points, - const at::Tensor& segms) { - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2}; - at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda"; - at::checkAllSameGPU(c, {points_t, segms_t}); - at::checkAllSameType(c, {points_t, segms_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int64_t P = points.size(0); - const int64_t S = segms.size(0); - - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (segms.size(1) == 2) && (segms.size(2) == 3), - "segms must be of shape Sx2x3"); - - at::Tensor dists = at::zeros({P, S}, points.options()); - - if (dists.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return dists; - } - - const size_t blocks = 1024; - const size_t threads = 64; - - PointEdgeArrayForwardKernel<<>>( - points.contiguous().data_ptr(), - segms.contiguous().data_ptr(), - dists.data_ptr(), - P, - S); - - AT_CUDA_CHECK(cudaGetLastError()); - return dists; -} - -__global__ void PointEdgeArrayBackwardKernel( - const float* __restrict__ points, // (P, 3) - const float* __restrict__ segms, // (S, 2, 3) - const float* __restrict__ grad_dists, // (P, S) - float* __restrict__ grad_points, // (P, 3) - float* __restrict__ grad_segms, // (S, 2, 3) - const size_t P, - const size_t S) { - float3* points_f3 = (float3*)points; - float3* segms_f3 = (float3*)segms; - - // Parallelize over P * S computations - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int t_i = tid; t_i < P * S; t_i += num_threads) { - const int s = t_i / P; // segment index. - const int p = t_i % P; // point index - const float3 a = segms_f3[s * 2 + 0]; - const float3 b = segms_f3[s * 2 + 1]; - - const float3 point = points_f3[p]; - const float grad_dist = grad_dists[p * S + s]; - const auto grads = PointLine3DistanceBackward(point, a, b, grad_dist); - const float3 grad_point = thrust::get<0>(grads); - const float3 grad_a = thrust::get<1>(grads); - const float3 grad_b = thrust::get<2>(grads); - - atomicAdd(grad_points + p * 3 + 0, grad_point.x); - atomicAdd(grad_points + p * 3 + 1, grad_point.y); - atomicAdd(grad_points + p * 3 + 2, grad_point.z); - - atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_a.x); - atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_a.y); - atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_a.z); - - atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_b.x); - atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_b.y); - atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_b.z); - } -} - -std::tuple PointEdgeArrayDistanceBackwardCuda( - const at::Tensor& points, - const at::Tensor& segms, - const at::Tensor& grad_dists) { - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2}, - grad_dists_t{grad_dists, "grad_dists", 3}; - at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda"; - at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t}); - at::checkAllSameType(c, {points_t, segms_t, grad_dists_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int64_t P = points.size(0); - const int64_t S = segms.size(0); - - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (segms.size(1) == 2) && (segms.size(2) == 3), - "segms must be of shape Sx2x3"); - TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S)); - - at::Tensor grad_points = at::zeros({P, 3}, points.options()); - at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options()); - - if (grad_points.numel() == 0 || grad_segms.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_points, grad_segms); - } - - const size_t blocks = 1024; - const size_t threads = 64; - - PointEdgeArrayBackwardKernel<<>>( - points.contiguous().data_ptr(), - segms.contiguous().data_ptr(), - grad_dists.contiguous().data_ptr(), - grad_points.data_ptr(), - grad_segms.data_ptr(), - P, - S); - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_points, grad_segms); -} diff --git a/pytorch3d/csrc/point_mesh/point_mesh_edge.h b/pytorch3d/csrc/point_mesh/point_mesh_edge.h deleted file mode 100644 index b775d2e0..00000000 --- a/pytorch3d/csrc/point_mesh/point_mesh_edge.h +++ /dev/null @@ -1,332 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -#pragma once -#include -#include -#include -#include "utils/pytorch3d_cutils.h" - -// **************************************************************************** -// * PointEdgeDistance * -// **************************************************************************** - -// Computes the squared euclidean distance of each p in points to the closest -// mesh edge belonging to the corresponding example in the batch of size N. -// -// Args: -// points: FloatTensor of shape (P, 3) -// points_first_idx: LongTensor of shape (N,) indicating the first point -// index for each example in the batch -// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge -// segment is spanned by (segms[s, 0], segms[s, 1]) -// segms_first_idx: LongTensor of shape (N,) indicating the first edge -// index for each example in the batch -// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing -// the maximum number of points in the batch and is used to set -// the grid dimensions in the CUDA implementation. -// -// Returns: -// dists: FloatTensor of shape (P,), where dists[p] is the squared euclidean -// distance of points[p] to the closest edge in the same example in the -// batch. -// idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest -// edge in the batch. -// So, dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1]), -// where d(u, v0, v1) is the distance of u from the segment spanned by -// (v0, v1). -// - -#ifdef WITH_CUDA - -std::tuple PointEdgeDistanceForwardCuda( - const torch::Tensor& points, - const torch::Tensor& points_first_idx, - const torch::Tensor& segms, - const torch::Tensor& segms_first_idx, - const int64_t max_points); -#endif - -std::tuple PointEdgeDistanceForwardCpu( - const torch::Tensor& points, - const torch::Tensor& points_first_idx, - const torch::Tensor& segms, - const torch::Tensor& segms_first_idx, - const int64_t max_points); - -std::tuple PointEdgeDistanceForward( - const torch::Tensor& points, - const torch::Tensor& points_first_idx, - const torch::Tensor& segms, - const torch::Tensor& segms_first_idx, - const int64_t max_points) { - if (points.is_cuda()) { -#ifdef WITH_CUDA - CHECK_CUDA(points); - CHECK_CUDA(points_first_idx); - CHECK_CUDA(segms); - CHECK_CUDA(segms_first_idx); - return PointEdgeDistanceForwardCuda( - points, points_first_idx, segms, segms_first_idx, max_points); -#else - AT_ERROR("Not compiled with GPU support."); -#endif - } - return PointEdgeDistanceForwardCpu( - points, points_first_idx, segms, segms_first_idx, max_points); -} - -// Backward pass for PointEdgeDistance. -// -// Args: -// points: FloatTensor of shape (P, 3) -// segms: FloatTensor of shape (S, 2, 3) -// idx_points: LongTensor of shape (P,) containing the indices -// of the closest edge in the example in the batch. -// This is computed by the forward pass. -// grad_dists: FloatTensor of shape (P,) -// -// Returns: -// grad_points: FloatTensor of shape (P, 3) -// grad_segms: FloatTensor of shape (S, 2, 3) -// - -#ifdef WITH_CUDA - -std::tuple PointEdgeDistanceBackwardCuda( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& idx_points, - const torch::Tensor& grad_dists); -#endif - -std::tuple PointEdgeDistanceBackwardCpu( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& idx_points, - const torch::Tensor& grad_dists); - -std::tuple PointEdgeDistanceBackward( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& idx_points, - const torch::Tensor& grad_dists) { - if (points.is_cuda()) { -#ifdef WITH_CUDA - CHECK_CUDA(points); - CHECK_CUDA(segms); - CHECK_CUDA(idx_points); - CHECK_CUDA(grad_dists); - return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists); -#else - AT_ERROR("Not compiled with GPU support."); -#endif - } - return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists); -} - -// **************************************************************************** -// * EdgePointDistance * -// **************************************************************************** - -// Computes the squared euclidean distance of each edge segment to the closest -// point belonging to the corresponding example in the batch of size N. -// -// Args: -// points: FloatTensor of shape (P, 3) -// points_first_idx: LongTensor of shape (N,) indicating the first point -// index for each example in the batch -// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge -// segment is spanned by (segms[s, 0], segms[s, 1]) -// segms_first_idx: LongTensor of shape (N,) indicating the first edge -// index for each example in the batch -// max_segms: Scalar equal to max(S_i) for i in [0, N - 1] containing -// the maximum number of edges in the batch and is used to set -// the block dimensions in the CUDA implementation. -// -// Returns: -// dists: FloatTensor of shape (S,), where dists[s] is the squared -// euclidean distance of s-th edge to the closest point in the -// corresponding example in the batch. -// idxs: LongTensor of shape (S,), where idxs[s] is the index of the closest -// point in the example in the batch. -// So, dists[s] = d(points[idxs[s]], segms[s, 0], segms[s, 1]), where -// d(u, v0, v1) is the distance of u from the segment spanned by (v0, v1) -// -// - -#ifdef WITH_CUDA - -std::tuple EdgePointDistanceForwardCuda( - const torch::Tensor& points, - const torch::Tensor& points_first_idx, - const torch::Tensor& segms, - const torch::Tensor& segms_first_idx, - const int64_t max_segms); -#endif - -std::tuple EdgePointDistanceForwardCpu( - const torch::Tensor& points, - const torch::Tensor& points_first_idx, - const torch::Tensor& segms, - const torch::Tensor& segms_first_idx, - const int64_t max_segms); - -std::tuple EdgePointDistanceForward( - const torch::Tensor& points, - const torch::Tensor& points_first_idx, - const torch::Tensor& segms, - const torch::Tensor& segms_first_idx, - const int64_t max_segms) { - if (points.is_cuda()) { -#ifdef WITH_CUDA - CHECK_CUDA(points); - CHECK_CUDA(points_first_idx); - CHECK_CUDA(segms); - CHECK_CUDA(segms_first_idx); - return EdgePointDistanceForwardCuda( - points, points_first_idx, segms, segms_first_idx, max_segms); -#else - AT_ERROR("Not compiled with GPU support."); -#endif - } - return EdgePointDistanceForwardCpu( - points, points_first_idx, segms, segms_first_idx, max_segms); -} - -// Backward pass for EdgePointDistance. -// -// Args: -// points: FloatTensor of shape (P, 3) -// segms: FloatTensor of shape (S, 2, 3) -// idx_segms: LongTensor of shape (S,) containing the indices -// of the closest point in the example in the batch. -// This is computed by the forward pass -// grad_dists: FloatTensor of shape (S,) -// -// Returns: -// grad_points: FloatTensor of shape (P, 3) -// grad_segms: FloatTensor of shape (S, 2, 3) -// - -#ifdef WITH_CUDA - -std::tuple EdgePointDistanceBackwardCuda( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& idx_segms, - const torch::Tensor& grad_dists); -#endif - -std::tuple EdgePointDistanceBackwardCpu( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& idx_segms, - const torch::Tensor& grad_dists); - -std::tuple EdgePointDistanceBackward( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& idx_segms, - const torch::Tensor& grad_dists) { - if (points.is_cuda()) { -#ifdef WITH_CUDA - CHECK_CUDA(points); - CHECK_CUDA(segms); - CHECK_CUDA(idx_segms); - CHECK_CUDA(grad_dists); - return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists); -#else - AT_ERROR("Not compiled with GPU support."); -#endif - } - return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists); -} - -// **************************************************************************** -// * PointEdgeArrayDistance * -// **************************************************************************** - -// Computes the squared euclidean distance of each p in points to each edge -// segment in segms. -// -// Args: -// points: FloatTensor of shape (P, 3) -// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th -// edge segment is spanned by (segms[s, 0], segms[s, 1]) -// -// Returns: -// dists: FloatTensor of shape (P, S), where dists[p, s] is the squared -// euclidean distance of points[p] to the segment spanned by -// (segms[s, 0], segms[s, 1]) -// -// For pointcloud and meshes of batch size N, this function requires N -// computations. The memory occupied is O(NPS) which can become quite large. -// For example, a medium sized batch with N = 32 with P = 10000 and S = 5000 -// will require for the forward pass 5.8G of memory to store dists. - -#ifdef WITH_CUDA -torch::Tensor PointEdgeArrayDistanceForwardCuda( - const torch::Tensor& points, - const torch::Tensor& segms); -#endif - -torch::Tensor PointEdgeArrayDistanceForwardCpu( - const torch::Tensor& points, - const torch::Tensor& segms); - -torch::Tensor PointEdgeArrayDistanceForward( - const torch::Tensor& points, - const torch::Tensor& segms) { - if (points.is_cuda()) { -#ifdef WITH_CUDA - CHECK_CUDA(points); - CHECK_CUDA(segms); - return PointEdgeArrayDistanceForwardCuda(points, segms); -#else - AT_ERROR("Not compiled with GPU support."); -#endif - } - return PointEdgeArrayDistanceForwardCpu(points, segms); -} - -// Backward pass for PointEdgeArrayDistance. -// -// Args: -// points: FloatTensor of shape (P, 3) -// segms: FloatTensor of shape (S, 2, 3) -// grad_dists: FloatTensor of shape (P, S) -// -// Returns: -// grad_points: FloatTensor of shape (P, 3) -// grad_segms: FloatTensor of shape (S, 2, 3) -// - -#ifdef WITH_CUDA - -std::tuple PointEdgeArrayDistanceBackwardCuda( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& grad_dists); -#endif - -std::tuple PointEdgeArrayDistanceBackwardCpu( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& grad_dists); - -std::tuple PointEdgeArrayDistanceBackward( - const torch::Tensor& points, - const torch::Tensor& segms, - const torch::Tensor& grad_dists) { - if (points.is_cuda()) { -#ifdef WITH_CUDA - CHECK_CUDA(points); - CHECK_CUDA(segms); - CHECK_CUDA(grad_dists); - return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists); -#else - AT_ERROR("Not compiled with GPU support."); -#endif - } - return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists); -}