mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Remove point mesh edge kernels
Summary: Removes the now-unnecessary kernels from point mesh edge file Migrates all point mesh functionality into one file. Reviewed By: gkioxari Differential Revision: D24550086 fbshipit-source-id: f924996cd38a7c2c1cf189d8a01611de4506cfa3
This commit is contained in:
		
							parent
							
								
									8dcfe30f66
								
							
						
					
					
						commit
						804235b05a
					
				@ -15,8 +15,7 @@
 | 
			
		||||
#include "interp_face_attrs/interp_face_attrs.h"
 | 
			
		||||
#include "knn/knn.h"
 | 
			
		||||
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
 | 
			
		||||
#include "point_mesh/point_mesh_edge.h"
 | 
			
		||||
#include "point_mesh/point_mesh_face.h"
 | 
			
		||||
#include "point_mesh/point_mesh_cuda.h"
 | 
			
		||||
#include "rasterize_meshes/rasterize_meshes.h"
 | 
			
		||||
#include "rasterize_points/rasterize_points.h"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@
 | 
			
		||||
#include "utils/warp_reduce.cuh"
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                          PointFaceDistance                               *
 | 
			
		||||
// *                   Generic Forward/Backward Kernels                       *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
__global__ void DistanceForwardKernel(
 | 
			
		||||
@ -202,16 +202,6 @@ std::tuple<at::Tensor, at::Tensor> DistanceForwardCuda(
 | 
			
		||||
  return std::make_tuple(dists, idxs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& points_first_idx,
 | 
			
		||||
    const at::Tensor& tris,
 | 
			
		||||
    const at::Tensor& tris_first_idx,
 | 
			
		||||
    const int64_t max_points) {
 | 
			
		||||
  return DistanceForwardCuda(
 | 
			
		||||
      points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void DistanceBackwardKernel(
 | 
			
		||||
    const float* __restrict__ objects, // (O * oD * 3)
 | 
			
		||||
    const size_t objects_size, // O
 | 
			
		||||
@ -365,6 +355,20 @@ std::tuple<at::Tensor, at::Tensor> DistanceBackwardCuda(
 | 
			
		||||
  return std::make_tuple(grad_points, grad_tris);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                          PointFaceDistance                               *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& points_first_idx,
 | 
			
		||||
    const at::Tensor& tris,
 | 
			
		||||
    const at::Tensor& tris_first_idx,
 | 
			
		||||
    const int64_t max_points) {
 | 
			
		||||
  return DistanceForwardCuda(
 | 
			
		||||
      points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& tris,
 | 
			
		||||
@ -395,9 +399,54 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
 | 
			
		||||
  return DistanceBackwardCuda(tris, 3, points, 1, idx_tris, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                          PointEdgeDistance                               *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& points_first_idx,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_points) {
 | 
			
		||||
  return DistanceForwardCuda(
 | 
			
		||||
      points, 1, points_first_idx, segms, 2, segms_first_idx, max_points);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& idx_points,
 | 
			
		||||
    const at::Tensor& grad_dists) {
 | 
			
		||||
  return DistanceBackwardCuda(points, 1, segms, 2, idx_points, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                          EdgePointDistance                               *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& points_first_idx,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_segms) {
 | 
			
		||||
  return DistanceForwardCuda(
 | 
			
		||||
      segms, 2, segms_first_idx, points, 1, points_first_idx, max_segms);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& idx_segms,
 | 
			
		||||
    const at::Tensor& grad_dists) {
 | 
			
		||||
  return DistanceBackwardCuda(segms, 2, points, 1, idx_segms, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                     PointFaceArrayDistance                               *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// TODO: Create wrapper function and merge kernel with other array kernel
 | 
			
		||||
 | 
			
		||||
__global__ void PointFaceArrayForwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
@ -565,3 +614,164 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  return std::make_tuple(grad_points, grad_tris);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                     PointEdgeArrayDistance                               *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// TODO: Create wrapper function and merge kernel with other array kernel
 | 
			
		||||
 | 
			
		||||
__global__ void PointEdgeArrayForwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
    const float* __restrict__ segms, // (S, 2, 3)
 | 
			
		||||
    float* __restrict__ dists, // (P, S)
 | 
			
		||||
    const size_t P,
 | 
			
		||||
    const size_t S) {
 | 
			
		||||
  float3* points_f3 = (float3*)points;
 | 
			
		||||
  float3* segms_f3 = (float3*)segms;
 | 
			
		||||
 | 
			
		||||
  // Parallelize over P * S computations
 | 
			
		||||
  const int num_threads = gridDim.x * blockDim.x;
 | 
			
		||||
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
  for (int t_i = tid; t_i < P * S; t_i += num_threads) {
 | 
			
		||||
    const int s = t_i / P; // segment index.
 | 
			
		||||
    const int p = t_i % P; // point index
 | 
			
		||||
    float3 a = segms_f3[s * 2 + 0];
 | 
			
		||||
    float3 b = segms_f3[s * 2 + 1];
 | 
			
		||||
 | 
			
		||||
    float3 point = points_f3[p];
 | 
			
		||||
    float dist = PointLine3DistanceForward(point, a, b);
 | 
			
		||||
    dists[p * S + s] = dist;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
at::Tensor PointEdgeArrayDistanceForwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& segms) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2};
 | 
			
		||||
  at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda";
 | 
			
		||||
  at::checkAllSameGPU(c, {points_t, segms_t});
 | 
			
		||||
  at::checkAllSameType(c, {points_t, segms_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(points.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t P = points.size(0);
 | 
			
		||||
  const int64_t S = segms.size(0);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (segms.size(1) == 2) && (segms.size(2) == 3),
 | 
			
		||||
      "segms must be of shape Sx2x3");
 | 
			
		||||
 | 
			
		||||
  at::Tensor dists = at::zeros({P, S}, points.options());
 | 
			
		||||
 | 
			
		||||
  if (dists.numel() == 0) {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
    return dists;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const size_t blocks = 1024;
 | 
			
		||||
  const size_t threads = 64;
 | 
			
		||||
 | 
			
		||||
  PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      segms.contiguous().data_ptr<float>(),
 | 
			
		||||
      dists.data_ptr<float>(),
 | 
			
		||||
      P,
 | 
			
		||||
      S);
 | 
			
		||||
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  return dists;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void PointEdgeArrayBackwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
    const float* __restrict__ segms, // (S, 2, 3)
 | 
			
		||||
    const float* __restrict__ grad_dists, // (P, S)
 | 
			
		||||
    float* __restrict__ grad_points, // (P, 3)
 | 
			
		||||
    float* __restrict__ grad_segms, // (S, 2, 3)
 | 
			
		||||
    const size_t P,
 | 
			
		||||
    const size_t S) {
 | 
			
		||||
  float3* points_f3 = (float3*)points;
 | 
			
		||||
  float3* segms_f3 = (float3*)segms;
 | 
			
		||||
 | 
			
		||||
  // Parallelize over P * S computations
 | 
			
		||||
  const int num_threads = gridDim.x * blockDim.x;
 | 
			
		||||
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
  for (int t_i = tid; t_i < P * S; t_i += num_threads) {
 | 
			
		||||
    const int s = t_i / P; // segment index.
 | 
			
		||||
    const int p = t_i % P; // point index
 | 
			
		||||
    const float3 a = segms_f3[s * 2 + 0];
 | 
			
		||||
    const float3 b = segms_f3[s * 2 + 1];
 | 
			
		||||
 | 
			
		||||
    const float3 point = points_f3[p];
 | 
			
		||||
    const float grad_dist = grad_dists[p * S + s];
 | 
			
		||||
    const auto grads = PointLine3DistanceBackward(point, a, b, grad_dist);
 | 
			
		||||
    const float3 grad_point = thrust::get<0>(grads);
 | 
			
		||||
    const float3 grad_a = thrust::get<1>(grads);
 | 
			
		||||
    const float3 grad_b = thrust::get<2>(grads);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 0, grad_point.x);
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 1, grad_point.y);
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 2, grad_point.z);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_a.x);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_a.y);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_a.z);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_b.x);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_b.y);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_b.z);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& grad_dists) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2},
 | 
			
		||||
      grad_dists_t{grad_dists, "grad_dists", 3};
 | 
			
		||||
  at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
 | 
			
		||||
  at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
 | 
			
		||||
  at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(points.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t P = points.size(0);
 | 
			
		||||
  const int64_t S = segms.size(0);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (segms.size(1) == 2) && (segms.size(2) == 3),
 | 
			
		||||
      "segms must be of shape Sx2x3");
 | 
			
		||||
  TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
 | 
			
		||||
 | 
			
		||||
  at::Tensor grad_points = at::zeros({P, 3}, points.options());
 | 
			
		||||
  at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
 | 
			
		||||
 | 
			
		||||
  if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
    return std::make_tuple(grad_points, grad_segms);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const size_t blocks = 1024;
 | 
			
		||||
  const size_t threads = 64;
 | 
			
		||||
 | 
			
		||||
  PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      segms.contiguous().data_ptr<float>(),
 | 
			
		||||
      grad_dists.contiguous().data_ptr<float>(),
 | 
			
		||||
      grad_points.data_ptr<float>(),
 | 
			
		||||
      grad_segms.data_ptr<float>(),
 | 
			
		||||
      P,
 | 
			
		||||
      S);
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  return std::make_tuple(grad_points, grad_segms);
 | 
			
		||||
}
 | 
			
		||||
@ -241,6 +241,242 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
 | 
			
		||||
  return FacePointDistanceBackwardCpu(points, tris, idx_tris, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                      PointEdgeDistance                                   *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
// Computes the squared euclidean distance of each p in points to the closest
 | 
			
		||||
// mesh edge belonging to the corresponding example in the batch of size N.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    points_first_idx: LongTensor of shape (N,) indicating the first point
 | 
			
		||||
//         index for each example in the batch
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
 | 
			
		||||
//        segment is spanned by (segms[s, 0], segms[s, 1])
 | 
			
		||||
//    segms_first_idx: LongTensor of shape (N,) indicating the first edge
 | 
			
		||||
//        index for each example in the batch
 | 
			
		||||
//    max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
 | 
			
		||||
//        the maximum number of points in the batch and is used to set
 | 
			
		||||
//        the grid dimensions in the CUDA implementation.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    dists: FloatTensor of shape (P,), where dists[p] is the squared euclidean
 | 
			
		||||
//        distance of points[p] to the closest edge in the same example in the
 | 
			
		||||
//        batch.
 | 
			
		||||
//    idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest
 | 
			
		||||
//        edge in the batch.
 | 
			
		||||
//        So, dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1]),
 | 
			
		||||
//        where d(u, v0, v1) is the distance of u from the segment spanned by
 | 
			
		||||
//        (v0, v1).
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_points);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_points);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_points) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(points_first_idx);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(segms_first_idx);
 | 
			
		||||
    return PointEdgeDistanceForwardCuda(
 | 
			
		||||
        points, points_first_idx, segms, segms_first_idx, max_points);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return PointEdgeDistanceForwardCpu(
 | 
			
		||||
      points, points_first_idx, segms, segms_first_idx, max_points);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Backward pass for PointEdgeDistance.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//    idx_points: LongTensor of shape (P,) containing the indices
 | 
			
		||||
//        of the closest edge in the example in the batch.
 | 
			
		||||
//        This is computed by the forward pass.
 | 
			
		||||
//    grad_dists: FloatTensor of shape (P,)
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    grad_points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    grad_segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_points,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_points,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_points,
 | 
			
		||||
    const torch::Tensor& grad_dists) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(idx_points);
 | 
			
		||||
    CHECK_CUDA(grad_dists);
 | 
			
		||||
    return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                      EdgePointDistance                                   *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
// Computes the squared euclidean distance of each edge segment to the closest
 | 
			
		||||
// point belonging to the corresponding example in the batch of size N.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    points_first_idx: LongTensor of shape (N,) indicating the first point
 | 
			
		||||
//         index for each example in the batch
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
 | 
			
		||||
//        segment is spanned by (segms[s, 0], segms[s, 1])
 | 
			
		||||
//    segms_first_idx: LongTensor of shape (N,) indicating the first edge
 | 
			
		||||
//        index for each example in the batch
 | 
			
		||||
//    max_segms: Scalar equal to max(S_i) for i in [0, N - 1] containing
 | 
			
		||||
//        the maximum number of edges in the batch and is used to set
 | 
			
		||||
//        the block dimensions in the CUDA implementation.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    dists: FloatTensor of shape (S,), where dists[s] is the squared
 | 
			
		||||
//        euclidean distance of s-th edge to the closest point in the
 | 
			
		||||
//        corresponding example in the batch.
 | 
			
		||||
//    idxs: LongTensor of shape (S,), where idxs[s] is the index of the closest
 | 
			
		||||
//        point in the example in the batch.
 | 
			
		||||
//        So, dists[s] = d(points[idxs[s]], segms[s, 0], segms[s, 1]), where
 | 
			
		||||
//        d(u, v0, v1) is the distance of u from the segment spanned by (v0, v1)
 | 
			
		||||
//
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_segms);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_segms);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_segms) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(points_first_idx);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(segms_first_idx);
 | 
			
		||||
    return EdgePointDistanceForwardCuda(
 | 
			
		||||
        points, points_first_idx, segms, segms_first_idx, max_segms);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return EdgePointDistanceForwardCpu(
 | 
			
		||||
      points, points_first_idx, segms, segms_first_idx, max_segms);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Backward pass for EdgePointDistance.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//    idx_segms: LongTensor of shape (S,) containing the indices
 | 
			
		||||
//        of the closest point in the example in the batch.
 | 
			
		||||
//        This is computed by the forward pass
 | 
			
		||||
//    grad_dists: FloatTensor of shape (S,)
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    grad_points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    grad_segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_segms,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_segms,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_segms,
 | 
			
		||||
    const torch::Tensor& grad_dists) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(idx_segms);
 | 
			
		||||
    CHECK_CUDA(grad_dists);
 | 
			
		||||
    return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                       PointFaceArrayDistance                             *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
@ -328,3 +564,92 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
 | 
			
		||||
  }
 | 
			
		||||
  return PointFaceArrayDistanceBackwardCpu(points, tris, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                          PointEdgeArrayDistance                          *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
// Computes the squared euclidean distance of each p in points to each edge
 | 
			
		||||
// segment in segms.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th
 | 
			
		||||
//        edge segment is spanned by (segms[s, 0], segms[s, 1])
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    dists: FloatTensor of shape (P, S), where dists[p, s] is the squared
 | 
			
		||||
//        euclidean distance of points[p] to the segment spanned by
 | 
			
		||||
//        (segms[s, 0], segms[s, 1])
 | 
			
		||||
//
 | 
			
		||||
// For pointcloud and meshes of batch size N, this function requires N
 | 
			
		||||
// computations. The memory occupied is O(NPS) which can become quite large.
 | 
			
		||||
// For example, a medium sized batch with N = 32 with P = 10000 and S = 5000
 | 
			
		||||
// will require for the forward pass 5.8G of memory to store dists.
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
torch::Tensor PointEdgeArrayDistanceForwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
torch::Tensor PointEdgeArrayDistanceForwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms);
 | 
			
		||||
 | 
			
		||||
torch::Tensor PointEdgeArrayDistanceForward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    return PointEdgeArrayDistanceForwardCuda(points, segms);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return PointEdgeArrayDistanceForwardCpu(points, segms);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Backward pass for PointEdgeArrayDistance.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//    grad_dists: FloatTensor of shape (P, S)
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//   grad_points: FloatTensor of shape (P, 3)
 | 
			
		||||
//   grad_segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& grad_dists) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(grad_dists);
 | 
			
		||||
    return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
@ -1,651 +0,0 @@
 | 
			
		||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <list>
 | 
			
		||||
#include <queue>
 | 
			
		||||
#include <tuple>
 | 
			
		||||
#include "utils/float_math.cuh"
 | 
			
		||||
#include "utils/geometry_utils.cuh"
 | 
			
		||||
#include "utils/warp_reduce.cuh"
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                          PointEdgeDistance                               *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
__global__ void PointEdgeForwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
    const int64_t* __restrict__ points_first_idx, // (B,)
 | 
			
		||||
    const float* __restrict__ segms, // (S, 2, 3)
 | 
			
		||||
    const int64_t* __restrict__ segms_first_idx, // (B,)
 | 
			
		||||
    float* __restrict__ dist_points, // (P,)
 | 
			
		||||
    int64_t* __restrict__ idx_points, // (P,)
 | 
			
		||||
    const size_t B,
 | 
			
		||||
    const size_t P,
 | 
			
		||||
    const size_t S) {
 | 
			
		||||
  float3* points_f3 = (float3*)points;
 | 
			
		||||
  float3* segms_f3 = (float3*)segms;
 | 
			
		||||
 | 
			
		||||
  // Single shared memory buffer which is split and cast to different types.
 | 
			
		||||
  extern __shared__ char shared_buf[];
 | 
			
		||||
  float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
 | 
			
		||||
  int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
 | 
			
		||||
 | 
			
		||||
  const size_t batch_idx = blockIdx.y; // index of batch element.
 | 
			
		||||
 | 
			
		||||
  // start and end for points in batch
 | 
			
		||||
  const int64_t startp = points_first_idx[batch_idx];
 | 
			
		||||
  const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
 | 
			
		||||
 | 
			
		||||
  // start and end for segments in batch_idx
 | 
			
		||||
  const int64_t starts = segms_first_idx[batch_idx];
 | 
			
		||||
  const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S;
 | 
			
		||||
 | 
			
		||||
  const size_t i = blockIdx.x; // index of point within batch element.
 | 
			
		||||
  const size_t tid = threadIdx.x; // thread idx
 | 
			
		||||
 | 
			
		||||
  // Each block will compute one element of the output idx_points[startp + i],
 | 
			
		||||
  // dist_points[startp + i]. Within the block we will use threads to compute
 | 
			
		||||
  // the distances between points[startp + i] and segms[j] for all j belonging
 | 
			
		||||
  // in the same batch as i, i.e. j in [starts, ends]. Then use a block
 | 
			
		||||
  // reduction to take an argmin of the distances.
 | 
			
		||||
 | 
			
		||||
  // If i exceeds the number of points in batch_idx, then do nothing
 | 
			
		||||
  if (i < (endp - startp)) {
 | 
			
		||||
    // Retrieve (startp + i) point
 | 
			
		||||
    const float3 p_f3 = points_f3[startp + i];
 | 
			
		||||
 | 
			
		||||
    // Compute the distances between points[startp + i] and segms[j] for
 | 
			
		||||
    // all j belonging in the same batch as i, i.e. j in [starts, ends].
 | 
			
		||||
    // Here each thread will reduce over (ends-starts) / blockDim.x in serial,
 | 
			
		||||
    // and store its result to shared memory
 | 
			
		||||
    float min_dist = FLT_MAX;
 | 
			
		||||
    size_t min_idx = 0;
 | 
			
		||||
    for (size_t j = tid; j < (ends - starts); j += blockDim.x) {
 | 
			
		||||
      const float3 v0 = segms_f3[(starts + j) * 2 + 0];
 | 
			
		||||
      const float3 v1 = segms_f3[(starts + j) * 2 + 1];
 | 
			
		||||
      float dist = PointLine3DistanceForward(p_f3, v0, v1);
 | 
			
		||||
      min_dist = (j == tid) ? dist : min_dist;
 | 
			
		||||
      min_idx = (dist <= min_dist) ? (starts + j) : min_idx;
 | 
			
		||||
      min_dist = (dist <= min_dist) ? dist : min_dist;
 | 
			
		||||
    }
 | 
			
		||||
    min_dists[tid] = min_dist;
 | 
			
		||||
    min_idxs[tid] = min_idx;
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // Perform reduction in shared memory.
 | 
			
		||||
    for (int s = blockDim.x / 2; s > 32; s >>= 1) {
 | 
			
		||||
      if (tid < s) {
 | 
			
		||||
        if (min_dists[tid] > min_dists[tid + s]) {
 | 
			
		||||
          min_dists[tid] = min_dists[tid + s];
 | 
			
		||||
          min_idxs[tid] = min_idxs[tid + s];
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      __syncthreads();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Unroll the last 6 iterations of the loop since they will happen
 | 
			
		||||
    // synchronized within a single warp.
 | 
			
		||||
    if (tid < 32)
 | 
			
		||||
      WarpReduce<float>(min_dists, min_idxs, tid);
 | 
			
		||||
 | 
			
		||||
    // Finally thread 0 writes the result to the output buffer.
 | 
			
		||||
    if (tid == 0) {
 | 
			
		||||
      idx_points[startp + i] = min_idxs[0];
 | 
			
		||||
      dist_points[startp + i] = min_dists[0];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& points_first_idx,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_points) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg points_t{points, "points", 1},
 | 
			
		||||
      points_first_idx_t{points_first_idx, "points_first_idx", 2},
 | 
			
		||||
      segms_t{segms, "segms", 3},
 | 
			
		||||
      segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
 | 
			
		||||
  at::CheckedFrom c = "PointEdgeDistanceForwardCuda";
 | 
			
		||||
  at::checkAllSameGPU(
 | 
			
		||||
      c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
 | 
			
		||||
  at::checkAllSameType(c, {points_t, segms_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(points.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t P = points.size(0);
 | 
			
		||||
  const int64_t S = segms.size(0);
 | 
			
		||||
  const int64_t B = points_first_idx.size(0);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (segms.size(1) == 2) && (segms.size(2) == 3),
 | 
			
		||||
      "segms must be of shape Sx2x3");
 | 
			
		||||
  TORCH_CHECK(segms_first_idx.size(0) == B);
 | 
			
		||||
 | 
			
		||||
  // clang-format off
 | 
			
		||||
  at::Tensor dists = at::zeros({P,}, points.options());
 | 
			
		||||
  at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
 | 
			
		||||
  // clang-format on
 | 
			
		||||
 | 
			
		||||
  if (dists.numel() == 0) {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
    return std::make_tuple(dists, idxs);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int threads = 128;
 | 
			
		||||
  const dim3 blocks(max_points, B);
 | 
			
		||||
  size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
 | 
			
		||||
 | 
			
		||||
  PointEdgeForwardKernel<<<blocks, threads, shared_size, stream>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      points_first_idx.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      segms.contiguous().data_ptr<float>(),
 | 
			
		||||
      segms_first_idx.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      dists.data_ptr<float>(),
 | 
			
		||||
      idxs.data_ptr<int64_t>(),
 | 
			
		||||
      B,
 | 
			
		||||
      P,
 | 
			
		||||
      S);
 | 
			
		||||
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  return std::make_tuple(dists, idxs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void PointEdgeBackwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
    const float* __restrict__ segms, // (S, 2, 3)
 | 
			
		||||
    const int64_t* __restrict__ idx_points, // (P,)
 | 
			
		||||
    const float* __restrict__ grad_dists, // (P,)
 | 
			
		||||
    float* __restrict__ grad_points, // (P, 3)
 | 
			
		||||
    float* __restrict__ grad_segms, // (S, 2, 3)
 | 
			
		||||
    const size_t P) {
 | 
			
		||||
  float3* points_f3 = (float3*)points;
 | 
			
		||||
  float3* segms_f3 = (float3*)segms;
 | 
			
		||||
 | 
			
		||||
  const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
  const size_t stride = gridDim.x * blockDim.x;
 | 
			
		||||
 | 
			
		||||
  for (size_t p = tid; p < P; p += stride) {
 | 
			
		||||
    const float3 p_f3 = points_f3[p];
 | 
			
		||||
 | 
			
		||||
    const int64_t sidx = idx_points[p];
 | 
			
		||||
    const float3 v0 = segms_f3[sidx * 2 + 0];
 | 
			
		||||
    const float3 v1 = segms_f3[sidx * 2 + 1];
 | 
			
		||||
 | 
			
		||||
    const float grad_dist = grad_dists[p];
 | 
			
		||||
 | 
			
		||||
    const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist);
 | 
			
		||||
    const float3 grad_point = thrust::get<0>(grads);
 | 
			
		||||
    const float3 grad_v0 = thrust::get<1>(grads);
 | 
			
		||||
    const float3 grad_v1 = thrust::get<2>(grads);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 0, grad_point.x);
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 1, grad_point.y);
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 2, grad_point.z);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 0, grad_v0.x);
 | 
			
		||||
    atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 1, grad_v0.y);
 | 
			
		||||
    atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 2, grad_v0.z);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 0, grad_v1.x);
 | 
			
		||||
    atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 1, grad_v1.y);
 | 
			
		||||
    atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 2, grad_v1.z);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& idx_points,
 | 
			
		||||
    const at::Tensor& grad_dists) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg points_t{points, "points", 1},
 | 
			
		||||
      idx_points_t{idx_points, "idx_points", 2}, segms_t{segms, "segms", 3},
 | 
			
		||||
      grad_dists_t{grad_dists, "grad_dists", 4};
 | 
			
		||||
  at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
 | 
			
		||||
  at::checkAllSameGPU(c, {points_t, idx_points_t, segms_t, grad_dists_t});
 | 
			
		||||
  at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(points.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t P = points.size(0);
 | 
			
		||||
  const int64_t S = segms.size(0);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (segms.size(1) == 2) && (segms.size(2) == 3),
 | 
			
		||||
      "segms must be of shape Sx2x3");
 | 
			
		||||
  TORCH_CHECK(idx_points.size(0) == P);
 | 
			
		||||
  TORCH_CHECK(grad_dists.size(0) == P);
 | 
			
		||||
 | 
			
		||||
  // clang-format off
 | 
			
		||||
  at::Tensor grad_points = at::zeros({P, 3}, points.options());
 | 
			
		||||
  at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
 | 
			
		||||
  // clang-format on
 | 
			
		||||
 | 
			
		||||
  if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
    return std::make_tuple(grad_points, grad_segms);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int blocks = 64;
 | 
			
		||||
  const int threads = 512;
 | 
			
		||||
 | 
			
		||||
  PointEdgeBackwardKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      segms.contiguous().data_ptr<float>(),
 | 
			
		||||
      idx_points.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      grad_dists.contiguous().data_ptr<float>(),
 | 
			
		||||
      grad_points.data_ptr<float>(),
 | 
			
		||||
      grad_segms.data_ptr<float>(),
 | 
			
		||||
      P);
 | 
			
		||||
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  return std::make_tuple(grad_points, grad_segms);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                          EdgePointDistance                               *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
__global__ void EdgePointForwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
    const int64_t* __restrict__ points_first_idx, // (B,)
 | 
			
		||||
    const float* __restrict__ segms, // (S, 2, 3)
 | 
			
		||||
    const int64_t* __restrict__ segms_first_idx, // (B,)
 | 
			
		||||
    float* __restrict__ dist_segms, // (S,)
 | 
			
		||||
    int64_t* __restrict__ idx_segms, // (S,)
 | 
			
		||||
    const size_t B,
 | 
			
		||||
    const size_t P,
 | 
			
		||||
    const size_t S) {
 | 
			
		||||
  float3* points_f3 = (float3*)points;
 | 
			
		||||
  float3* segms_f3 = (float3*)segms;
 | 
			
		||||
 | 
			
		||||
  // Single shared memory buffer which is split and cast to different types.
 | 
			
		||||
  extern __shared__ char shared_buf[];
 | 
			
		||||
  float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
 | 
			
		||||
  int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
 | 
			
		||||
 | 
			
		||||
  const size_t batch_idx = blockIdx.y; // index of batch element.
 | 
			
		||||
 | 
			
		||||
  // start and end for points in batch_idx
 | 
			
		||||
  const int64_t startp = points_first_idx[batch_idx];
 | 
			
		||||
  const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
 | 
			
		||||
 | 
			
		||||
  // start and end for segms in batch_idx
 | 
			
		||||
  const int64_t starts = segms_first_idx[batch_idx];
 | 
			
		||||
  const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S;
 | 
			
		||||
 | 
			
		||||
  const size_t i = blockIdx.x; // index of point within batch element.
 | 
			
		||||
  const size_t tid = threadIdx.x; // thread index
 | 
			
		||||
 | 
			
		||||
  // Each block will compute one element of the output idx_segms[starts + i],
 | 
			
		||||
  // dist_segms[starts + i]. Within the block we will use threads to compute
 | 
			
		||||
  // the distances between segms[starts + i] and points[j] for all j belonging
 | 
			
		||||
  // in the same batch as i, i.e. j in [startp, endp]. Then use a block
 | 
			
		||||
  // reduction to take an argmin of the distances.
 | 
			
		||||
 | 
			
		||||
  // If i exceeds the number of segms in batch_idx, then do nothing
 | 
			
		||||
  if (i < (ends - starts)) {
 | 
			
		||||
    const float3 v0 = segms_f3[(starts + i) * 2 + 0];
 | 
			
		||||
    const float3 v1 = segms_f3[(starts + i) * 2 + 1];
 | 
			
		||||
 | 
			
		||||
    // Compute the distances between segms[starts + i] and points[j] for
 | 
			
		||||
    // all j belonging in the same batch as i, i.e. j in [startp, endp].
 | 
			
		||||
    // Here each thread will reduce over (endp-startp) / blockDim.x in serial,
 | 
			
		||||
    // and store its result to shared memory
 | 
			
		||||
    float min_dist = FLT_MAX;
 | 
			
		||||
    size_t min_idx = 0;
 | 
			
		||||
    for (size_t j = tid; j < (endp - startp); j += blockDim.x) {
 | 
			
		||||
      // Retrieve (startp + i) point
 | 
			
		||||
      const float3 p_f3 = points_f3[startp + j];
 | 
			
		||||
 | 
			
		||||
      float dist = PointLine3DistanceForward(p_f3, v0, v1);
 | 
			
		||||
      min_dist = (j == tid) ? dist : min_dist;
 | 
			
		||||
      min_idx = (dist <= min_dist) ? (startp + j) : min_idx;
 | 
			
		||||
      min_dist = (dist <= min_dist) ? dist : min_dist;
 | 
			
		||||
    }
 | 
			
		||||
    min_dists[tid] = min_dist;
 | 
			
		||||
    min_idxs[tid] = min_idx;
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // Perform reduction in shared memory.
 | 
			
		||||
    for (int s = blockDim.x / 2; s > 32; s >>= 1) {
 | 
			
		||||
      if (tid < s) {
 | 
			
		||||
        if (min_dists[tid] > min_dists[tid + s]) {
 | 
			
		||||
          min_dists[tid] = min_dists[tid + s];
 | 
			
		||||
          min_idxs[tid] = min_idxs[tid + s];
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      __syncthreads();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Unroll the last 6 iterations of the loop since they will happen
 | 
			
		||||
    // synchronized within a single warp.
 | 
			
		||||
    if (tid < 32)
 | 
			
		||||
      WarpReduce<float>(min_dists, min_idxs, tid);
 | 
			
		||||
 | 
			
		||||
    // Finally thread 0 writes the result to the output buffer.
 | 
			
		||||
    if (tid == 0) {
 | 
			
		||||
      idx_segms[starts + i] = min_idxs[0];
 | 
			
		||||
      dist_segms[starts + i] = min_dists[0];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& points_first_idx,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_segms) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg points_t{points, "points", 1},
 | 
			
		||||
      points_first_idx_t{points_first_idx, "points_first_idx", 2},
 | 
			
		||||
      segms_t{segms, "segms", 3},
 | 
			
		||||
      segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
 | 
			
		||||
  at::CheckedFrom c = "EdgePointDistanceForwardCuda";
 | 
			
		||||
  at::checkAllSameGPU(
 | 
			
		||||
      c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
 | 
			
		||||
  at::checkAllSameType(c, {points_t, segms_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(points.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t P = points.size(0);
 | 
			
		||||
  const int64_t S = segms.size(0);
 | 
			
		||||
  const int64_t B = points_first_idx.size(0);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (segms.size(1) == 2) && (segms.size(2) == 3),
 | 
			
		||||
      "segms must be of shape Sx2x3");
 | 
			
		||||
  TORCH_CHECK(segms_first_idx.size(0) == B);
 | 
			
		||||
 | 
			
		||||
  // clang-format off
 | 
			
		||||
  at::Tensor dists = at::zeros({S,}, segms.options());
 | 
			
		||||
  at::Tensor idxs = at::zeros({S,}, segms_first_idx.options());
 | 
			
		||||
  // clang-format on
 | 
			
		||||
 | 
			
		||||
  if (dists.numel() == 0) {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
    return std::make_tuple(dists, idxs);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int threads = 128;
 | 
			
		||||
  const dim3 blocks(max_segms, B);
 | 
			
		||||
  size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
 | 
			
		||||
 | 
			
		||||
  EdgePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      points_first_idx.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      segms.contiguous().data_ptr<float>(),
 | 
			
		||||
      segms_first_idx.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      dists.data_ptr<float>(),
 | 
			
		||||
      idxs.data_ptr<int64_t>(),
 | 
			
		||||
      B,
 | 
			
		||||
      P,
 | 
			
		||||
      S);
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  return std::make_tuple(dists, idxs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void EdgePointBackwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
    const float* __restrict__ segms, // (S, 2, 3)
 | 
			
		||||
    const int64_t* __restrict__ idx_segms, // (S,)
 | 
			
		||||
    const float* __restrict__ grad_dists, // (S,)
 | 
			
		||||
    float* __restrict__ grad_points, // (P, 3)
 | 
			
		||||
    float* __restrict__ grad_segms, // (S, 2, 3)
 | 
			
		||||
    const size_t S) {
 | 
			
		||||
  float3* points_f3 = (float3*)points;
 | 
			
		||||
  float3* segms_f3 = (float3*)segms;
 | 
			
		||||
 | 
			
		||||
  const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
  const size_t stride = gridDim.x * blockDim.x;
 | 
			
		||||
 | 
			
		||||
  for (size_t s = tid; s < S; s += stride) {
 | 
			
		||||
    const float3 v0 = segms_f3[s * 2 + 0];
 | 
			
		||||
    const float3 v1 = segms_f3[s * 2 + 1];
 | 
			
		||||
 | 
			
		||||
    const int64_t pidx = idx_segms[s];
 | 
			
		||||
 | 
			
		||||
    const float3 p_f3 = points_f3[pidx];
 | 
			
		||||
 | 
			
		||||
    const float grad_dist = grad_dists[s];
 | 
			
		||||
 | 
			
		||||
    const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist);
 | 
			
		||||
    const float3 grad_point = thrust::get<0>(grads);
 | 
			
		||||
    const float3 grad_v0 = thrust::get<1>(grads);
 | 
			
		||||
    const float3 grad_v1 = thrust::get<2>(grads);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_points + pidx * 3 + 0, grad_point.x);
 | 
			
		||||
    atomicAdd(grad_points + pidx * 3 + 1, grad_point.y);
 | 
			
		||||
    atomicAdd(grad_points + pidx * 3 + 2, grad_point.z);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_v0.x);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_v0.y);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_v0.z);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_v1.x);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_v1.y);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_v1.z);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& idx_segms,
 | 
			
		||||
    const at::Tensor& grad_dists) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg points_t{points, "points", 1},
 | 
			
		||||
      idx_segms_t{idx_segms, "idx_segms", 2}, segms_t{segms, "segms", 3},
 | 
			
		||||
      grad_dists_t{grad_dists, "grad_dists", 4};
 | 
			
		||||
  at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
 | 
			
		||||
  at::checkAllSameGPU(c, {points_t, idx_segms_t, segms_t, grad_dists_t});
 | 
			
		||||
  at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(points.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t P = points.size(0);
 | 
			
		||||
  const int64_t S = segms.size(0);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (segms.size(1) == 2) && (segms.size(2) == 3),
 | 
			
		||||
      "segms must be of shape Sx2x3");
 | 
			
		||||
  TORCH_CHECK(idx_segms.size(0) == S);
 | 
			
		||||
  TORCH_CHECK(grad_dists.size(0) == S);
 | 
			
		||||
 | 
			
		||||
  // clang-format off
 | 
			
		||||
  at::Tensor grad_points = at::zeros({P, 3}, points.options());
 | 
			
		||||
  at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
 | 
			
		||||
  // clang-format on
 | 
			
		||||
 | 
			
		||||
  const int blocks = 64;
 | 
			
		||||
  const int threads = 512;
 | 
			
		||||
 | 
			
		||||
  EdgePointBackwardKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      segms.contiguous().data_ptr<float>(),
 | 
			
		||||
      idx_segms.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
      grad_dists.contiguous().data_ptr<float>(),
 | 
			
		||||
      grad_points.data_ptr<float>(),
 | 
			
		||||
      grad_segms.data_ptr<float>(),
 | 
			
		||||
      S);
 | 
			
		||||
 | 
			
		||||
  return std::make_tuple(grad_points, grad_segms);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                     PointEdgeArrayDistance                               *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
__global__ void PointEdgeArrayForwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
    const float* __restrict__ segms, // (S, 2, 3)
 | 
			
		||||
    float* __restrict__ dists, // (P, S)
 | 
			
		||||
    const size_t P,
 | 
			
		||||
    const size_t S) {
 | 
			
		||||
  float3* points_f3 = (float3*)points;
 | 
			
		||||
  float3* segms_f3 = (float3*)segms;
 | 
			
		||||
 | 
			
		||||
  // Parallelize over P * S computations
 | 
			
		||||
  const int num_threads = gridDim.x * blockDim.x;
 | 
			
		||||
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
  for (int t_i = tid; t_i < P * S; t_i += num_threads) {
 | 
			
		||||
    const int s = t_i / P; // segment index.
 | 
			
		||||
    const int p = t_i % P; // point index
 | 
			
		||||
    float3 a = segms_f3[s * 2 + 0];
 | 
			
		||||
    float3 b = segms_f3[s * 2 + 1];
 | 
			
		||||
 | 
			
		||||
    float3 point = points_f3[p];
 | 
			
		||||
    float dist = PointLine3DistanceForward(point, a, b);
 | 
			
		||||
    dists[p * S + s] = dist;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
at::Tensor PointEdgeArrayDistanceForwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& segms) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2};
 | 
			
		||||
  at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda";
 | 
			
		||||
  at::checkAllSameGPU(c, {points_t, segms_t});
 | 
			
		||||
  at::checkAllSameType(c, {points_t, segms_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(points.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t P = points.size(0);
 | 
			
		||||
  const int64_t S = segms.size(0);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (segms.size(1) == 2) && (segms.size(2) == 3),
 | 
			
		||||
      "segms must be of shape Sx2x3");
 | 
			
		||||
 | 
			
		||||
  at::Tensor dists = at::zeros({P, S}, points.options());
 | 
			
		||||
 | 
			
		||||
  if (dists.numel() == 0) {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
    return dists;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const size_t blocks = 1024;
 | 
			
		||||
  const size_t threads = 64;
 | 
			
		||||
 | 
			
		||||
  PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      segms.contiguous().data_ptr<float>(),
 | 
			
		||||
      dists.data_ptr<float>(),
 | 
			
		||||
      P,
 | 
			
		||||
      S);
 | 
			
		||||
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  return dists;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void PointEdgeArrayBackwardKernel(
 | 
			
		||||
    const float* __restrict__ points, // (P, 3)
 | 
			
		||||
    const float* __restrict__ segms, // (S, 2, 3)
 | 
			
		||||
    const float* __restrict__ grad_dists, // (P, S)
 | 
			
		||||
    float* __restrict__ grad_points, // (P, 3)
 | 
			
		||||
    float* __restrict__ grad_segms, // (S, 2, 3)
 | 
			
		||||
    const size_t P,
 | 
			
		||||
    const size_t S) {
 | 
			
		||||
  float3* points_f3 = (float3*)points;
 | 
			
		||||
  float3* segms_f3 = (float3*)segms;
 | 
			
		||||
 | 
			
		||||
  // Parallelize over P * S computations
 | 
			
		||||
  const int num_threads = gridDim.x * blockDim.x;
 | 
			
		||||
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
  for (int t_i = tid; t_i < P * S; t_i += num_threads) {
 | 
			
		||||
    const int s = t_i / P; // segment index.
 | 
			
		||||
    const int p = t_i % P; // point index
 | 
			
		||||
    const float3 a = segms_f3[s * 2 + 0];
 | 
			
		||||
    const float3 b = segms_f3[s * 2 + 1];
 | 
			
		||||
 | 
			
		||||
    const float3 point = points_f3[p];
 | 
			
		||||
    const float grad_dist = grad_dists[p * S + s];
 | 
			
		||||
    const auto grads = PointLine3DistanceBackward(point, a, b, grad_dist);
 | 
			
		||||
    const float3 grad_point = thrust::get<0>(grads);
 | 
			
		||||
    const float3 grad_a = thrust::get<1>(grads);
 | 
			
		||||
    const float3 grad_b = thrust::get<2>(grads);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 0, grad_point.x);
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 1, grad_point.y);
 | 
			
		||||
    atomicAdd(grad_points + p * 3 + 2, grad_point.z);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_a.x);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_a.y);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_a.z);
 | 
			
		||||
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_b.x);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_b.y);
 | 
			
		||||
    atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_b.z);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
 | 
			
		||||
    const at::Tensor& points,
 | 
			
		||||
    const at::Tensor& segms,
 | 
			
		||||
    const at::Tensor& grad_dists) {
 | 
			
		||||
  // Check inputs are on the same device
 | 
			
		||||
  at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2},
 | 
			
		||||
      grad_dists_t{grad_dists, "grad_dists", 3};
 | 
			
		||||
  at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
 | 
			
		||||
  at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
 | 
			
		||||
  at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
 | 
			
		||||
 | 
			
		||||
  // Set the device for the kernel launch based on the device of the input
 | 
			
		||||
  at::cuda::CUDAGuard device_guard(points.device());
 | 
			
		||||
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  const int64_t P = points.size(0);
 | 
			
		||||
  const int64_t S = segms.size(0);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      (segms.size(1) == 2) && (segms.size(2) == 3),
 | 
			
		||||
      "segms must be of shape Sx2x3");
 | 
			
		||||
  TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
 | 
			
		||||
 | 
			
		||||
  at::Tensor grad_points = at::zeros({P, 3}, points.options());
 | 
			
		||||
  at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
 | 
			
		||||
 | 
			
		||||
  if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
    return std::make_tuple(grad_points, grad_segms);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const size_t blocks = 1024;
 | 
			
		||||
  const size_t threads = 64;
 | 
			
		||||
 | 
			
		||||
  PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      segms.contiguous().data_ptr<float>(),
 | 
			
		||||
      grad_dists.contiguous().data_ptr<float>(),
 | 
			
		||||
      grad_points.data_ptr<float>(),
 | 
			
		||||
      grad_segms.data_ptr<float>(),
 | 
			
		||||
      P,
 | 
			
		||||
      S);
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  return std::make_tuple(grad_points, grad_segms);
 | 
			
		||||
}
 | 
			
		||||
@ -1,332 +0,0 @@
 | 
			
		||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include <tuple>
 | 
			
		||||
#include "utils/pytorch3d_cutils.h"
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                      PointEdgeDistance                                   *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
// Computes the squared euclidean distance of each p in points to the closest
 | 
			
		||||
// mesh edge belonging to the corresponding example in the batch of size N.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    points_first_idx: LongTensor of shape (N,) indicating the first point
 | 
			
		||||
//         index for each example in the batch
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
 | 
			
		||||
//        segment is spanned by (segms[s, 0], segms[s, 1])
 | 
			
		||||
//    segms_first_idx: LongTensor of shape (N,) indicating the first edge
 | 
			
		||||
//        index for each example in the batch
 | 
			
		||||
//    max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
 | 
			
		||||
//        the maximum number of points in the batch and is used to set
 | 
			
		||||
//        the grid dimensions in the CUDA implementation.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    dists: FloatTensor of shape (P,), where dists[p] is the squared euclidean
 | 
			
		||||
//        distance of points[p] to the closest edge in the same example in the
 | 
			
		||||
//        batch.
 | 
			
		||||
//    idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest
 | 
			
		||||
//        edge in the batch.
 | 
			
		||||
//        So, dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1]),
 | 
			
		||||
//        where d(u, v0, v1) is the distance of u from the segment spanned by
 | 
			
		||||
//        (v0, v1).
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_points);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_points);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_points) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(points_first_idx);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(segms_first_idx);
 | 
			
		||||
    return PointEdgeDistanceForwardCuda(
 | 
			
		||||
        points, points_first_idx, segms, segms_first_idx, max_points);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return PointEdgeDistanceForwardCpu(
 | 
			
		||||
      points, points_first_idx, segms, segms_first_idx, max_points);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Backward pass for PointEdgeDistance.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//    idx_points: LongTensor of shape (P,) containing the indices
 | 
			
		||||
//        of the closest edge in the example in the batch.
 | 
			
		||||
//        This is computed by the forward pass.
 | 
			
		||||
//    grad_dists: FloatTensor of shape (P,)
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    grad_points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    grad_segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_points,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_points,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_points,
 | 
			
		||||
    const torch::Tensor& grad_dists) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(idx_points);
 | 
			
		||||
    CHECK_CUDA(grad_dists);
 | 
			
		||||
    return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                      EdgePointDistance                                   *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
// Computes the squared euclidean distance of each edge segment to the closest
 | 
			
		||||
// point belonging to the corresponding example in the batch of size N.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    points_first_idx: LongTensor of shape (N,) indicating the first point
 | 
			
		||||
//         index for each example in the batch
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
 | 
			
		||||
//        segment is spanned by (segms[s, 0], segms[s, 1])
 | 
			
		||||
//    segms_first_idx: LongTensor of shape (N,) indicating the first edge
 | 
			
		||||
//        index for each example in the batch
 | 
			
		||||
//    max_segms: Scalar equal to max(S_i) for i in [0, N - 1] containing
 | 
			
		||||
//        the maximum number of edges in the batch and is used to set
 | 
			
		||||
//        the block dimensions in the CUDA implementation.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    dists: FloatTensor of shape (S,), where dists[s] is the squared
 | 
			
		||||
//        euclidean distance of s-th edge to the closest point in the
 | 
			
		||||
//        corresponding example in the batch.
 | 
			
		||||
//    idxs: LongTensor of shape (S,), where idxs[s] is the index of the closest
 | 
			
		||||
//        point in the example in the batch.
 | 
			
		||||
//        So, dists[s] = d(points[idxs[s]], segms[s, 0], segms[s, 1]), where
 | 
			
		||||
//        d(u, v0, v1) is the distance of u from the segment spanned by (v0, v1)
 | 
			
		||||
//
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_segms);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_segms);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& points_first_idx,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& segms_first_idx,
 | 
			
		||||
    const int64_t max_segms) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(points_first_idx);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(segms_first_idx);
 | 
			
		||||
    return EdgePointDistanceForwardCuda(
 | 
			
		||||
        points, points_first_idx, segms, segms_first_idx, max_segms);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return EdgePointDistanceForwardCpu(
 | 
			
		||||
      points, points_first_idx, segms, segms_first_idx, max_segms);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Backward pass for EdgePointDistance.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//    idx_segms: LongTensor of shape (S,) containing the indices
 | 
			
		||||
//        of the closest point in the example in the batch.
 | 
			
		||||
//        This is computed by the forward pass
 | 
			
		||||
//    grad_dists: FloatTensor of shape (S,)
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    grad_points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    grad_segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_segms,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_segms,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& idx_segms,
 | 
			
		||||
    const torch::Tensor& grad_dists) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(idx_segms);
 | 
			
		||||
    CHECK_CUDA(grad_dists);
 | 
			
		||||
    return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
// *                          PointEdgeArrayDistance                          *
 | 
			
		||||
// ****************************************************************************
 | 
			
		||||
 | 
			
		||||
// Computes the squared euclidean distance of each p in points to each edge
 | 
			
		||||
// segment in segms.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th
 | 
			
		||||
//        edge segment is spanned by (segms[s, 0], segms[s, 1])
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    dists: FloatTensor of shape (P, S), where dists[p, s] is the squared
 | 
			
		||||
//        euclidean distance of points[p] to the segment spanned by
 | 
			
		||||
//        (segms[s, 0], segms[s, 1])
 | 
			
		||||
//
 | 
			
		||||
// For pointcloud and meshes of batch size N, this function requires N
 | 
			
		||||
// computations. The memory occupied is O(NPS) which can become quite large.
 | 
			
		||||
// For example, a medium sized batch with N = 32 with P = 10000 and S = 5000
 | 
			
		||||
// will require for the forward pass 5.8G of memory to store dists.
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
torch::Tensor PointEdgeArrayDistanceForwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
torch::Tensor PointEdgeArrayDistanceForwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms);
 | 
			
		||||
 | 
			
		||||
torch::Tensor PointEdgeArrayDistanceForward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    return PointEdgeArrayDistanceForwardCuda(points, segms);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return PointEdgeArrayDistanceForwardCpu(points, segms);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Backward pass for PointEdgeArrayDistance.
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    points: FloatTensor of shape (P, 3)
 | 
			
		||||
//    segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//    grad_dists: FloatTensor of shape (P, S)
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//   grad_points: FloatTensor of shape (P, 3)
 | 
			
		||||
//   grad_segms: FloatTensor of shape (S, 2, 3)
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCpu(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& grad_dists);
 | 
			
		||||
 | 
			
		||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
 | 
			
		||||
    const torch::Tensor& points,
 | 
			
		||||
    const torch::Tensor& segms,
 | 
			
		||||
    const torch::Tensor& grad_dists) {
 | 
			
		||||
  if (points.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CUDA(points);
 | 
			
		||||
    CHECK_CUDA(segms);
 | 
			
		||||
    CHECK_CUDA(grad_dists);
 | 
			
		||||
    return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
 | 
			
		||||
#else
 | 
			
		||||
    AT_ERROR("Not compiled with GPU support.");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
  return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user