mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-07 04:36:00 +08:00
point mesh distances
Summary: Implementation of point to mesh distances. The current diff contains two types: (a) Point to Edge (b) Point to Face ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- POINT_MESH_EDGE_4_100_300_5000_cuda:0 2745 3138 183 POINT_MESH_EDGE_4_100_300_10000_cuda:0 4408 4499 114 POINT_MESH_EDGE_4_100_3000_5000_cuda:0 4978 5070 101 POINT_MESH_EDGE_4_100_3000_10000_cuda:0 9076 9187 56 POINT_MESH_EDGE_4_1000_300_5000_cuda:0 1411 1487 355 POINT_MESH_EDGE_4_1000_300_10000_cuda:0 4829 5030 104 POINT_MESH_EDGE_4_1000_3000_5000_cuda:0 7539 7620 67 POINT_MESH_EDGE_4_1000_3000_10000_cuda:0 12088 12272 42 POINT_MESH_EDGE_8_100_300_5000_cuda:0 3106 3222 161 POINT_MESH_EDGE_8_100_300_10000_cuda:0 8561 8648 59 POINT_MESH_EDGE_8_100_3000_5000_cuda:0 6932 7021 73 POINT_MESH_EDGE_8_100_3000_10000_cuda:0 24032 24176 21 POINT_MESH_EDGE_8_1000_300_5000_cuda:0 5272 5399 95 POINT_MESH_EDGE_8_1000_300_10000_cuda:0 11348 11430 45 POINT_MESH_EDGE_8_1000_3000_5000_cuda:0 17478 17683 29 POINT_MESH_EDGE_8_1000_3000_10000_cuda:0 25961 26236 20 POINT_MESH_EDGE_16_100_300_5000_cuda:0 8244 8323 61 POINT_MESH_EDGE_16_100_300_10000_cuda:0 18018 18071 28 POINT_MESH_EDGE_16_100_3000_5000_cuda:0 19428 19544 26 POINT_MESH_EDGE_16_100_3000_10000_cuda:0 44967 45135 12 POINT_MESH_EDGE_16_1000_300_5000_cuda:0 7825 7937 64 POINT_MESH_EDGE_16_1000_300_10000_cuda:0 18504 18571 28 POINT_MESH_EDGE_16_1000_3000_5000_cuda:0 65805 66132 8 POINT_MESH_EDGE_16_1000_3000_10000_cuda:0 90885 91089 6 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- POINT_MESH_FACE_4_100_300_5000_cuda:0 1561 1685 321 POINT_MESH_FACE_4_100_300_10000_cuda:0 2818 2954 178 POINT_MESH_FACE_4_100_3000_5000_cuda:0 15893 16018 32 POINT_MESH_FACE_4_100_3000_10000_cuda:0 16350 16439 31 POINT_MESH_FACE_4_1000_300_5000_cuda:0 3179 3278 158 POINT_MESH_FACE_4_1000_300_10000_cuda:0 2353 2436 213 POINT_MESH_FACE_4_1000_3000_5000_cuda:0 16262 16336 31 POINT_MESH_FACE_4_1000_3000_10000_cuda:0 9334 9448 54 POINT_MESH_FACE_8_100_300_5000_cuda:0 4377 4493 115 POINT_MESH_FACE_8_100_300_10000_cuda:0 9728 9822 52 POINT_MESH_FACE_8_100_3000_5000_cuda:0 26428 26544 19 POINT_MESH_FACE_8_100_3000_10000_cuda:0 42238 43031 12 POINT_MESH_FACE_8_1000_300_5000_cuda:0 3891 3982 129 POINT_MESH_FACE_8_1000_300_10000_cuda:0 5363 5429 94 POINT_MESH_FACE_8_1000_3000_5000_cuda:0 20998 21084 24 POINT_MESH_FACE_8_1000_3000_10000_cuda:0 39711 39897 13 POINT_MESH_FACE_16_100_300_5000_cuda:0 5955 6001 84 POINT_MESH_FACE_16_100_300_10000_cuda:0 12082 12144 42 POINT_MESH_FACE_16_100_3000_5000_cuda:0 44996 45176 12 POINT_MESH_FACE_16_100_3000_10000_cuda:0 73042 73197 7 POINT_MESH_FACE_16_1000_300_5000_cuda:0 8292 8374 61 POINT_MESH_FACE_16_1000_300_10000_cuda:0 19442 19506 26 POINT_MESH_FACE_16_1000_3000_5000_cuda:0 36059 36194 14 POINT_MESH_FACE_16_1000_3000_10000_cuda:0 64644 64822 8 -------------------------------------------------------------------------------- ``` Reviewed By: jcjohnson Differential Revision: D20590462 fbshipit-source-id: 42a39837b514a546ac9471bfaff60eefe7fae829
This commit is contained in:
committed by
Facebook GitHub Bot
parent
474c8b456a
commit
487d4d6607
548
pytorch3d/csrc/point_mesh/point_mesh_edge.cu
Normal file
548
pytorch3d/csrc/point_mesh/point_mesh_edge.cu
Normal file
@@ -0,0 +1,548 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "utils/float_math.cuh"
|
||||
#include "utils/geometry_utils.cuh"
|
||||
#include "utils/warp_reduce.cuh"
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void PointEdgeForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const int64_t* __restrict__ points_first_idx, // (B,)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const int64_t* __restrict__ segms_first_idx, // (B,)
|
||||
float* __restrict__ dist_points, // (P,)
|
||||
int64_t* __restrict__ idx_points, // (P,)
|
||||
const size_t B,
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t batch_idx = blockIdx.y; // index of batch element.
|
||||
|
||||
// start and end for points in batch
|
||||
const int64_t startp = points_first_idx[batch_idx];
|
||||
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
|
||||
|
||||
// start and end for segments in batch_idx
|
||||
const int64_t starts = segms_first_idx[batch_idx];
|
||||
const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S;
|
||||
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
const size_t tid = threadIdx.x; // thread idx
|
||||
|
||||
// Each block will compute one element of the output idx_points[startp + i],
|
||||
// dist_points[startp + i]. Within the block we will use threads to compute
|
||||
// the distances between points[startp + i] and segms[j] for all j belonging
|
||||
// in the same batch as i, i.e. j in [starts, ends]. Then use a block
|
||||
// reduction to take an argmin of the distances.
|
||||
|
||||
// If i exceeds the number of points in batch_idx, then do nothing
|
||||
if (i < (endp - startp)) {
|
||||
// Retrieve (startp + i) point
|
||||
const float3 p_f3 = points_f3[startp + i];
|
||||
|
||||
// Compute the distances between points[startp + i] and segms[j] for
|
||||
// all j belonging in the same batch as i, i.e. j in [starts, ends].
|
||||
// Here each thread will reduce over (ends-starts) / blockDim.x in serial,
|
||||
// and store its result to shared memory
|
||||
float min_dist = FLT_MAX;
|
||||
size_t min_idx = 0;
|
||||
for (size_t j = tid; j < (ends - starts); j += blockDim.x) {
|
||||
const float3 v0 = segms_f3[(starts + j) * 2 + 0];
|
||||
const float3 v1 = segms_f3[(starts + j) * 2 + 1];
|
||||
float dist = PointLine3DistanceForward(p_f3, v0, v1);
|
||||
min_dist = (j == tid) ? dist : min_dist;
|
||||
min_idx = (dist <= min_dist) ? (starts + j) : min_idx;
|
||||
min_dist = (dist <= min_dist) ? dist : min_dist;
|
||||
}
|
||||
min_dists[tid] = min_dist;
|
||||
min_idxs[tid] = min_idx;
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
min_idxs[tid] = min_idxs[tid + s];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Unroll the last 6 iterations of the loop since they will happen
|
||||
// synchronized within a single warp.
|
||||
if (tid < 32)
|
||||
WarpReduce<float>(min_dists, min_idxs, tid);
|
||||
|
||||
// Finally thread 0 writes the result to the output buffer.
|
||||
if (tid == 0) {
|
||||
idx_points[startp + i] = min_idxs[0];
|
||||
dist_points[startp + i] = min_dists[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM(segms_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor dists = torch::zeros({P,}, points.options());
|
||||
torch::Tensor idxs = torch::zeros({P,}, points_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_points, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
PointEdgeForwardKernel<<<blocks, threads, shared_size>>>(
|
||||
points.data_ptr<float>(),
|
||||
points_first_idx.data_ptr<int64_t>(),
|
||||
segms.data_ptr<float>(),
|
||||
segms_first_idx.data_ptr<int64_t>(),
|
||||
dists.data_ptr<float>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
B,
|
||||
P,
|
||||
S);
|
||||
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
__global__ void PointEdgeBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const int64_t* __restrict__ idx_points, // (P,)
|
||||
const float* __restrict__ grad_dists, // (P,)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_segms, // (S, 2, 3)
|
||||
const size_t P) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
for (size_t p = tid; p < P; p += stride) {
|
||||
const float3 p_f3 = points_f3[p];
|
||||
|
||||
const int64_t sidx = idx_points[p];
|
||||
const float3 v0 = segms_f3[sidx * 2 + 0];
|
||||
const float3 v1 = segms_f3[sidx * 2 + 1];
|
||||
|
||||
const float grad_dist = grad_dists[p];
|
||||
|
||||
const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_v0 = thrust::get<1>(grads);
|
||||
const float3 grad_v1 = thrust::get<2>(grads);
|
||||
|
||||
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 0, grad_v0.x);
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 1, grad_v0.y);
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 0 * 3 + 2, grad_v0.z);
|
||||
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 0, grad_v1.x);
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 1, grad_v1.y);
|
||||
atomicAdd(grad_segms + sidx * 2 * 3 + 1 * 3 + 2, grad_v1.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM(idx_points.size(0) == P);
|
||||
AT_ASSERTM(grad_dists.size(0) == P);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
PointEdgeBackwardKernel<<<blocks, threads>>>(
|
||||
points.data_ptr<float>(),
|
||||
segms.data_ptr<float>(),
|
||||
idx_points.data_ptr<int64_t>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_segms.data_ptr<float>(),
|
||||
P);
|
||||
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * EdgePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void EdgePointForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const int64_t* __restrict__ points_first_idx, // (B,)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const int64_t* __restrict__ segms_first_idx, // (B,)
|
||||
float* __restrict__ dist_segms, // (S,)
|
||||
int64_t* __restrict__ idx_segms, // (S,)
|
||||
const size_t B,
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t batch_idx = blockIdx.y; // index of batch element.
|
||||
|
||||
// start and end for points in batch_idx
|
||||
const int64_t startp = points_first_idx[batch_idx];
|
||||
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
|
||||
|
||||
// start and end for segms in batch_idx
|
||||
const int64_t starts = segms_first_idx[batch_idx];
|
||||
const int64_t ends = batch_idx + 1 < B ? segms_first_idx[batch_idx + 1] : S;
|
||||
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
const size_t tid = threadIdx.x; // thread index
|
||||
|
||||
// Each block will compute one element of the output idx_segms[starts + i],
|
||||
// dist_segms[starts + i]. Within the block we will use threads to compute
|
||||
// the distances between segms[starts + i] and points[j] for all j belonging
|
||||
// in the same batch as i, i.e. j in [startp, endp]. Then use a block
|
||||
// reduction to take an argmin of the distances.
|
||||
|
||||
// If i exceeds the number of segms in batch_idx, then do nothing
|
||||
if (i < (ends - starts)) {
|
||||
const float3 v0 = segms_f3[(starts + i) * 2 + 0];
|
||||
const float3 v1 = segms_f3[(starts + i) * 2 + 1];
|
||||
|
||||
// Compute the distances between segms[starts + i] and points[j] for
|
||||
// all j belonging in the same batch as i, i.e. j in [startp, endp].
|
||||
// Here each thread will reduce over (endp-startp) / blockDim.x in serial,
|
||||
// and store its result to shared memory
|
||||
float min_dist = FLT_MAX;
|
||||
size_t min_idx = 0;
|
||||
for (size_t j = tid; j < (endp - startp); j += blockDim.x) {
|
||||
// Retrieve (startp + i) point
|
||||
const float3 p_f3 = points_f3[startp + j];
|
||||
|
||||
float dist = PointLine3DistanceForward(p_f3, v0, v1);
|
||||
min_dist = (j == tid) ? dist : min_dist;
|
||||
min_idx = (dist <= min_dist) ? (startp + j) : min_idx;
|
||||
min_dist = (dist <= min_dist) ? dist : min_dist;
|
||||
}
|
||||
min_dists[tid] = min_dist;
|
||||
min_idxs[tid] = min_idx;
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
min_idxs[tid] = min_idxs[tid + s];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Unroll the last 6 iterations of the loop since they will happen
|
||||
// synchronized within a single warp.
|
||||
if (tid < 32)
|
||||
WarpReduce<float>(min_dists, min_idxs, tid);
|
||||
|
||||
// Finally thread 0 writes the result to the output buffer.
|
||||
if (tid == 0) {
|
||||
idx_segms[starts + i] = min_idxs[0];
|
||||
dist_segms[starts + i] = min_dists[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM(segms_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor dists = torch::zeros({S,}, segms.options());
|
||||
torch::Tensor idxs = torch::zeros({S,}, segms_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_segms, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
EdgePointForwardKernel<<<blocks, threads, shared_size>>>(
|
||||
points.data_ptr<float>(),
|
||||
points_first_idx.data_ptr<int64_t>(),
|
||||
segms.data_ptr<float>(),
|
||||
segms_first_idx.data_ptr<int64_t>(),
|
||||
dists.data_ptr<float>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
B,
|
||||
P,
|
||||
S);
|
||||
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
__global__ void EdgePointBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const int64_t* __restrict__ idx_segms, // (S,)
|
||||
const float* __restrict__ grad_dists, // (S,)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_segms, // (S, 2, 3)
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
for (size_t s = tid; s < S; s += stride) {
|
||||
const float3 v0 = segms_f3[s * 2 + 0];
|
||||
const float3 v1 = segms_f3[s * 2 + 1];
|
||||
|
||||
const int64_t pidx = idx_segms[s];
|
||||
|
||||
const float3 p_f3 = points_f3[pidx];
|
||||
|
||||
const float grad_dist = grad_dists[s];
|
||||
|
||||
const auto grads = PointLine3DistanceBackward(p_f3, v0, v1, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_v0 = thrust::get<1>(grads);
|
||||
const float3 grad_v1 = thrust::get<2>(grads);
|
||||
|
||||
atomicAdd(grad_points + pidx * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + pidx * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + pidx * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_v0.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_v0.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_v0.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_v1.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_v1.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_v1.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM(idx_segms.size(0) == S);
|
||||
AT_ASSERTM(grad_dists.size(0) == S);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
EdgePointBackwardKernel<<<blocks, threads>>>(
|
||||
points.data_ptr<float>(),
|
||||
segms.data_ptr<float>(),
|
||||
idx_segms.data_ptr<int64_t>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_segms.data_ptr<float>(),
|
||||
S);
|
||||
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeArrayDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void PointEdgeArrayForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
float* __restrict__ dists, // (P, S)
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
const int p = t_i % P; // point index
|
||||
float3 a = segms_f3[s * 2 + 0];
|
||||
float3 b = segms_f3[s * 2 + 1];
|
||||
|
||||
float3 point = points_f3[p];
|
||||
float dist = PointLine3DistanceForward(point, a, b);
|
||||
dists[p * S + s] = dist;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
|
||||
torch::Tensor dists = torch::zeros({P, S}, points.options());
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointEdgeArrayForwardKernel<<<blocks, threads>>>(
|
||||
points.data_ptr<float>(),
|
||||
segms.data_ptr<float>(),
|
||||
dists.data_ptr<float>(),
|
||||
P,
|
||||
S);
|
||||
|
||||
return dists;
|
||||
}
|
||||
|
||||
__global__ void PointEdgeArrayBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ segms, // (S, 2, 3)
|
||||
const float* __restrict__ grad_dists, // (P, S)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_segms, // (S, 2, 3)
|
||||
const size_t P,
|
||||
const size_t S) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
const int p = t_i % P; // point index
|
||||
const float3 a = segms_f3[s * 2 + 0];
|
||||
const float3 b = segms_f3[s * 2 + 1];
|
||||
|
||||
const float3 point = points_f3[p];
|
||||
const float grad_dist = grad_dists[p * S + s];
|
||||
const auto grads = PointLine3DistanceBackward(point, a, b, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_a = thrust::get<1>(grads);
|
||||
const float3 grad_b = thrust::get<2>(grads);
|
||||
|
||||
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 0, grad_a.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 1, grad_a.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 0 * 3 + 2, grad_a.z);
|
||||
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 0, grad_b.x);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 1, grad_b.y);
|
||||
atomicAdd(grad_segms + s * 2 * 3 + 1 * 3 + 2, grad_b.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t S = segms.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(segms.size(1) == 2) && (segms.size(2) == 3),
|
||||
"segms must be of shape Sx2x3");
|
||||
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == S));
|
||||
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_segms = torch::zeros({S, 2, 3}, segms.options());
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointEdgeArrayBackwardKernel<<<blocks, threads>>>(
|
||||
points.data_ptr<float>(),
|
||||
segms.data_ptr<float>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_segms.data_ptr<float>(),
|
||||
P,
|
||||
S);
|
||||
|
||||
return std::make_tuple(grad_points, grad_segms);
|
||||
}
|
||||
274
pytorch3d/csrc/point_mesh/point_mesh_edge.h
Normal file
274
pytorch3d/csrc/point_mesh/point_mesh_edge.h
Normal file
@@ -0,0 +1,274 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each p in points to the closest
|
||||
// mesh edge belonging to the corresponding example in the batch of size N.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// points_first_idx: LongTensor of shape (N,) indicating the first point
|
||||
// index for each example in the batch
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
|
||||
// segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
// segms_first_idx: LongTensor of shape (N,) indicating the first edge
|
||||
// index for each example in the batch
|
||||
// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
|
||||
// the maximum number of points in the batch and is used to set
|
||||
// the grid dimensions in the CUDA implementation.
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (P,), where dists[p] is the squared euclidean
|
||||
// distance of points[p] to the closest edge in the same example in the
|
||||
// batch.
|
||||
// idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest
|
||||
// edge in the batch.
|
||||
// So, dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1]),
|
||||
// where d(u, v0, v1) is the distance of u from the segment spanned by
|
||||
// (v0, v1).
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_points) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PointEdgeDistanceForwardCuda(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// Backward pass for PointEdgeDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// idx_points: LongTensor of shape (P,) containing the indices
|
||||
// of the closest edge in the example in the batch.
|
||||
// This is computed by the forward pass.
|
||||
// grad_dists: FloatTensor of shape (P,)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * EdgePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each edge segment to the closest
|
||||
// point belonging to the corresponding example in the batch of size N.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// points_first_idx: LongTensor of shape (N,) indicating the first point
|
||||
// index for each example in the batch
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th edge
|
||||
// segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
// segms_first_idx: LongTensor of shape (N,) indicating the first edge
|
||||
// index for each example in the batch
|
||||
// max_segms: Scalar equal to max(S_i) for i in [0, N - 1] containing
|
||||
// the maximum number of edges in the batch and is used to set
|
||||
// the block dimensions in the CUDA implementation.
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (S,), where dists[s] is the squared
|
||||
// euclidean distance of s-th edge to the closest point in the
|
||||
// corresponding example in the batch.
|
||||
// idxs: LongTensor of shape (S,), where idxs[s] is the index of the closest
|
||||
// point in the example in the batch.
|
||||
// So, dists[s] = d(points[idxs[s]], segms[s, 0], segms[s, 1]), where
|
||||
// d(u, v0, v1) is the distance of u from the segment spanned by (v0, v1)
|
||||
//
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& segms_first_idx,
|
||||
const int64_t max_segms) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return EdgePointDistanceForwardCuda(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// Backward pass for EdgePointDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// idx_segms: LongTensor of shape (S,) containing the indices
|
||||
// of the closest point in the example in the batch.
|
||||
// This is computed by the forward pass
|
||||
// grad_dists: FloatTensor of shape (S,)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& idx_segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointEdgeArrayDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each p in points to each edge
|
||||
// segment in segms.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3) of edge segments. The s-th
|
||||
// edge segment is spanned by (segms[s, 0], segms[s, 1])
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (P, S), where dists[p, s] is the squared
|
||||
// euclidean distance of points[p] to the segment spanned by
|
||||
// (segms[s, 0], segms[s, 1])
|
||||
//
|
||||
// For pointcloud and meshes of batch size N, this function requires N
|
||||
// computations. The memory occupied is O(NPS) which can become quite large.
|
||||
// For example, a medium sized batch with N = 32 with P = 10000 and S = 5000
|
||||
// will require for the forward pass 5.8G of memory to store dists.
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms);
|
||||
#endif
|
||||
|
||||
torch::Tensor PointEdgeArrayDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PointEdgeArrayDistanceForwardCuda(points, segms);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// Backward pass for PointEdgeArrayDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// segms: FloatTensor of shape (S, 2, 3)
|
||||
// grad_dists: FloatTensor of shape (P, S)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_segms: FloatTensor of shape (S, 2, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& segms,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
574
pytorch3d/csrc/point_mesh/point_mesh_face.cu
Normal file
574
pytorch3d/csrc/point_mesh/point_mesh_face.cu
Normal file
@@ -0,0 +1,574 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "utils/float_math.cuh"
|
||||
#include "utils/geometry_utils.cuh"
|
||||
#include "utils/warp_reduce.cuh"
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void PointFaceForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const int64_t* __restrict__ points_first_idx, // (B,)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
const int64_t* __restrict__ tris_first_idx, // (B,)
|
||||
float* __restrict__ dist_points, // (P,)
|
||||
int64_t* __restrict__ idx_points, // (P,)
|
||||
const size_t B,
|
||||
const size_t P,
|
||||
const size_t T) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t batch_idx = blockIdx.y; // index of batch element.
|
||||
|
||||
// start and end for points in batch_idx
|
||||
const int64_t startp = points_first_idx[batch_idx];
|
||||
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
|
||||
|
||||
// start and end for faces in batch_idx
|
||||
const int64_t startt = tris_first_idx[batch_idx];
|
||||
const int64_t endt = batch_idx + 1 < B ? tris_first_idx[batch_idx + 1] : T;
|
||||
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
const size_t tid = threadIdx.x; // thread index
|
||||
|
||||
// Each block will compute one element of the output idx_points[startp + i],
|
||||
// dist_points[startp + i]. Within the block we will use threads to compute
|
||||
// the distances between points[startp + i] and tris[j] for all j belonging
|
||||
// in the same batch as i, i.e. j in [startt, endt]. Then use a block
|
||||
// reduction to take an argmin of the distances.
|
||||
|
||||
// If i exceeds the number of points in batch_idx, then do nothing
|
||||
if (i < (endp - startp)) {
|
||||
// Retrieve (startp + i) point
|
||||
const float3 p_f3 = points_f3[startp + i];
|
||||
|
||||
// Compute the distances between points[startp + i] and tris[j] for
|
||||
// all j belonging in the same batch as i, i.e. j in [startt, endt].
|
||||
// Here each thread will reduce over (endt-startt) / blockDim.x in serial,
|
||||
// and store its result to shared memory
|
||||
float min_dist = FLT_MAX;
|
||||
size_t min_idx = 0;
|
||||
for (size_t j = tid; j < (endt - startt); j += blockDim.x) {
|
||||
const float3 v0 = tris_f3[(startt + j) * 3 + 0];
|
||||
const float3 v1 = tris_f3[(startt + j) * 3 + 1];
|
||||
const float3 v2 = tris_f3[(startt + j) * 3 + 2];
|
||||
float dist = PointTriangle3DistanceForward(p_f3, v0, v1, v2);
|
||||
min_dist = (j == tid) ? dist : min_dist;
|
||||
min_idx = (dist <= min_dist) ? (startt + j) : min_idx;
|
||||
min_dist = (dist <= min_dist) ? dist : min_dist;
|
||||
}
|
||||
min_dists[tid] = min_dist;
|
||||
min_idxs[tid] = min_idx;
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
min_idxs[tid] = min_idxs[tid + s];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Unroll the last 6 iterations of the loop since they will happen
|
||||
// synchronized within a single warp.
|
||||
if (tid < 32)
|
||||
WarpReduce<float>(min_dists, min_idxs, tid);
|
||||
|
||||
// Finally thread 0 writes the result to the output buffer.
|
||||
if (tid == 0) {
|
||||
idx_points[startp + i] = min_idxs[0];
|
||||
dist_points[startp + i] = min_dists[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& tris_first_idx,
|
||||
const int64_t max_points) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM(tris_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor dists = torch::zeros({P,}, points.options());
|
||||
torch::Tensor idxs = torch::zeros({P,}, points_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_points, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
PointFaceForwardKernel<<<blocks, threads, shared_size>>>(
|
||||
points.data_ptr<float>(),
|
||||
points_first_idx.data_ptr<int64_t>(),
|
||||
tris.data_ptr<float>(),
|
||||
tris_first_idx.data_ptr<int64_t>(),
|
||||
dists.data_ptr<float>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
B,
|
||||
P,
|
||||
T);
|
||||
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
__global__ void PointFaceBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
const int64_t* __restrict__ idx_points, // (P,)
|
||||
const float* __restrict__ grad_dists, // (P,)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_tris, // (T, 3, 3)
|
||||
const size_t P) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* tris_f3 = (float3*)tris;
|
||||
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
for (size_t p = tid; p < P; p += stride) {
|
||||
const float3 p_f3 = points_f3[p];
|
||||
|
||||
const int64_t tidx = idx_points[p];
|
||||
const float3 v0 = tris_f3[tidx * 3 + 0];
|
||||
const float3 v1 = tris_f3[tidx * 3 + 1];
|
||||
const float3 v2 = tris_f3[tidx * 3 + 2];
|
||||
|
||||
const float grad_dist = grad_dists[p];
|
||||
|
||||
const auto grads =
|
||||
PointTriangle3DistanceBackward(p_f3, v0, v1, v2, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_v0 = thrust::get<1>(grads);
|
||||
const float3 grad_v1 = thrust::get<2>(grads);
|
||||
const float3 grad_v2 = thrust::get<3>(grads);
|
||||
|
||||
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 0, grad_v0.x);
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 1, grad_v0.y);
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 2, grad_v0.z);
|
||||
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 0, grad_v1.x);
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 1, grad_v1.y);
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 2, grad_v1.z);
|
||||
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 0, grad_v2.x);
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 1, grad_v2.y);
|
||||
atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 2, grad_v2.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM(idx_points.size(0) == P);
|
||||
AT_ASSERTM(grad_dists.size(0) == P);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
PointFaceBackwardKernel<<<blocks, threads>>>(
|
||||
points.data_ptr<float>(),
|
||||
tris.data_ptr<float>(),
|
||||
idx_points.data_ptr<int64_t>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_tris.data_ptr<float>(),
|
||||
P);
|
||||
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * FacePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void FacePointForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const int64_t* __restrict__ points_first_idx, // (B,)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
const int64_t* __restrict__ tris_first_idx, // (B,)
|
||||
float* __restrict__ dist_tris, // (T,)
|
||||
int64_t* __restrict__ idx_tris, // (T,)
|
||||
const size_t B,
|
||||
const size_t P,
|
||||
const size_t T) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t batch_idx = blockIdx.y; // index of batch element.
|
||||
|
||||
// start and end for points in batch_idx
|
||||
const int64_t startp = points_first_idx[batch_idx];
|
||||
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
|
||||
|
||||
// start and end for tris in batch_idx
|
||||
const int64_t startt = tris_first_idx[batch_idx];
|
||||
const int64_t endt = batch_idx + 1 < B ? tris_first_idx[batch_idx + 1] : T;
|
||||
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
const size_t tid = threadIdx.x;
|
||||
|
||||
// Each block will compute one element of the output idx_tris[startt + i],
|
||||
// dist_tris[startt + i]. Within the block we will use threads to compute
|
||||
// the distances between tris[startt + i] and points[j] for all j belonging
|
||||
// in the same batch as i, i.e. j in [startp, endp]. Then use a block
|
||||
// reduction to take an argmin of the distances.
|
||||
|
||||
// If i exceeds the number of tris in batch_idx, then do nothing
|
||||
if (i < (endt - startt)) {
|
||||
const float3 v0 = tris_f3[(startt + i) * 3 + 0];
|
||||
const float3 v1 = tris_f3[(startt + i) * 3 + 1];
|
||||
const float3 v2 = tris_f3[(startt + i) * 3 + 2];
|
||||
|
||||
// Compute the distances between tris[startt + i] and points[j] for
|
||||
// all j belonging in the same batch as i, i.e. j in [startp, endp].
|
||||
// Here each thread will reduce over (endp-startp) / blockDim.x in serial,
|
||||
// and store its result to shared memory
|
||||
float min_dist = FLT_MAX;
|
||||
size_t min_idx = 0;
|
||||
for (size_t j = tid; j < (endp - startp); j += blockDim.x) {
|
||||
// Retrieve (startp + i) point
|
||||
const float3 p_f3 = points_f3[startp + j];
|
||||
|
||||
float dist = PointTriangle3DistanceForward(p_f3, v0, v1, v2);
|
||||
min_dist = (j == tid) ? dist : min_dist;
|
||||
min_idx = (dist <= min_dist) ? (startp + j) : min_idx;
|
||||
min_dist = (dist <= min_dist) ? dist : min_dist;
|
||||
}
|
||||
min_dists[tid] = min_dist;
|
||||
min_idxs[tid] = min_idx;
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
min_idxs[tid] = min_idxs[tid + s];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Unroll the last 6 iterations of the loop since they will happen
|
||||
// synchronized within a single warp.
|
||||
if (tid < 32)
|
||||
WarpReduce<float>(min_dists, min_idxs, tid);
|
||||
|
||||
// Finally thread 0 writes the result to the output buffer.
|
||||
if (tid == 0) {
|
||||
idx_tris[startt + i] = min_idxs[0];
|
||||
dist_tris[startt + i] = min_dists[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& tris_first_idx,
|
||||
const int64_t max_tris) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM(tris_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor dists = torch::zeros({T,}, tris.options());
|
||||
torch::Tensor idxs = torch::zeros({T,}, tris_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_tris, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
FacePointForwardKernel<<<blocks, threads, shared_size>>>(
|
||||
points.data_ptr<float>(),
|
||||
points_first_idx.data_ptr<int64_t>(),
|
||||
tris.data_ptr<float>(),
|
||||
tris_first_idx.data_ptr<int64_t>(),
|
||||
dists.data_ptr<float>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
B,
|
||||
P,
|
||||
T);
|
||||
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
__global__ void FacePointBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
const int64_t* __restrict__ idx_tris, // (T,)
|
||||
const float* __restrict__ grad_dists, // (T,)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_tris, // (T, 3, 3)
|
||||
const size_t T) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* tris_f3 = (float3*)tris;
|
||||
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
for (size_t t = tid; t < T; t += stride) {
|
||||
const float3 v0 = tris_f3[t * 3 + 0];
|
||||
const float3 v1 = tris_f3[t * 3 + 1];
|
||||
const float3 v2 = tris_f3[t * 3 + 2];
|
||||
|
||||
const int64_t pidx = idx_tris[t];
|
||||
|
||||
const float3 p_f3 = points_f3[pidx];
|
||||
|
||||
const float grad_dist = grad_dists[t];
|
||||
|
||||
const auto grads =
|
||||
PointTriangle3DistanceBackward(p_f3, v0, v1, v2, grad_dist);
|
||||
const float3 grad_point = thrust::get<0>(grads);
|
||||
const float3 grad_v0 = thrust::get<1>(grads);
|
||||
const float3 grad_v1 = thrust::get<2>(grads);
|
||||
const float3 grad_v2 = thrust::get<3>(grads);
|
||||
|
||||
atomicAdd(grad_points + pidx * 3 + 0, grad_point.x);
|
||||
atomicAdd(grad_points + pidx * 3 + 1, grad_point.y);
|
||||
atomicAdd(grad_points + pidx * 3 + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 0, grad_v0.x);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 1, grad_v0.y);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 2, grad_v0.z);
|
||||
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 0, grad_v1.x);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 1, grad_v1.y);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 2, grad_v1.z);
|
||||
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 0, grad_v2.x);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 1, grad_v2.y);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 2, grad_v2.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& idx_tris,
|
||||
const torch::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM(idx_tris.size(0) == T);
|
||||
AT_ASSERTM(grad_dists.size(0) == T);
|
||||
|
||||
// clang-format off
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
|
||||
// clang-format on
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
|
||||
FacePointBackwardKernel<<<blocks, threads>>>(
|
||||
points.data_ptr<float>(),
|
||||
tris.data_ptr<float>(),
|
||||
idx_tris.data_ptr<int64_t>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_tris.data_ptr<float>(),
|
||||
T);
|
||||
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceArrayDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void PointFaceArrayForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
float* __restrict__ dists, // (P, T)
|
||||
const size_t P,
|
||||
const size_t T) {
|
||||
const float3* points_f3 = (float3*)points;
|
||||
const float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
||||
const int t = t_i / P; // segment index.
|
||||
const int p = t_i % P; // point index
|
||||
const float3 v0 = tris_f3[t * 3 + 0];
|
||||
const float3 v1 = tris_f3[t * 3 + 1];
|
||||
const float3 v2 = tris_f3[t * 3 + 2];
|
||||
|
||||
const float3 point = points_f3[p];
|
||||
float dist = PointTriangle3DistanceForward(point, v0, v1, v2);
|
||||
dists[p * T + t] = dist;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor PointFaceArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
|
||||
torch::Tensor dists = torch::zeros({P, T}, points.options());
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointFaceArrayForwardKernel<<<blocks, threads>>>(
|
||||
points.data_ptr<float>(),
|
||||
tris.data_ptr<float>(),
|
||||
dists.data_ptr<float>(),
|
||||
P,
|
||||
T);
|
||||
|
||||
return dists;
|
||||
}
|
||||
|
||||
__global__ void PointFaceArrayBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
const float* __restrict__ grad_dists, // (P, T)
|
||||
float* __restrict__ grad_points, // (P, 3)
|
||||
float* __restrict__ grad_tris, // (T, 3, 3)
|
||||
const size_t P,
|
||||
const size_t T) {
|
||||
const float3* points_f3 = (float3*)points;
|
||||
const float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
||||
const int t = t_i / P; // triangle index.
|
||||
const int p = t_i % P; // point index
|
||||
const float3 v0 = tris_f3[t * 3 + 0];
|
||||
const float3 v1 = tris_f3[t * 3 + 1];
|
||||
const float3 v2 = tris_f3[t * 3 + 2];
|
||||
|
||||
const float3 point = points_f3[p];
|
||||
|
||||
const float grad_dist = grad_dists[p * T + t];
|
||||
const auto grad =
|
||||
PointTriangle3DistanceBackward(point, v0, v1, v2, grad_dist);
|
||||
|
||||
const float3 grad_point = thrust::get<0>(grad);
|
||||
const float3 grad_v0 = thrust::get<1>(grad);
|
||||
const float3 grad_v1 = thrust::get<2>(grad);
|
||||
const float3 grad_v2 = thrust::get<3>(grad);
|
||||
|
||||
atomicAdd(grad_points + 3 * p + 0, grad_point.x);
|
||||
atomicAdd(grad_points + 3 * p + 1, grad_point.y);
|
||||
atomicAdd(grad_points + 3 * p + 2, grad_point.z);
|
||||
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 0, grad_v0.x);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 1, grad_v0.y);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 2, grad_v0.z);
|
||||
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 0, grad_v1.x);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 1, grad_v1.y);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 2, grad_v1.z);
|
||||
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 0, grad_v2.x);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 1, grad_v2.y);
|
||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 2, grad_v2.z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& grad_dists) {
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
|
||||
AT_ASSERTM(points.size(1) == 3, "points must be of shape Px3");
|
||||
AT_ASSERTM(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
AT_ASSERTM((grad_dists.size(0) == P) && (grad_dists.size(1) == T));
|
||||
|
||||
torch::Tensor grad_points = torch::zeros({P, 3}, points.options());
|
||||
torch::Tensor grad_tris = torch::zeros({T, 3, 3}, tris.options());
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
PointFaceArrayBackwardKernel<<<blocks, threads>>>(
|
||||
points.data_ptr<float>(),
|
||||
tris.data_ptr<float>(),
|
||||
grad_dists.data_ptr<float>(),
|
||||
grad_points.data_ptr<float>(),
|
||||
grad_tris.data_ptr<float>(),
|
||||
P,
|
||||
T);
|
||||
|
||||
return std::make_tuple(grad_points, grad_tris);
|
||||
}
|
||||
276
pytorch3d/csrc/point_mesh/point_mesh_face.h
Normal file
276
pytorch3d/csrc/point_mesh/point_mesh_face.h
Normal file
@@ -0,0 +1,276 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each p in points to it closest
|
||||
// triangular face belonging to the corresponding mesh example in the batch of
|
||||
// size N.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// points_first_idx: LongTensor of shape (N,) indicating the first point
|
||||
// index for each example in the batch
|
||||
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
|
||||
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
|
||||
// tris_first_idx: LongTensor of shape (N,) indicating the first face
|
||||
// index for each example in the batch
|
||||
// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
|
||||
// the maximum number of points in the batch and is used to set
|
||||
// the block dimensions in the CUDA implementation.
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (P,), where dists[p] is the minimum
|
||||
// squared euclidean distance of points[p] to the faces in the same
|
||||
// example in the batch.
|
||||
// idxs: LongTensor of shape (P,), where idxs[p] is the index of the closest
|
||||
// face in the batch.
|
||||
// So, dists[p] = d(points[p], tris[idxs[p], 0], tris[idxs[p], 1],
|
||||
// tris[idxs[p], 2]) where d(u, v0, v1, v2) is the distance of u from the
|
||||
// face spanned by (v0, v1, v2)
|
||||
//
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& tris_first_idx,
|
||||
const int64_t max_points);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& tris_first_idx,
|
||||
const int64_t max_points) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PointFaceDistanceForwardCuda(
|
||||
points, points_first_idx, tris, tris_first_idx, max_points);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// Backward pass for PointFaceDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// tris: FloatTensor of shape (T, 3, 3)
|
||||
// idx_points: LongTensor of shape (P,) containing the indices
|
||||
// of the closest face in the example in the batch.
|
||||
// This is computed by the forward pass
|
||||
// grad_dists: FloatTensor of shape (P,)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_tris: FloatTensor of shape (T, 3, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& idx_points,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * FacePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each triangular face to its
|
||||
// closest point belonging to the corresponding example in the batch of size N.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// points_first_idx: LongTensor of shape (N,) indicating the first point
|
||||
// index for each example in the batch
|
||||
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
|
||||
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
|
||||
// tris_first_idx: LongTensor of shape (N,) indicating the first face
|
||||
// index for each example in the batch
|
||||
// max_tris: Scalar equal to max(T_i) for i in [0, N - 1] containing
|
||||
// the maximum number of faces in the batch and is used to set
|
||||
// the block dimensions in the CUDA implementation.
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (T,), where dists[t] is the minimum squared
|
||||
// euclidean distance of t-th triangular face from the closest point in
|
||||
// the batch.
|
||||
// idxs: LongTensor of shape (T,), where idxs[t] is the index of the closest
|
||||
// point in the batch.
|
||||
// So, dists[t] = d(points[idxs[t]], tris[t, 0], tris[t, 1], tris[t, 2])
|
||||
// where d(u, v0, v1, v2) is the distance of u from the triangular face
|
||||
// spanned by (v0, v1, v2)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& tris_first_idx,
|
||||
const int64_t max_tros);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& points_first_idx,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& tris_first_idx,
|
||||
const int64_t max_tris) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return FacePointDistanceForwardCuda(
|
||||
points, points_first_idx, tris, tris_first_idx, max_tris);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// Backward pass for FacePointDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// tris: FloatTensor of shape (T, 3, 3)
|
||||
// idx_tris: LongTensor of shape (T,) containing the indices
|
||||
// of the closest point in the example in the batch.
|
||||
// This is computed by the forward pass
|
||||
// grad_dists: FloatTensor of shape (T,)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_tris: FloatTensor of shape (T, 3, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& idx_tris,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& idx_tris,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// * PointFaceArrayDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
// Computes the squared euclidean distance of each p in points to each
|
||||
// triangular face spanned by (v0, v1, v2) in tris.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
|
||||
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
|
||||
//
|
||||
// Returns:
|
||||
// dists: FloatTensor of shape (P, T), where dists[p, t] is the squared
|
||||
// euclidean distance of points[p] to the face spanned by (v0, v1, v2)
|
||||
// where v0 = tris[t, 0], v1 = tris[t, 1] and v2 = tris[t, 2]
|
||||
//
|
||||
// For pointcloud and meshes of batch size N, this function requires N
|
||||
// computations. The memory occupied is O(NPT) which can become quite large.
|
||||
// For example, a medium sized batch with N = 32 with P = 10000 and T = 5000
|
||||
// will require for the forward pass 5.8G of memory to store dists.
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
torch::Tensor PointFaceArrayDistanceForwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris);
|
||||
#endif
|
||||
|
||||
torch::Tensor PointFaceArrayDistanceForward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PointFaceArrayDistanceForwardCuda(points, tris);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
|
||||
// Backward pass for PointFaceArrayDistance.
|
||||
//
|
||||
// Args:
|
||||
// points: FloatTensor of shape (P, 3)
|
||||
// tris: FloatTensor of shape (T, 3, 3)
|
||||
// grad_dists: FloatTensor of shape (P, T)
|
||||
//
|
||||
// Returns:
|
||||
// grad_points: FloatTensor of shape (P, 3)
|
||||
// grad_tris: FloatTensor of shape (T, 3, 3)
|
||||
//
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& grad_dists);
|
||||
#endif
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& tris,
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("No CPU implementation.");
|
||||
}
|
||||
Reference in New Issue
Block a user