diff --git a/pytorch3d/csrc/point_mesh/point_mesh_face.cu b/pytorch3d/csrc/point_mesh/point_mesh_face.cu index d43cfe7b..ec3e9f89 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_face.cu +++ b/pytorch3d/csrc/point_mesh/point_mesh_face.cu @@ -15,18 +15,23 @@ // * PointFaceDistance * // **************************************************************************** -__global__ void PointFaceForwardKernel( - const float* __restrict__ points, // (P, 3) - const int64_t* __restrict__ points_first_idx, // (B,) - const float* __restrict__ tris, // (T, 3, 3) - const int64_t* __restrict__ tris_first_idx, // (B,) - float* __restrict__ dist_points, // (P,) - int64_t* __restrict__ idx_points, // (P,) - const size_t B, - const size_t P, - const size_t T) { - float3* points_f3 = (float3*)points; - float3* tris_f3 = (float3*)tris; +__global__ void DistanceForwardKernel( + const float* __restrict__ objects, // (O * oD * 3) + const size_t objects_size, // O + const size_t objects_dim, // oD + const float* __restrict__ targets, // (T * tD * 3) + const size_t targets_size, // T + const size_t targets_dim, // tD + const int64_t* __restrict__ objects_first_idx, // (B,) + const int64_t* __restrict__ targets_first_idx, // (B,) + const size_t batch_size, // B + float* __restrict__ dist_objects, // (O,) + int64_t* __restrict__ idx_objects) { // (O,) + // This kernel is used interchangeably to compute bi-directional distances + // between points and triangles/lines. The direction of the distance computed, + // i.e. point to triangle/line or triangle/line to point, depends on the order + // of the input arguments and is inferred based on their shape. The shape is + // used to distinguish between triangles and lines // Single shared memory buffer which is split and cast to different types. extern __shared__ char shared_buf[]; @@ -35,39 +40,59 @@ __global__ void PointFaceForwardKernel( const size_t batch_idx = blockIdx.y; // index of batch element. - // start and end for points in batch_idx - const int64_t startp = points_first_idx[batch_idx]; - const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P; + // start and end for objects in batch_idx + const int64_t starto = objects_first_idx[batch_idx]; + const int64_t endo = batch_idx + 1 < batch_size + ? objects_first_idx[batch_idx + 1] + : objects_size; - // start and end for faces in batch_idx - const int64_t startt = tris_first_idx[batch_idx]; - const int64_t endt = batch_idx + 1 < B ? tris_first_idx[batch_idx + 1] : T; + // start and end for targets in batch_idx + const int64_t startt = targets_first_idx[batch_idx]; + const int64_t endt = batch_idx + 1 < batch_size + ? targets_first_idx[batch_idx + 1] + : targets_size; - const size_t i = blockIdx.x; // index of point within batch element. + const size_t i = blockIdx.x; // index within batch element. const size_t tid = threadIdx.x; // thread index - // Each block will compute one element of the output idx_points[startp + i], - // dist_points[startp + i]. Within the block we will use threads to compute - // the distances between points[startp + i] and tris[j] for all j belonging - // in the same batch as i, i.e. j in [startt, endt]. Then use a block - // reduction to take an argmin of the distances. + // Set references to points/face based on which of objects/targets refer to + // points/faces + float3* points_f3 = objects_dim == 1 ? (float3*)objects : (float3*)targets; + float3* face_f3 = objects_dim == 1 ? (float3*)targets : (float3*)objects; + // Distinguishes whether we're computing distance against triangle vs edge + bool isTriangle = objects_dim == 3 || targets_dim == 3; - // If i exceeds the number of points in batch_idx, then do nothing - if (i < (endp - startp)) { - // Retrieve (startp + i) point - const float3 p_f3 = points_f3[startp + i]; + // Each block will compute one element of the output idx_objects[starto + i], + // dist_objects[starto + i]. Within the block we will use threads to compute + // the distances between objects[starto + i] and targets[j] for all j + // belonging in the same batch as i, i.e. j in [startt, endt]. Then use a + // block reduction to take an argmin of the distances. - // Compute the distances between points[startp + i] and tris[j] for + // If i exceeds the number of objects in batch_idx, then do nothing + if (i < (endo - starto)) { + // Compute the distances between objects[starto + i] and targets[j] for // all j belonging in the same batch as i, i.e. j in [startt, endt]. // Here each thread will reduce over (endt-startt) / blockDim.x in serial, // and store its result to shared memory float min_dist = FLT_MAX; size_t min_idx = 0; for (size_t j = tid; j < (endt - startt); j += blockDim.x) { - const float3 v0 = tris_f3[(startt + j) * 3 + 0]; - const float3 v1 = tris_f3[(startt + j) * 3 + 1]; - const float3 v2 = tris_f3[(startt + j) * 3 + 2]; - float dist = PointTriangle3DistanceForward(p_f3, v0, v1, v2); + size_t point_idx = objects_dim == 1 ? starto + i : startt + j; + size_t face_idx = objects_dim == 1 ? (startt + j) * targets_dim + : (starto + i) * objects_dim; + + float dist; + if (isTriangle) { + dist = PointTriangle3DistanceForward( + points_f3[point_idx], + face_f3[face_idx], + face_f3[face_idx + 1], + face_f3[face_idx + 2]); + } else { + dist = PointLine3DistanceForward( + points_f3[point_idx], face_f3[face_idx], face_f3[face_idx + 1]); + } + min_dist = (j == tid) ? dist : min_dist; min_idx = (dist <= min_dist) ? (startt + j) : min_idx; min_dist = (dist <= min_dist) ? dist : min_dist; @@ -94,45 +119,61 @@ __global__ void PointFaceForwardKernel( // Finally thread 0 writes the result to the output buffer. if (tid == 0) { - idx_points[startp + i] = min_idxs[0]; - dist_points[startp + i] = min_dists[0]; + idx_objects[starto + i] = min_idxs[0]; + dist_objects[starto + i] = min_dists[0]; } } } -std::tuple PointFaceDistanceForwardCuda( - const at::Tensor& points, - const at::Tensor& points_first_idx, - const at::Tensor& tris, - const at::Tensor& tris_first_idx, - const int64_t max_points) { +std::tuple DistanceForwardCuda( + const at::Tensor& objects, + const size_t objects_dim, + const at::Tensor& objects_first_idx, + const at::Tensor& targets, + const size_t targets_dim, + const at::Tensor& targets_first_idx, + const int64_t max_objects) { // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - points_first_idx_t{points_first_idx, "points_first_idx", 2}, - tris_t{tris, "tris", 3}, - tris_first_idx_t{tris_first_idx, "tris_first_idx", 4}; - at::CheckedFrom c = "PointFaceDistanceForwardCuda"; + at::TensorArg objects_t{objects, "objects", 1}, + objects_first_idx_t{objects_first_idx, "objects_first_idx", 2}, + targets_t{targets, "targets", 3}, + targets_first_idx_t{targets_first_idx, "targets_first_idx", 4}; + at::CheckedFrom c = "DistanceForwardCuda"; at::checkAllSameGPU( - c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t}); - at::checkAllSameType(c, {points_t, tris_t}); + c, {objects_t, objects_first_idx_t, targets_t, targets_first_idx_t}); + at::checkAllSameType(c, {objects_t, targets_t}); // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); + at::cuda::CUDAGuard device_guard(objects.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const int64_t P = points.size(0); - const int64_t T = tris.size(0); - const int64_t B = points_first_idx.size(0); + const int64_t objects_size = objects.size(0); + const int64_t targets_size = targets.size(0); + const int64_t batch_size = objects_first_idx.size(0); - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (tris.size(1) == 3) && (tris.size(2) == 3), - "tris must be of shape Tx3x3"); - TORCH_CHECK(tris_first_idx.size(0) == B); + TORCH_CHECK(targets_first_idx.size(0) == batch_size); + if (objects_dim == 1) { + TORCH_CHECK( + targets_dim >= 2 && targets_dim <= 3, + "either object or target must be edge or face"); + TORCH_CHECK(objects.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( + targets.size(2) == 3, + "face must be of shape Tx3x3, lines must be of shape Tx2x3"); + } else { + TORCH_CHECK(targets_dim == 1, "either object or target must be point"); + TORCH_CHECK( + objects_dim >= 2 && objects_dim <= 3, + "either object or target must be edge or face"); + TORCH_CHECK(targets.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( + objects.size(2) == 3, + "face must be of shape Tx3x3, lines must be of shape Tx2x3"); + } // clang-format off - at::Tensor dists = at::zeros({P,}, points.options()); - at::Tensor idxs = at::zeros({P,}, points_first_idx.options()); + at::Tensor dists = at::zeros({objects_size,}, objects.options()); + at::Tensor idxs = at::zeros({objects_size,}, objects_first_idx.options()); // clang-format on if (dists.numel() == 0) { @@ -141,24 +182,36 @@ std::tuple PointFaceDistanceForwardCuda( } const int threads = 128; - const dim3 blocks(max_points, B); + const dim3 blocks(max_objects, batch_size); size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); - PointFaceForwardKernel<<>>( - points.contiguous().data_ptr(), - points_first_idx.contiguous().data_ptr(), - tris.contiguous().data_ptr(), - tris_first_idx.contiguous().data_ptr(), + DistanceForwardKernel<<>>( + objects.contiguous().data_ptr(), + objects_size, + objects_dim, + targets.contiguous().data_ptr(), + targets_size, + targets_dim, + objects_first_idx.contiguous().data_ptr(), + targets_first_idx.contiguous().data_ptr(), + batch_size, dists.data_ptr(), - idxs.data_ptr(), - B, - P, - T); + idxs.data_ptr()); AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(dists, idxs); } +std::tuple PointFaceDistanceForwardCuda( + const at::Tensor& points, + const at::Tensor& points_first_idx, + const at::Tensor& tris, + const at::Tensor& tris_first_idx, + const int64_t max_points) { + return DistanceForwardCuda( + points, 1, points_first_idx, tris, 3, tris_first_idx, max_points); +} + __global__ void PointFaceBackwardKernel( const float* __restrict__ points, // (P, 3) const float* __restrict__ tris, // (T, 3, 3) @@ -265,149 +318,14 @@ std::tuple PointFaceDistanceBackwardCuda( // * FacePointDistance * // **************************************************************************** -__global__ void FacePointForwardKernel( - const float* __restrict__ points, // (P, 3) - const int64_t* __restrict__ points_first_idx, // (B,) - const float* __restrict__ tris, // (T, 3, 3) - const int64_t* __restrict__ tris_first_idx, // (B,) - float* __restrict__ dist_tris, // (T,) - int64_t* __restrict__ idx_tris, // (T,) - const size_t B, - const size_t P, - const size_t T) { - float3* points_f3 = (float3*)points; - float3* tris_f3 = (float3*)tris; - - // Single shared memory buffer which is split and cast to different types. - extern __shared__ char shared_buf[]; - float* min_dists = (float*)shared_buf; // float[NUM_THREADS] - int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS] - - const size_t batch_idx = blockIdx.y; // index of batch element. - - // start and end for points in batch_idx - const int64_t startp = points_first_idx[batch_idx]; - const int64_t endp = batch_idx + 1 < B ? points_first_idx[batch_idx + 1] : P; - - // start and end for tris in batch_idx - const int64_t startt = tris_first_idx[batch_idx]; - const int64_t endt = batch_idx + 1 < B ? tris_first_idx[batch_idx + 1] : T; - - const size_t i = blockIdx.x; // index of point within batch element. - const size_t tid = threadIdx.x; - - // Each block will compute one element of the output idx_tris[startt + i], - // dist_tris[startt + i]. Within the block we will use threads to compute - // the distances between tris[startt + i] and points[j] for all j belonging - // in the same batch as i, i.e. j in [startp, endp]. Then use a block - // reduction to take an argmin of the distances. - - // If i exceeds the number of tris in batch_idx, then do nothing - if (i < (endt - startt)) { - const float3 v0 = tris_f3[(startt + i) * 3 + 0]; - const float3 v1 = tris_f3[(startt + i) * 3 + 1]; - const float3 v2 = tris_f3[(startt + i) * 3 + 2]; - - // Compute the distances between tris[startt + i] and points[j] for - // all j belonging in the same batch as i, i.e. j in [startp, endp]. - // Here each thread will reduce over (endp-startp) / blockDim.x in serial, - // and store its result to shared memory - float min_dist = FLT_MAX; - size_t min_idx = 0; - for (size_t j = tid; j < (endp - startp); j += blockDim.x) { - // Retrieve (startp + i) point - const float3 p_f3 = points_f3[startp + j]; - - float dist = PointTriangle3DistanceForward(p_f3, v0, v1, v2); - min_dist = (j == tid) ? dist : min_dist; - min_idx = (dist <= min_dist) ? (startp + j) : min_idx; - min_dist = (dist <= min_dist) ? dist : min_dist; - } - min_dists[tid] = min_dist; - min_idxs[tid] = min_idx; - __syncthreads(); - - // Perform reduction in shared memory. - for (int s = blockDim.x / 2; s > 32; s >>= 1) { - if (tid < s) { - if (min_dists[tid] > min_dists[tid + s]) { - min_dists[tid] = min_dists[tid + s]; - min_idxs[tid] = min_idxs[tid + s]; - } - } - __syncthreads(); - } - - // Unroll the last 6 iterations of the loop since they will happen - // synchronized within a single warp. - if (tid < 32) - WarpReduce(min_dists, min_idxs, tid); - - // Finally thread 0 writes the result to the output buffer. - if (tid == 0) { - idx_tris[startt + i] = min_idxs[0]; - dist_tris[startt + i] = min_dists[0]; - } - } -} - std::tuple FacePointDistanceForwardCuda( const at::Tensor& points, const at::Tensor& points_first_idx, const at::Tensor& tris, const at::Tensor& tris_first_idx, const int64_t max_tris) { - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - points_first_idx_t{points_first_idx, "points_first_idx", 2}, - tris_t{tris, "tris", 3}, - tris_first_idx_t{tris_first_idx, "tris_first_idx", 4}; - at::CheckedFrom c = "FacePointDistanceForwardCuda"; - at::checkAllSameGPU( - c, {points_t, points_first_idx_t, tris_t, tris_first_idx_t}); - at::checkAllSameType(c, {points_t, tris_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int64_t P = points.size(0); - const int64_t T = tris.size(0); - const int64_t B = points_first_idx.size(0); - - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (tris.size(1) == 3) && (tris.size(2) == 3), - "tris must be of shape Tx3x3"); - TORCH_CHECK(tris_first_idx.size(0) == B); - - // clang-format off - at::Tensor dists = at::zeros({T,}, tris.options()); - at::Tensor idxs = at::zeros({T,}, tris_first_idx.options()); - // clang-format on - - if (dists.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(dists, idxs); - } - - const int threads = 128; - const dim3 blocks(max_tris, B); - size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t); - - FacePointForwardKernel<<>>( - points.contiguous().data_ptr(), - points_first_idx.contiguous().data_ptr(), - tris.contiguous().data_ptr(), - tris_first_idx.contiguous().data_ptr(), - dists.data_ptr(), - idxs.data_ptr(), - B, - P, - T); - - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(dists, idxs); + return DistanceForwardCuda( + tris, 3, tris_first_idx, points, 1, points_first_idx, max_tris); } __global__ void FacePointBackwardKernel(