mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Remove point mesh edge kernels
Summary: Removes the now-unnecessary kernels from point mesh edge file Migrates all point mesh functionality into one file. Reviewed By: gkioxari Differential Revision: D24550086 fbshipit-source-id: f924996cd38a7c2c1cf189d8a01611de4506cfa3
This commit is contained in:
parent
8dcfe30f66
commit
804235b05a
@ -15,8 +15,7 @@
|
||||
#include "interp_face_attrs/interp_face_attrs.h"
|
||||
#include "knn/knn.h"
|
||||
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
||||
#include "point_mesh/point_mesh_edge.h"
|
||||
#include "point_mesh/point_mesh_face.h"
|
||||
#include "point_mesh/point_mesh_cuda.h"
|
||||
#include "rasterize_meshes/rasterize_meshes.h"
|
||||
#include "rasterize_points/rasterize_points.h"
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
#include "utils/warp_reduce.cuh"
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceDistance *
|
||||
// * Generic Forward/Backward Kernels *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void DistanceForwardKernel(
|
||||
@ -202,16 +202,6 @@ std::tuple<at::Tensor, at::Tensor> DistanceForwardCuda(
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_points) {
|
||||
return DistanceForwardCuda(
|
||||
points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
|
||||
}
|
||||
|
||||
__global__ void DistanceBackwardKernel(
|
||||
const float* __restrict__ objects, // (O * oD * 3)
|
||||
const size_t objects_size, // O
|
||||
@ -365,6 +355,20 @@ std::tuple<at::Tensor, at::Tensor> DistanceBackwardCuda(
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_points) {
|
||||
return DistanceForwardCuda(
|
||||
points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& tris,
|
||||
@ -395,9 +399,54 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
|
||||
return DistanceBackwardCuda(tris, 3, points, 1, idx_tris, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& segms_first_idx,
|
||||
const int64_t max_points) {
|
||||
return DistanceForwardCuda(
|
||||
points, 1, points_first_idx, segms, 2, segms_first_idx, max_points);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& idx_points,
|
||||
const at::Tensor& grad_dists) {
|
||||
return DistanceBackwardCuda(points, 1, segms, 2, idx_points, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * EdgePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& segms_first_idx,
|
||||
const int64_t max_segms) {
|
||||
return DistanceForwardCuda(
|
||||
segms, 2, segms_first_idx, points, 1, points_first_idx, max_segms);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& idx_segms,
|
||||
const at::Tensor& grad_dists) {
|
||||
return DistanceBackwardCuda(segms, 2, points, 1, idx_segms, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceArrayDistance *
|
||||
// ****************************************************************************
|
||||
// TODO: Create wrapper function and merge kernel with other array kernel
|
||||
|
||||
__global__ void PointFaceArrayForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
@ -565,3 +614,164 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeArrayDistance *
|
||||
// ****************************************************************************
|
||||
// TODO: Create wrapper function and merge kernel with other array kernel
|
||||
|
||||
__global__ void PointEdgeArrayForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
float* __restrict__ dists, // (P, S)
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
const int p = t_i % P; // point index
|
||||
float3 a = segms_f3[s * 2 + 0];
|
||||
float3 b = segms_f3[s * 2 + 1];
|
||||
|
||||
float3 point = points_f3[p];
|
||||
float dist = PointLine3DistanceForward(point, a, b);
|
||||
dists[p * S + s] = dist;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2};
|
||||
at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, segms_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
|
||||
at::Tensor dists = at::zeros({P, S}, points.options());
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return dists;
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
segms.contiguous().data_ptr<float>(),
|
||||
dists.data_ptr<float>(),
|
||||
P,
|
||||
S);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return dists;
|
||||
}
|
||||
|
||||
__global__ void PointEdgeArrayBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const float* __restrict__ grad_dists, // (P, S)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_segms, // (S, 2, 3)
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
const int p = t_i % P; // point index
|
||||
const float3 a = segms_f3[s * 2 + 0];
|
||||
const float3 b = segms_f3[s * 2 + 1];
|
||||
|
||||
const float3 point = points_f3[p];
|
||||
const float grad_dist = grad_dists[p * S + s];
|
||||
const auto grads = PointLine3DistanceBackward(point, a, b, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_a = thrust::get<1>(grads);
|
||||
const float3 grad_b = thrust::get<2>(grads);
|
||||
|
||||
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_a.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_a.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_a.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_b.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_b.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_b.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2},
|
||||
grad_dists_t{grad_dists, "grad_dists", 3};
|
||||
at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
|
||||
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
|
||||
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
segms.contiguous().data_ptr<float>(),
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_segms.data_ptr<float>(),
|
||||
P,
|
||||
S);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
@ -241,6 +241,242 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
|
||||
return FacePointDistanceBackwardCpu(points, tris, idx_tris, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each p in points to the closest
|
||||
// mesh edge belonging to the corresponding example in the batch of size N.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// points_first_idx: LongTensor of shape (N,) indicating the first point
|
||||
// index for each example in the batch
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
|
||||
// segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
// segms_first_idx: LongTensor of shape (N,) indicating the first edge
|
||||
// index for each example in the batch
|
||||
// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
|
||||
// the maximum number of points in the batch and is used to set
|
||||
// the grid dimensions in the CUDA implementation.
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (P,), where dists[p] is the squared euclidean
|
||||
// distance of points[p] to the closest edge in the same example in the
|
||||
// batch.
|
||||
// idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest
|
||||
// edge in the batch.
|
||||
// So, dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1]),
|
||||
// where d(u, v0, v1) is the distance of u from the segment spanned by
|
||||
// (v0, v1).
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(points_first_idx);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(segms_first_idx);
|
||||
return PointEdgeDistanceForwardCuda(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PointEdgeDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
}
|
||||
|
||||
// Backward pass for PointEdgeDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// idx_points: LongTensor of shape (P,) containing the indices
|
||||
// of the closest edge in the example in the batch.
|
||||
// This is computed by the forward pass.
|
||||
// grad_dists: FloatTensor of shape (P,)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(idx_points);
|
||||
CHECK_CUDA(grad_dists);
|
||||
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * EdgePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each edge segment to the closest
|
||||
// point belonging to the corresponding example in the batch of size N.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// points_first_idx: LongTensor of shape (N,) indicating the first point
|
||||
// index for each example in the batch
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
|
||||
// segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
// segms_first_idx: LongTensor of shape (N,) indicating the first edge
|
||||
// index for each example in the batch
|
||||
// max_segms: Scalar equal to max(S_i) for i in [0, N - 1] containing
|
||||
// the maximum number of edges in the batch and is used to set
|
||||
// the block dimensions in the CUDA implementation.
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (S,), where dists[s] is the squared
|
||||
// euclidean distance of s-th edge to the closest point in the
|
||||
// corresponding example in the batch.
|
||||
// idxs: LongTensor of shape (S,), where idxs[s] is the index of the closest
|
||||
// point in the example in the batch.
|
||||
// So, dists[s] = d(points[idxs[s]], segms[s, 0], segms[s, 1]), where
|
||||
// d(u, v0, v1) is the distance of u from the segment spanned by (v0, v1)
|
||||
//
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(points_first_idx);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(segms_first_idx);
|
||||
return EdgePointDistanceForwardCuda(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return EdgePointDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
}
|
||||
|
||||
// Backward pass for EdgePointDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// idx_segms: LongTensor of shape (S,) containing the indices
|
||||
// of the closest point in the example in the batch.
|
||||
// This is computed by the forward pass
|
||||
// grad_dists: FloatTensor of shape (S,)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(idx_segms);
|
||||
CHECK_CUDA(grad_dists);
|
||||
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceArrayDistance *
|
||||
// ****************************************************************************
|
||||
@ -328,3 +564,92 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
|
||||
}
|
||||
return PointFaceArrayDistanceBackwardCpu(points, tris, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeArrayDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each p in points to each edge
|
||||
// segment in segms.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th
|
||||
// edge segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (P, S), where dists[p, s] is the squared
|
||||
// euclidean distance of points[p] to the segment spanned by
|
||||
// (segms[s, 0], segms[s, 1])
|
||||
//
|
||||
// For pointcloud and meshes of batch size N, this function requires N
|
||||
// computations. The memory occupied is O(NPS) which can become quite large.
|
||||
// For example, a medium sized batch with N = 32 with P = 10000 and S = 5000
|
||||
// will require for the forward pass 5.8G of memory to store dists.
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms);
|
||||
#endif
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms);
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(segms);
|
||||
return PointEdgeArrayDistanceForwardCuda(points, segms);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PointEdgeArrayDistanceForwardCpu(points, segms);
|
||||
}
|
||||
|
||||
// Backward pass for PointEdgeArrayDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// grad_dists: FloatTensor of shape (P, S)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(grad_dists);
|
||||
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
|
||||
}
|
@ -1,651 +0,0 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "utils/float_math.cuh"
|
||||
#include "utils/geometry_utils.cuh"
|
||||
#include "utils/warp_reduce.cuh"
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void PointEdgeForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const int64_t* __restrict__ points_first_idx, // (B,)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const int64_t* __restrict__ segms_first_idx, // (B,)
|
||||
float* __restrict__ dist_points, // (P,)
|
||||
int64_t* __restrict__ idx_points, // (P,)
|
||||
const size_t B,
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t batch_idx = blockIdx.y; // index of batch element.
|
||||
|
||||
// start and end for points in batch
|
||||
const int64_t startp = points_first_idx[batch_idx];
|
||||
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
|
||||
|
||||
// start and end for segments in batch_idx
|
||||
const int64_t starts = segms_first_idx[batch_idx];
|
||||
const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S;
|
||||
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
const size_t tid = threadIdx.x; // thread idx
|
||||
|
||||
// Each block will compute one element of the output idx_points[startp + i],
|
||||
// dist_points[startp + i]. Within the block we will use threads to compute
|
||||
// the distances between points[startp + i] and segms[j] for all j belonging
|
||||
// in the same batch as i, i.e. j in [starts, ends]. Then use a block
|
||||
// reduction to take an argmin of the distances.
|
||||
|
||||
// If i exceeds the number of points in batch_idx, then do nothing
|
||||
if (i < (endp - startp)) {
|
||||
// Retrieve (startp + i) point
|
||||
const float3 p_f3 = points_f3[startp + i];
|
||||
|
||||
// Compute the distances between points[startp + i] and segms[j] for
|
||||
// all j belonging in the same batch as i, i.e. j in [starts, ends].
|
||||
// Here each thread will reduce over (ends-starts) / blockDim.x in serial,
|
||||
// and store its result to shared memory
|
||||
float min_dist = FLT_MAX;
|
||||
size_t min_idx = 0;
|
||||
for (size_t j = tid; j < (ends - starts); j += blockDim.x) {
|
||||
const float3 v0 = segms_f3[(starts + j) * 2 + 0];
|
||||
const float3 v1 = segms_f3[(starts + j) * 2 + 1];
|
||||
float dist = PointLine3DistanceForward(p_f3, v0, v1);
|
||||
min_dist = (j == tid) ? dist : min_dist;
|
||||
min_idx = (dist <= min_dist) ? (starts + j) : min_idx;
|
||||
min_dist = (dist <= min_dist) ? dist : min_dist;
|
||||
}
|
||||
min_dists[tid] = min_dist;
|
||||
min_idxs[tid] = min_idx;
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
min_idxs[tid] = min_idxs[tid + s];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Unroll the last 6 iterations of the loop since they will happen
|
||||
// synchronized within a single warp.
|
||||
if (tid < 32)
|
||||
WarpReduce<float>(min_dists, min_idxs, tid);
|
||||
|
||||
// Finally thread 0 writes the result to the output buffer.
|
||||
if (tid == 0) {
|
||||
idx_points[startp + i] = min_idxs[0];
|
||||
dist_points[startp + i] = min_dists[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& segms_first_idx,
|
||||
const int64_t max_points) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
points_first_idx_t{points_first_idx, "points_first_idx", 2},
|
||||
segms_t{segms, "segms", 3},
|
||||
segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
|
||||
at::CheckedFrom c = "PointEdgeDistanceForwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
TORCH_CHECK(segms_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor dists = at::zeros({P,}, points.options());
|
||||
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_points, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
PointEdgeForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
points_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
segms.contiguous().data_ptr<float>(),
|
||||
segms_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
dists.data_ptr<float>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
B,
|
||||
P,
|
||||
S);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
__global__ void PointEdgeBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const int64_t* __restrict__ idx_points, // (P,)
|
||||
const float* __restrict__ grad_dists, // (P,)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_segms, // (S, 2, 3)
|
||||
const size_t P) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
for (size_t p = tid; p < P; p += stride) {
|
||||
const float3 p_f3 = points_f3[p];
|
||||
|
||||
const int64_t sidx = idx_points[p];
|
||||
const float3 v0 = segms_f3[sidx * 2 + 0];
|
||||
const float3 v1 = segms_f3[sidx * 2 + 1];
|
||||
|
||||
const float grad_dist = grad_dists[p];
|
||||
|
||||
const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_v0 = thrust::get<1>(grads);
|
||||
const float3 grad_v1 = thrust::get<2>(grads);
|
||||
|
||||
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 0, grad_v0.x);
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 1, grad_v0.y);
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 2, grad_v0.z);
|
||||
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 0, grad_v1.x);
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 1, grad_v1.y);
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 2, grad_v1.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& idx_points,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
idx_points_t{idx_points, "idx_points", 2}, segms_t{segms, "segms", 3},
|
||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||
at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idx_points_t, segms_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
TORCH_CHECK(idx_points.size(0) == P);
|
||||
TORCH_CHECK(grad_dists.size(0) == P);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
// clang-format on
|
||||
|
||||
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
PointEdgeBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
segms.contiguous().data_ptr<float>(),
|
||||
idx_points.contiguous().data_ptr<int64_t>(),
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_segms.data_ptr<float>(),
|
||||
P);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * EdgePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void EdgePointForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const int64_t* __restrict__ points_first_idx, // (B,)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const int64_t* __restrict__ segms_first_idx, // (B,)
|
||||
float* __restrict__ dist_segms, // (S,)
|
||||
int64_t* __restrict__ idx_segms, // (S,)
|
||||
const size_t B,
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t batch_idx = blockIdx.y; // index of batch element.
|
||||
|
||||
// start and end for points in batch_idx
|
||||
const int64_t startp = points_first_idx[batch_idx];
|
||||
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
|
||||
|
||||
// start and end for segms in batch_idx
|
||||
const int64_t starts = segms_first_idx[batch_idx];
|
||||
const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S;
|
||||
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
const size_t tid = threadIdx.x; // thread index
|
||||
|
||||
// Each block will compute one element of the output idx_segms[starts + i],
|
||||
// dist_segms[starts + i]. Within the block we will use threads to compute
|
||||
// the distances between segms[starts + i] and points[j] for all j belonging
|
||||
// in the same batch as i, i.e. j in [startp, endp]. Then use a block
|
||||
// reduction to take an argmin of the distances.
|
||||
|
||||
// If i exceeds the number of segms in batch_idx, then do nothing
|
||||
if (i < (ends - starts)) {
|
||||
const float3 v0 = segms_f3[(starts + i) * 2 + 0];
|
||||
const float3 v1 = segms_f3[(starts + i) * 2 + 1];
|
||||
|
||||
// Compute the distances between segms[starts + i] and points[j] for
|
||||
// all j belonging in the same batch as i, i.e. j in [startp, endp].
|
||||
// Here each thread will reduce over (endp-startp) / blockDim.x in serial,
|
||||
// and store its result to shared memory
|
||||
float min_dist = FLT_MAX;
|
||||
size_t min_idx = 0;
|
||||
for (size_t j = tid; j < (endp - startp); j += blockDim.x) {
|
||||
// Retrieve (startp + i) point
|
||||
const float3 p_f3 = points_f3[startp + j];
|
||||
|
||||
float dist = PointLine3DistanceForward(p_f3, v0, v1);
|
||||
min_dist = (j == tid) ? dist : min_dist;
|
||||
min_idx = (dist <= min_dist) ? (startp + j) : min_idx;
|
||||
min_dist = (dist <= min_dist) ? dist : min_dist;
|
||||
}
|
||||
min_dists[tid] = min_dist;
|
||||
min_idxs[tid] = min_idx;
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
min_idxs[tid] = min_idxs[tid + s];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Unroll the last 6 iterations of the loop since they will happen
|
||||
// synchronized within a single warp.
|
||||
if (tid < 32)
|
||||
WarpReduce<float>(min_dists, min_idxs, tid);
|
||||
|
||||
// Finally thread 0 writes the result to the output buffer.
|
||||
if (tid == 0) {
|
||||
idx_segms[starts + i] = min_idxs[0];
|
||||
dist_segms[starts + i] = min_dists[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& segms_first_idx,
|
||||
const int64_t max_segms) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
points_first_idx_t{points_first_idx, "points_first_idx", 2},
|
||||
segms_t{segms, "segms", 3},
|
||||
segms_first_idx_t{segms_first_idx, "segms_first_idx", 4};
|
||||
at::CheckedFrom c = "EdgePointDistanceForwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, points_first_idx_t, segms_t, segms_first_idx_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
TORCH_CHECK(segms_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor dists = at::zeros({S,}, segms.options());
|
||||
at::Tensor idxs = at::zeros({S,}, segms_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_segms, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
EdgePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
points_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
segms.contiguous().data_ptr<float>(),
|
||||
segms_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
dists.data_ptr<float>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
B,
|
||||
P,
|
||||
S);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
__global__ void EdgePointBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const int64_t* __restrict__ idx_segms, // (S,)
|
||||
const float* __restrict__ grad_dists, // (S,)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_segms, // (S, 2, 3)
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
for (size_t s = tid; s < S; s += stride) {
|
||||
const float3 v0 = segms_f3[s * 2 + 0];
|
||||
const float3 v1 = segms_f3[s * 2 + 1];
|
||||
|
||||
const int64_t pidx = idx_segms[s];
|
||||
|
||||
const float3 p_f3 = points_f3[pidx];
|
||||
|
||||
const float grad_dist = grad_dists[s];
|
||||
|
||||
const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_v0 = thrust::get<1>(grads);
|
||||
const float3 grad_v1 = thrust::get<2>(grads);
|
||||
|
||||
atomicAdd(grad_points + pidx * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + pidx * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + pidx * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_v0.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_v0.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_v0.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_v1.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_v1.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_v1.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& idx_segms,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
idx_segms_t{idx_segms, "idx_segms", 2}, segms_t{segms, "segms", 3},
|
||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||
at::CheckedFrom c = "PointEdgeDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idx_segms_t, segms_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
TORCH_CHECK(idx_segms.size(0) == S);
|
||||
TORCH_CHECK(grad_dists.size(0) == S);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
EdgePointBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
segms.contiguous().data_ptr<float>(),
|
||||
idx_segms.contiguous().data_ptr<int64_t>(),
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_segms.data_ptr<float>(),
|
||||
S);
|
||||
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeArrayDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void PointEdgeArrayForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
float* __restrict__ dists, // (P, S)
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
const int p = t_i % P; // point index
|
||||
float3 a = segms_f3[s * 2 + 0];
|
||||
float3 b = segms_f3[s * 2 + 1];
|
||||
|
||||
float3 point = points_f3[p];
|
||||
float dist = PointLine3DistanceForward(point, a, b);
|
||||
dists[p * S + s] = dist;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2};
|
||||
at::CheckedFrom c = "PointEdgeArrayDistanceForwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, segms_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
|
||||
at::Tensor dists = at::zeros({P, S}, points.options());
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return dists;
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
segms.contiguous().data_ptr<float>(),
|
||||
dists.data_ptr<float>(),
|
||||
P,
|
||||
S);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return dists;
|
||||
}
|
||||
|
||||
__global__ void PointEdgeArrayBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const float* __restrict__ grad_dists, // (P, S)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_segms, // (S, 2, 3)
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
const int p = t_i % P; // point index
|
||||
const float3 a = segms_f3[s * 2 + 0];
|
||||
const float3 b = segms_f3[s * 2 + 1];
|
||||
|
||||
const float3 point = points_f3[p];
|
||||
const float grad_dist = grad_dists[p * S + s];
|
||||
const auto grads = PointLine3DistanceBackward(point, a, b, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_a = thrust::get<1>(grads);
|
||||
const float3 grad_b = thrust::get<2>(grads);
|
||||
|
||||
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_a.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_a.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_a.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_b.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_b.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_b.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& segms,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, segms_t{segms, "segms", 2},
|
||||
grad_dists_t{grad_dists, "grad_dists", 3};
|
||||
at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
|
||||
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
at::Tensor grad_segms = at::zeros({S, 2, 3}, segms.options());
|
||||
|
||||
if (grad_points.numel() == 0 || grad_segms.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
segms.contiguous().data_ptr<float>(),
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_segms.data_ptr<float>(),
|
||||
P,
|
||||
S);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
@ -1,332 +0,0 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each p in points to the closest
|
||||
// mesh edge belonging to the corresponding example in the batch of size N.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// points_first_idx: LongTensor of shape (N,) indicating the first point
|
||||
// index for each example in the batch
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
|
||||
// segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
// segms_first_idx: LongTensor of shape (N,) indicating the first edge
|
||||
// index for each example in the batch
|
||||
// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
|
||||
// the maximum number of points in the batch and is used to set
|
||||
// the grid dimensions in the CUDA implementation.
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (P,), where dists[p] is the squared euclidean
|
||||
// distance of points[p] to the closest edge in the same example in the
|
||||
// batch.
|
||||
// idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest
|
||||
// edge in the batch.
|
||||
// So, dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1]),
|
||||
// where d(u, v0, v1) is the distance of u from the segment spanned by
|
||||
// (v0, v1).
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(points_first_idx);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(segms_first_idx);
|
||||
return PointEdgeDistanceForwardCuda(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PointEdgeDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
}
|
||||
|
||||
// Backward pass for PointEdgeDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// idx_points: LongTensor of shape (P,) containing the indices
|
||||
// of the closest edge in the example in the batch.
|
||||
// This is computed by the forward pass.
|
||||
// grad_dists: FloatTensor of shape (P,)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(idx_points);
|
||||
CHECK_CUDA(grad_dists);
|
||||
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * EdgePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each edge segment to the closest
|
||||
// point belonging to the corresponding example in the batch of size N.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// points_first_idx: LongTensor of shape (N,) indicating the first point
|
||||
// index for each example in the batch
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
|
||||
// segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
// segms_first_idx: LongTensor of shape (N,) indicating the first edge
|
||||
// index for each example in the batch
|
||||
// max_segms: Scalar equal to max(S_i) for i in [0, N - 1] containing
|
||||
// the maximum number of edges in the batch and is used to set
|
||||
// the block dimensions in the CUDA implementation.
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (S,), where dists[s] is the squared
|
||||
// euclidean distance of s-th edge to the closest point in the
|
||||
// corresponding example in the batch.
|
||||
// idxs: LongTensor of shape (S,), where idxs[s] is the index of the closest
|
||||
// point in the example in the batch.
|
||||
// So, dists[s] = d(points[idxs[s]], segms[s, 0], segms[s, 1]), where
|
||||
// d(u, v0, v1) is the distance of u from the segment spanned by (v0, v1)
|
||||
//
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(points_first_idx);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(segms_first_idx);
|
||||
return EdgePointDistanceForwardCuda(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return EdgePointDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
}
|
||||
|
||||
// Backward pass for EdgePointDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// idx_segms: LongTensor of shape (S,) containing the indices
|
||||
// of the closest point in the example in the batch.
|
||||
// This is computed by the forward pass
|
||||
// grad_dists: FloatTensor of shape (S,)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(idx_segms);
|
||||
CHECK_CUDA(grad_dists);
|
||||
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeArrayDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each p in points to each edge
|
||||
// segment in segms.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th
|
||||
// edge segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (P, S), where dists[p, s] is the squared
|
||||
// euclidean distance of points[p] to the segment spanned by
|
||||
// (segms[s, 0], segms[s, 1])
|
||||
//
|
||||
// For pointcloud and meshes of batch size N, this function requires N
|
||||
// computations. The memory occupied is O(NPS) which can become quite large.
|
||||
// For example, a medium sized batch with N = 32 with P = 10000 and S = 5000
|
||||
// will require for the forward pass 5.8G of memory to store dists.
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms);
|
||||
#endif
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms);
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(segms);
|
||||
return PointEdgeArrayDistanceForwardCuda(points, segms);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PointEdgeArrayDistanceForwardCpu(points, segms);
|
||||
}
|
||||
|
||||
// Backward pass for PointEdgeArrayDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// grad_dists: FloatTensor of shape (P, S)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCpu(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(segms);
|
||||
CHECK_CUDA(grad_dists);
|
||||
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user