mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Consolidate point mesh forward kernels
Summary: This diff creates the generic MeshForwardKernel which can handle distance calculations between point, edge and faces in either direction. Replaces only point_mesh_face code for now. Reviewed By: gkioxari Differential Revision: D24543316 fbshipit-source-id: 302707d7cec2d77a899738adf40481035c240da8
This commit is contained in:
parent
194b29fb6c
commit
c41aff23f0
@ -15,18 +15,23 @@
|
||||
// * PointFaceDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void PointFaceForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const int64_t* __restrict__ points_first_idx, // (B,)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
const int64_t* __restrict__ tris_first_idx, // (B,)
|
||||
float* __restrict__ dist_points, // (P,)
|
||||
int64_t* __restrict__ idx_points, // (P,)
|
||||
const size_t B,
|
||||
const size_t P,
|
||||
const size_t T) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* tris_f3 = (float3*)tris;
|
||||
__global__ void DistanceForwardKernel(
|
||||
const float* __restrict__ objects, // (O * oD * 3)
|
||||
const size_t objects_size, // O
|
||||
const size_t objects_dim, // oD
|
||||
const float* __restrict__ targets, // (T * tD * 3)
|
||||
const size_t targets_size, // T
|
||||
const size_t targets_dim, // tD
|
||||
const int64_t* __restrict__ objects_first_idx, // (B,)
|
||||
const int64_t* __restrict__ targets_first_idx, // (B,)
|
||||
const size_t batch_size, // B
|
||||
float* __restrict__ dist_objects, // (O,)
|
||||
int64_t* __restrict__ idx_objects) { // (O,)
|
||||
// This kernel is used interchangeably to compute bi-directional distances
|
||||
// between points and triangles/lines. The direction of the distance computed,
|
||||
// i.e. point to triangle/line or triangle/line to point, depends on the order
|
||||
// of the input arguments and is inferred based on their shape. The shape is
|
||||
// used to distinguish between triangles and lines
|
||||
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
@ -35,39 +40,59 @@ __global__ void PointFaceForwardKernel(
|
||||
|
||||
const size_t batch_idx = blockIdx.y; // index of batch element.
|
||||
|
||||
// start and end for points in batch_idx
|
||||
const int64_t startp = points_first_idx[batch_idx];
|
||||
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
|
||||
// start and end for objects in batch_idx
|
||||
const int64_t starto = objects_first_idx[batch_idx];
|
||||
const int64_t endo = batch_idx + 1 < batch_size
|
||||
? objects_first_idx[batch_idx + 1]
|
||||
: objects_size;
|
||||
|
||||
// start and end for faces in batch_idx
|
||||
const int64_t startt = tris_first_idx[batch_idx];
|
||||
const int64_t endt = batch_idx + 1 < B ? tris_first_idx[batch_idx + 1] : T;
|
||||
// start and end for targets in batch_idx
|
||||
const int64_t startt = targets_first_idx[batch_idx];
|
||||
const int64_t endt = batch_idx + 1 < batch_size
|
||||
? targets_first_idx[batch_idx + 1]
|
||||
: targets_size;
|
||||
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
const size_t i = blockIdx.x; // index within batch element.
|
||||
const size_t tid = threadIdx.x; // thread index
|
||||
|
||||
// Each block will compute one element of the output idx_points[startp + i],
|
||||
// dist_points[startp + i]. Within the block we will use threads to compute
|
||||
// the distances between points[startp + i] and tris[j] for all j belonging
|
||||
// in the same batch as i, i.e. j in [startt, endt]. Then use a block
|
||||
// reduction to take an argmin of the distances.
|
||||
// Set references to points/face based on which of objects/targets refer to
|
||||
// points/faces
|
||||
float3* points_f3 = objects_dim == 1 ? (float3*)objects : (float3*)targets;
|
||||
float3* face_f3 = objects_dim == 1 ? (float3*)targets : (float3*)objects;
|
||||
// Distinguishes whether we're computing distance against triangle vs edge
|
||||
bool isTriangle = objects_dim == 3 || targets_dim == 3;
|
||||
|
||||
// If i exceeds the number of points in batch_idx, then do nothing
|
||||
if (i < (endp - startp)) {
|
||||
// Retrieve (startp + i) point
|
||||
const float3 p_f3 = points_f3[startp + i];
|
||||
// Each block will compute one element of the output idx_objects[starto + i],
|
||||
// dist_objects[starto + i]. Within the block we will use threads to compute
|
||||
// the distances between objects[starto + i] and targets[j] for all j
|
||||
// belonging in the same batch as i, i.e. j in [startt, endt]. Then use a
|
||||
// block reduction to take an argmin of the distances.
|
||||
|
||||
// Compute the distances between points[startp + i] and tris[j] for
|
||||
// If i exceeds the number of objects in batch_idx, then do nothing
|
||||
if (i < (endo - starto)) {
|
||||
// Compute the distances between objects[starto + i] and targets[j] for
|
||||
// all j belonging in the same batch as i, i.e. j in [startt, endt].
|
||||
// Here each thread will reduce over (endt-startt) / blockDim.x in serial,
|
||||
// and store its result to shared memory
|
||||
float min_dist = FLT_MAX;
|
||||
size_t min_idx = 0;
|
||||
for (size_t j = tid; j < (endt - startt); j += blockDim.x) {
|
||||
const float3 v0 = tris_f3[(startt + j) * 3 + 0];
|
||||
const float3 v1 = tris_f3[(startt + j) * 3 + 1];
|
||||
const float3 v2 = tris_f3[(startt + j) * 3 + 2];
|
||||
float dist = PointTriangle3DistanceForward(p_f3, v0, v1, v2);
|
||||
size_t point_idx = objects_dim == 1 ? starto + i : startt + j;
|
||||
size_t face_idx = objects_dim == 1 ? (startt + j) * targets_dim
|
||||
: (starto + i) * objects_dim;
|
||||
|
||||
float dist;
|
||||
if (isTriangle) {
|
||||
dist = PointTriangle3DistanceForward(
|
||||
points_f3[point_idx],
|
||||
face_f3[face_idx],
|
||||
face_f3[face_idx + 1],
|
||||
face_f3[face_idx + 2]);
|
||||
} else {
|
||||
dist = PointLine3DistanceForward(
|
||||
points_f3[point_idx], face_f3[face_idx], face_f3[face_idx + 1]);
|
||||
}
|
||||
|
||||
min_dist = (j == tid) ? dist : min_dist;
|
||||
min_idx = (dist <= min_dist) ? (startt + j) : min_idx;
|
||||
min_dist = (dist <= min_dist) ? dist : min_dist;
|
||||
@ -94,45 +119,61 @@ __global__ void PointFaceForwardKernel(
|
||||
|
||||
// Finally thread 0 writes the result to the output buffer.
|
||||
if (tid == 0) {
|
||||
idx_points[startp + i] = min_idxs[0];
|
||||
dist_points[startp + i] = min_dists[0];
|
||||
idx_objects[starto + i] = min_idxs[0];
|
||||
dist_objects[starto + i] = min_dists[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_points) {
|
||||
std::tuple<at::Tensor, at::Tensor> DistanceForwardCuda(
|
||||
const at::Tensor& objects,
|
||||
const size_t objects_dim,
|
||||
const at::Tensor& objects_first_idx,
|
||||
const at::Tensor& targets,
|
||||
const size_t targets_dim,
|
||||
const at::Tensor& targets_first_idx,
|
||||
const int64_t max_objects) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
points_first_idx_t{points_first_idx, "points_first_idx", 2},
|
||||
tris_t{tris, "tris", 3},
|
||||
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
|
||||
at::CheckedFrom c = "PointFaceDistanceForwardCuda";
|
||||
at::TensorArg objects_t{objects, "objects", 1},
|
||||
objects_first_idx_t{objects_first_idx, "objects_first_idx", 2},
|
||||
targets_t{targets, "targets", 3},
|
||||
targets_first_idx_t{targets_first_idx, "targets_first_idx", 4};
|
||||
at::CheckedFrom c = "DistanceForwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t});
|
||||
c, {objects_t, objects_first_idx_t, targets_t, targets_first_idx_t});
|
||||
at::checkAllSameType(c, {objects_t, targets_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
at::cuda::CUDAGuard device_guard(objects.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
const int64_t objects_size = objects.size(0);
|
||||
const int64_t targets_size = targets.size(0);
|
||||
const int64_t batch_size = objects_first_idx.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
TORCH_CHECK(tris_first_idx.size(0) == B);
|
||||
TORCH_CHECK(targets_first_idx.size(0) == batch_size);
|
||||
if (objects_dim == 1) {
|
||||
TORCH_CHECK(
|
||||
targets_dim >= 2 && targets_dim <= 3,
|
||||
"either object or target must be edge or face");
|
||||
TORCH_CHECK(objects.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
targets.size(2) == 3,
|
||||
"face must be of shape Tx3x3, lines must be of shape Tx2x3");
|
||||
} else {
|
||||
TORCH_CHECK(targets_dim == 1, "either object or target must be point");
|
||||
TORCH_CHECK(
|
||||
objects_dim >= 2 && objects_dim <= 3,
|
||||
"either object or target must be edge or face");
|
||||
TORCH_CHECK(targets.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
objects.size(2) == 3,
|
||||
"face must be of shape Tx3x3, lines must be of shape Tx2x3");
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
at::Tensor dists = at::zeros({P,}, points.options());
|
||||
at::Tensor idxs = at::zeros({P,}, points_first_idx.options());
|
||||
at::Tensor dists = at::zeros({objects_size,}, objects.options());
|
||||
at::Tensor idxs = at::zeros({objects_size,}, objects_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
@ -141,24 +182,36 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
||||
}
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_points, B);
|
||||
const dim3 blocks(max_objects, batch_size);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
PointFaceForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
points_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
tris.contiguous().data_ptr<float>(),
|
||||
tris_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
DistanceForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
objects.contiguous().data_ptr<float>(),
|
||||
objects_size,
|
||||
objects_dim,
|
||||
targets.contiguous().data_ptr<float>(),
|
||||
targets_size,
|
||||
targets_dim,
|
||||
objects_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
targets_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
batch_size,
|
||||
dists.data_ptr<float>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
B,
|
||||
P,
|
||||
T);
|
||||
idxs.data_ptr<int64_t>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_points) {
|
||||
return DistanceForwardCuda(
|
||||
points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
|
||||
}
|
||||
|
||||
__global__ void PointFaceBackwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
@ -265,149 +318,14 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
||||
// * FacePointDistance *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void FacePointForwardKernel(
|
||||
const float* __restrict__ points, // (P, 3)
|
||||
const int64_t* __restrict__ points_first_idx, // (B,)
|
||||
const float* __restrict__ tris, // (T, 3, 3)
|
||||
const int64_t* __restrict__ tris_first_idx, // (B,)
|
||||
float* __restrict__ dist_tris, // (T,)
|
||||
int64_t* __restrict__ idx_tris, // (T,)
|
||||
const size_t B,
|
||||
const size_t P,
|
||||
const size_t T) {
|
||||
float3* points_f3 = (float3*)points;
|
||||
float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
float* min_dists = (float*)shared_buf; // float[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t batch_idx = blockIdx.y; // index of batch element.
|
||||
|
||||
// start and end for points in batch_idx
|
||||
const int64_t startp = points_first_idx[batch_idx];
|
||||
const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P;
|
||||
|
||||
// start and end for tris in batch_idx
|
||||
const int64_t startt = tris_first_idx[batch_idx];
|
||||
const int64_t endt = batch_idx + 1 < B ? tris_first_idx[batch_idx + 1] : T;
|
||||
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
const size_t tid = threadIdx.x;
|
||||
|
||||
// Each block will compute one element of the output idx_tris[startt + i],
|
||||
// dist_tris[startt + i]. Within the block we will use threads to compute
|
||||
// the distances between tris[startt + i] and points[j] for all j belonging
|
||||
// in the same batch as i, i.e. j in [startp, endp]. Then use a block
|
||||
// reduction to take an argmin of the distances.
|
||||
|
||||
// If i exceeds the number of tris in batch_idx, then do nothing
|
||||
if (i < (endt - startt)) {
|
||||
const float3 v0 = tris_f3[(startt + i) * 3 + 0];
|
||||
const float3 v1 = tris_f3[(startt + i) * 3 + 1];
|
||||
const float3 v2 = tris_f3[(startt + i) * 3 + 2];
|
||||
|
||||
// Compute the distances between tris[startt + i] and points[j] for
|
||||
// all j belonging in the same batch as i, i.e. j in [startp, endp].
|
||||
// Here each thread will reduce over (endp-startp) / blockDim.x in serial,
|
||||
// and store its result to shared memory
|
||||
float min_dist = FLT_MAX;
|
||||
size_t min_idx = 0;
|
||||
for (size_t j = tid; j < (endp - startp); j += blockDim.x) {
|
||||
// Retrieve (startp + i) point
|
||||
const float3 p_f3 = points_f3[startp + j];
|
||||
|
||||
float dist = PointTriangle3DistanceForward(p_f3, v0, v1, v2);
|
||||
min_dist = (j == tid) ? dist : min_dist;
|
||||
min_idx = (dist <= min_dist) ? (startp + j) : min_idx;
|
||||
min_dist = (dist <= min_dist) ? dist : min_dist;
|
||||
}
|
||||
min_dists[tid] = min_dist;
|
||||
min_idxs[tid] = min_idx;
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
min_idxs[tid] = min_idxs[tid + s];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Unroll the last 6 iterations of the loop since they will happen
|
||||
// synchronized within a single warp.
|
||||
if (tid < 32)
|
||||
WarpReduce<float>(min_dists, min_idxs, tid);
|
||||
|
||||
// Finally thread 0 writes the result to the output buffer.
|
||||
if (tid == 0) {
|
||||
idx_tris[startt + i] = min_idxs[0];
|
||||
dist_tris[startt + i] = min_dists[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& points_first_idx,
|
||||
const at::Tensor& tris,
|
||||
const at::Tensor& tris_first_idx,
|
||||
const int64_t max_tris) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
points_first_idx_t{points_first_idx, "points_first_idx", 2},
|
||||
tris_t{tris, "tris", 3},
|
||||
tris_first_idx_t{tris_first_idx, "tris_first_idx", 4};
|
||||
at::CheckedFrom c = "FacePointDistanceForwardCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t P = points.size(0);
|
||||
const int64_t T = tris.size(0);
|
||||
const int64_t B = points_first_idx.size(0);
|
||||
|
||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
||||
TORCH_CHECK(
|
||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
||||
"tris must be of shape Tx3x3");
|
||||
TORCH_CHECK(tris_first_idx.size(0) == B);
|
||||
|
||||
// clang-format off
|
||||
at::Tensor dists = at::zeros({T,}, tris.options());
|
||||
at::Tensor idxs = at::zeros({T,}, tris_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
if (dists.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
}
|
||||
|
||||
const int threads = 128;
|
||||
const dim3 blocks(max_tris, B);
|
||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||
|
||||
FacePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
points_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
tris.contiguous().data_ptr<float>(),
|
||||
tris_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
dists.data_ptr<float>(),
|
||||
idxs.data_ptr<int64_t>(),
|
||||
B,
|
||||
P,
|
||||
T);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(dists, idxs);
|
||||
return DistanceForwardCuda(
|
||||
tris, 3, tris_first_idx, points, 1, points_first_idx, max_tris);
|
||||
}
|
||||
|
||||
__global__ void FacePointBackwardKernel(
|
||||
|
Loading…
x
Reference in New Issue
Block a user