Consolidate point mesh forward kernels

Summary: This diff creates the generic MeshForwardKernel which can handle distance calculations between point, edge and faces in either direction. Replaces only point_mesh_face code for now.

Reviewed By: gkioxari

Differential Revision: D24543316

fbshipit-source-id: 302707d7cec2d77a899738adf40481035c240da8
This commit is contained in:
Dave Schnizlein 2020-11-10 09:32:33 -08:00 committed by Facebook GitHub Bot
parent 194b29fb6c
commit c41aff23f0

View File

@ -15,18 +15,23 @@
// * PointFaceDistance * // * PointFaceDistance *
// **************************************************************************** // ****************************************************************************
__global__ void PointFaceForwardKernel( __global__ void DistanceForwardKernel(
const float* __restrict__ points, // (P, 3) const float* __restrict__ objects, // (O * oD * 3)
const int64_t* __restrict__ points_first_idx, // (B,) const size_t objects_size, // O
const float* __restrict__ tris, // (T, 3, 3) const size_t objects_dim, // oD
const int64_t* __restrict__ tris_first_idx, // (B,) const float* __restrict__ targets, // (T * tD * 3)
float* __restrict__ dist_points, // (P,) const size_t targets_size, // T
int64_t* __restrict__ idx_points, // (P,) const size_t targets_dim, // tD
const size_t B, const int64_t* __restrict__ objects_first_idx, // (B,)
const size_t P, const int64_t* __restrict__ targets_first_idx, // (B,)
const size_t T) { const size_t batch_size, // B
float3* points_f3 = (float3*)points; float* __restrict__ dist_objects, // (O,)
float3* tris_f3 = (float3*)tris; int64_t* __restrict__ idx_objects) { // (O,)
// This kernel is used interchangeably to compute bi-directional distances
// between points and triangles/lines. The direction of the distance computed,
// i.e. point to triangle/line or triangle/line to point, depends on the order
// of the input arguments and is inferred based on their shape. The shape is
// used to distinguish between triangles and lines
// Single shared memory buffer which is split and cast to different types. // Single shared memory buffer which is split and cast to different types.
extern __shared__ char shared_buf[]; extern __shared__ char shared_buf[];
@ -35,39 +40,59 @@ __global__ void PointFaceForwardKernel(
const size_t batch_idx = blockIdx.y; // index of batch element. const size_t batch_idx = blockIdx.y; // index of batch element.
// start and end for points in batch_idx // start and end for objects in batch_idx
const int64_t startp = points_first_idx[batch_idx]; const int64_t starto = objects_first_idx[batch_idx];
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P; const int64_t endo = batch_idx + 1 < batch_size
? objects_first_idx[batch_idx + 1]
: objects_size;
// start and end for faces in batch_idx // start and end for targets in batch_idx
const int64_t startt = tris_first_idx[batch_idx]; const int64_t startt = targets_first_idx[batch_idx];
const int64_t endt = batch_idx + 1 < B ? tris_first_idx[batch_idx + 1] : T; const int64_t endt = batch_idx + 1 < batch_size
? targets_first_idx[batch_idx + 1]
: targets_size;
const size_t i = blockIdx.x; // index of point within batch element. const size_t i = blockIdx.x; // index within batch element.
const size_t tid = threadIdx.x; // thread index const size_t tid = threadIdx.x; // thread index
// Each block will compute one element of the output idx_points[startp + i], // Set references to points/face based on which of objects/targets refer to
// dist_points[startp + i]. Within the block we will use threads to compute // points/faces
// the distances between points[startp + i] and tris[j] for all j belonging float3* points_f3 = objects_dim == 1 ? (float3*)objects : (float3*)targets;
// in the same batch as i, i.e. j in [startt, endt]. Then use a block float3* face_f3 = objects_dim == 1 ? (float3*)targets : (float3*)objects;
// reduction to take an argmin of the distances. // Distinguishes whether we're computing distance against triangle vs edge
bool isTriangle = objects_dim == 3 || targets_dim == 3;
// If i exceeds the number of points in batch_idx, then do nothing // Each block will compute one element of the output idx_objects[starto + i],
if (i < (endp - startp)) { // dist_objects[starto + i]. Within the block we will use threads to compute
// Retrieve (startp + i) point // the distances between objects[starto + i] and targets[j] for all j
const float3 p_f3 = points_f3[startp + i]; // belonging in the same batch as i, i.e. j in [startt, endt]. Then use a
// block reduction to take an argmin of the distances.
// Compute the distances between points[startp + i] and tris[j] for // If i exceeds the number of objects in batch_idx, then do nothing
if (i < (endo - starto)) {
// Compute the distances between objects[starto + i] and targets[j] for
// all j belonging in the same batch as i, i.e. j in [startt, endt]. // all j belonging in the same batch as i, i.e. j in [startt, endt].
// Here each thread will reduce over (endt-startt) / blockDim.x in serial, // Here each thread will reduce over (endt-startt) / blockDim.x in serial,
// and store its result to shared memory // and store its result to shared memory
float min_dist = FLT_MAX; float min_dist = FLT_MAX;
size_t min_idx = 0; size_t min_idx = 0;
for (size_t j = tid; j < (endt - startt); j += blockDim.x) { for (size_t j = tid; j < (endt - startt); j += blockDim.x) {
const float3 v0 = tris_f3[(startt + j) * 3 + 0]; size_t point_idx = objects_dim == 1 ? starto + i : startt + j;
const float3 v1 = tris_f3[(startt + j) * 3 + 1]; size_t face_idx = objects_dim == 1 ? (startt + j) * targets_dim
const float3 v2 = tris_f3[(startt + j) * 3 + 2]; : (starto + i) * objects_dim;
float dist = PointTriangle3DistanceForward(p_f3, v0, v1, v2);
float dist;
if (isTriangle) {
dist = PointTriangle3DistanceForward(
points_f3[point_idx],
face_f3[face_idx],
face_f3[face_idx + 1],
face_f3[face_idx + 2]);
} else {
dist = PointLine3DistanceForward(
points_f3[point_idx], face_f3[face_idx], face_f3[face_idx + 1]);
}
min_dist = (j == tid) ? dist : min_dist; min_dist = (j == tid) ? dist : min_dist;
min_idx = (dist <= min_dist) ? (startt + j) : min_idx; min_idx = (dist <= min_dist) ? (startt + j) : min_idx;
min_dist = (dist <= min_dist) ? dist : min_dist; min_dist = (dist <= min_dist) ? dist : min_dist;
@ -94,45 +119,61 @@ __global__ void PointFaceForwardKernel(
// Finally thread 0 writes the result to the output buffer. // Finally thread 0 writes the result to the output buffer.
if (tid == 0) { if (tid == 0) {
idx_points[startp + i] = min_idxs[0]; idx_objects[starto + i] = min_idxs[0];
dist_points[startp + i] = min_dists[0]; dist_objects[starto + i] = min_dists[0];
} }
} }
} }
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda( std::tuple<at::Tensor, at::Tensor> DistanceForwardCuda(
const at::Tensor& points, const at::Tensor& objects,
const at::Tensor& points_first_idx, const size_t objects_dim,
const at::Tensor& tris, const at::Tensor& objects_first_idx,
const at::Tensor& tris_first_idx, const at::Tensor& targets,
const int64_t max_points) { const size_t targets_dim,
const at::Tensor& targets_first_idx,
const int64_t max_objects) {
// Check inputs are on the same device // Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, at::TensorArg objects_t{objects, "objects", 1},
points_first_idx_t{points_first_idx, "points_first_idx", 2}, objects_first_idx_t{objects_first_idx, "objects_first_idx", 2},
tris_t{tris, "tris", 3}, targets_t{targets, "targets", 3},
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4}; targets_first_idx_t{targets_first_idx, "targets_first_idx", 4};
at::CheckedFrom c = "PointFaceDistanceForwardCuda"; at::CheckedFrom c = "DistanceForwardCuda";
at::checkAllSameGPU( at::checkAllSameGPU(
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t}); c, {objects_t, objects_first_idx_t, targets_t, targets_first_idx_t});
at::checkAllSameType(c, {points_t, tris_t}); at::checkAllSameType(c, {objects_t, targets_t});
// Set the device for the kernel launch based on the device of the input // Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device()); at::cuda::CUDAGuard device_guard(objects.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0); const int64_t objects_size = objects.size(0);
const int64_t T = tris.size(0); const int64_t targets_size = targets.size(0);
const int64_t B = points_first_idx.size(0); const int64_t batch_size = objects_first_idx.size(0);
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); TORCH_CHECK(targets_first_idx.size(0) == batch_size);
TORCH_CHECK( if (objects_dim == 1) {
(tris.size(1) == 3) && (tris.size(2) == 3), TORCH_CHECK(
"tris must be of shape Tx3x3"); targets_dim >= 2 && targets_dim <= 3,
TORCH_CHECK(tris_first_idx.size(0) == B); "either object or target must be edge or face");
TORCH_CHECK(objects.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
targets.size(2) == 3,
"face must be of shape Tx3x3, lines must be of shape Tx2x3");
} else {
TORCH_CHECK(targets_dim == 1, "either object or target must be point");
TORCH_CHECK(
objects_dim >= 2 && objects_dim <= 3,
"either object or target must be edge or face");
TORCH_CHECK(targets.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
objects.size(2) == 3,
"face must be of shape Tx3x3, lines must be of shape Tx2x3");
}
// clang-format off // clang-format off
at::Tensor dists = at::zeros({P,}, points.options()); at::Tensor dists = at::zeros({objects_size,}, objects.options());
at::Tensor idxs = at::zeros({P,}, points_first_idx.options()); at::Tensor idxs = at::zeros({objects_size,}, objects_first_idx.options());
// clang-format on // clang-format on
if (dists.numel() == 0) { if (dists.numel() == 0) {
@ -141,24 +182,36 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
} }
const int threads = 128; const int threads = 128;
const dim3 blocks(max_points, B); const dim3 blocks(max_objects, batch_size);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
PointFaceForwardKernel<<<blocks, threads, shared_size, stream>>>( DistanceForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.contiguous().data_ptr<float>(), objects.contiguous().data_ptr<float>(),
points_first_idx.contiguous().data_ptr<int64_t>(), objects_size,
tris.contiguous().data_ptr<float>(), objects_dim,
tris_first_idx.contiguous().data_ptr<int64_t>(), targets.contiguous().data_ptr<float>(),
targets_size,
targets_dim,
objects_first_idx.contiguous().data_ptr<int64_t>(),
targets_first_idx.contiguous().data_ptr<int64_t>(),
batch_size,
dists.data_ptr<float>(), dists.data_ptr<float>(),
idxs.data_ptr<int64_t>(), idxs.data_ptr<int64_t>());
B,
P,
T);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs); return std::make_tuple(dists, idxs);
} }
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& points_first_idx,
const at::Tensor& tris,
const at::Tensor& tris_first_idx,
const int64_t max_points) {
return DistanceForwardCuda(
points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
}
__global__ void PointFaceBackwardKernel( __global__ void PointFaceBackwardKernel(
const float* __restrict__ points, // (P, 3) const float* __restrict__ points, // (P, 3)
const float* __restrict__ tris, // (T, 3, 3) const float* __restrict__ tris, // (T, 3, 3)
@ -265,149 +318,14 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
// * FacePointDistance * // * FacePointDistance *
// **************************************************************************** // ****************************************************************************
__global__ void FacePointForwardKernel(
const float* __restrict__ points, // (P, 3)
const int64_t* __restrict__ points_first_idx, // (B,)
const float* __restrict__ tris, // (T, 3, 3)
const int64_t* __restrict__ tris_first_idx, // (B,)
float* __restrict__ dist_tris, // (T,)
int64_t* __restrict__ idx_tris, // (T,)
const size_t B,
const size_t P,
const size_t T) {
float3* points_f3 = (float3*)points;
float3* tris_f3 = (float3*)tris;
// Single shared memory buffer which is split and cast to different types.
extern __shared__ char shared_buf[];
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
const size_t batch_idx = blockIdx.y; // index of batch element.
// start and end for points in batch_idx
const int64_t startp = points_first_idx[batch_idx];
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
// start and end for tris in batch_idx
const int64_t startt = tris_first_idx[batch_idx];
const int64_t endt = batch_idx + 1 < B ? tris_first_idx[batch_idx + 1] : T;
const size_t i = blockIdx.x; // index of point within batch element.
const size_t tid = threadIdx.x;
// Each block will compute one element of the output idx_tris[startt + i],
// dist_tris[startt + i]. Within the block we will use threads to compute
// the distances between tris[startt + i] and points[j] for all j belonging
// in the same batch as i, i.e. j in [startp, endp]. Then use a block
// reduction to take an argmin of the distances.
// If i exceeds the number of tris in batch_idx, then do nothing
if (i < (endt - startt)) {
const float3 v0 = tris_f3[(startt + i) * 3 + 0];
const float3 v1 = tris_f3[(startt + i) * 3 + 1];
const float3 v2 = tris_f3[(startt + i) * 3 + 2];
// Compute the distances between tris[startt + i] and points[j] for
// all j belonging in the same batch as i, i.e. j in [startp, endp].
// Here each thread will reduce over (endp-startp) / blockDim.x in serial,
// and store its result to shared memory
float min_dist = FLT_MAX;
size_t min_idx = 0;
for (size_t j = tid; j < (endp - startp); j += blockDim.x) {
// Retrieve (startp + i) point
const float3 p_f3 = points_f3[startp + j];
float dist = PointTriangle3DistanceForward(p_f3, v0, v1, v2);
min_dist = (j == tid) ? dist : min_dist;
min_idx = (dist <= min_dist) ? (startp + j) : min_idx;
min_dist = (dist <= min_dist) ? dist : min_dist;
}
min_dists[tid] = min_dist;
min_idxs[tid] = min_idx;
__syncthreads();
// Perform reduction in shared memory.
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
if (tid < s) {
if (min_dists[tid] > min_dists[tid + s]) {
min_dists[tid] = min_dists[tid + s];
min_idxs[tid] = min_idxs[tid + s];
}
}
__syncthreads();
}
// Unroll the last 6 iterations of the loop since they will happen
// synchronized within a single warp.
if (tid < 32)
WarpReduce<float>(min_dists, min_idxs, tid);
// Finally thread 0 writes the result to the output buffer.
if (tid == 0) {
idx_tris[startt + i] = min_idxs[0];
dist_tris[startt + i] = min_dists[0];
}
}
}
std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda( std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
const at::Tensor& points, const at::Tensor& points,
const at::Tensor& points_first_idx, const at::Tensor& points_first_idx,
const at::Tensor& tris, const at::Tensor& tris,
const at::Tensor& tris_first_idx, const at::Tensor& tris_first_idx,
const int64_t max_tris) { const int64_t max_tris) {
// Check inputs are on the same device return DistanceForwardCuda(
at::TensorArg points_t{points, "points", 1}, tris, 3, tris_first_idx, points, 1, points_first_idx, max_tris);
points_first_idx_t{points_first_idx, "points_first_idx", 2},
tris_t{tris, "tris", 3},
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
at::CheckedFrom c = "FacePointDistanceForwardCuda";
at::checkAllSameGPU(
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
at::checkAllSameType(c, {points_t, tris_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t P = points.size(0);
const int64_t T = tris.size(0);
const int64_t B = points_first_idx.size(0);
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
TORCH_CHECK(
(tris.size(1) == 3) && (tris.size(2) == 3),
"tris must be of shape Tx3x3");
TORCH_CHECK(tris_first_idx.size(0) == B);
// clang-format off
at::Tensor dists = at::zeros({T,}, tris.options());
at::Tensor idxs = at::zeros({T,}, tris_first_idx.options());
// clang-format on
if (dists.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
}
const int threads = 128;
const dim3 blocks(max_tris, B);
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
FacePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
points.contiguous().data_ptr<float>(),
points_first_idx.contiguous().data_ptr<int64_t>(),
tris.contiguous().data_ptr<float>(),
tris_first_idx.contiguous().data_ptr<int64_t>(),
dists.data_ptr<float>(),
idxs.data_ptr<int64_t>(),
B,
P,
T);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
} }
__global__ void FacePointBackwardKernel( __global__ void FacePointBackwardKernel(