mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-26 08:06:00 +08:00
Summary: ### Overview The current C++ code for `pytorch3d.ops.ball_query()` performs floating point multiplication for every coordinate of every pair of points (up until the maximum number of neighbor points is reached). This PR modifies the code (for both CPU and CUDA versions) to implement idea presented [here](https://stackoverflow.com/a/3939525): a `D`-cube around the `D`-ball is first constructed, and any point pairs falling outside the cube are skipped, without explicitly computing the squared distances. This change is especially useful for when the dimension `D` and the number of points `P2` are large and the radius is much smaller than the overall volume of space occupied by the point clouds; as much as **~2.5x speedup** (CPU case; ~1.8x speedup in CUDA case) is observed when `D = 10` and `radius = 0.01`. In all benchmark cases, points were uniform randomly distributed inside a unit `D`-cube. The benchmark code used was different from `tests/benchmarks/bm_ball_query.py` (only the forward part is benchmarked, larger input sizes were used) and is stored in `tests/benchmarks/bm_ball_query_large.py`. ### Average time comparisons <img width="360" height="270" alt="cpu-03-0 01-avg" src="https://github.com/user-attachments/assets/6cc79893-7921-44af-9366-1766c3caf142" /> <img width="360" height="270" alt="cuda-03-0 01-avg" src="https://github.com/user-attachments/assets/5151647d-0273-40a3-aac6-8b9399ede18a" /> <img width="360" height="270" alt="cpu-03-0 10-avg" src="https://github.com/user-attachments/assets/a87bc150-a5eb-47cd-a4ba-83c2ec81edaf" /> <img width="360" height="270" alt="cuda-03-0 10-avg" src="https://github.com/user-attachments/assets/e3699a9f-dfd3-4dd3-b3c9-619296186d43" /> <img width="360" height="270" alt="cpu-10-0 01-avg" src="https://github.com/user-attachments/assets/5ec8c32d-8e4d-4ced-a94e-1b816b1cb0f8" /> <img width="360" height="270" alt="cuda-10-0 01-avg" src="https://github.com/user-attachments/assets/168a3dfc-777a-4fb3-8023-1ac8c13985b8" /> <img width="360" height="270" alt="cpu-10-0 10-avg" src="https://github.com/user-attachments/assets/43a57fd6-1e01-4c5e-87a9-8ef604ef5fa0" /> <img width="360" height="270" alt="cuda-10-0 10-avg" src="https://github.com/user-attachments/assets/a7c7cc69-f273-493e-95b8-3ba2bb2e32da" /> ### Peak time comparisons <img width="360" height="270" alt="cpu-03-0 01-peak" src="https://github.com/user-attachments/assets/5bbbea3f-ef9b-490d-ab0d-ce551711d74f" /> <img width="360" height="270" alt="cuda-03-0 01-peak" src="https://github.com/user-attachments/assets/30b5ab9b-45cb-4057-b69f-bda6e76bd1dc" /> <img width="360" height="270" alt="cpu-03-0 10-peak" src="https://github.com/user-attachments/assets/db69c333-e5ac-4305-8a86-a26a8a9fe80d" /> <img width="360" height="270" alt="cuda-03-0 10-peak" src="https://github.com/user-attachments/assets/82549656-1f12-409e-8160-dd4c4c9d14f7" /> <img width="360" height="270" alt="cpu-10-0 01-peak" src="https://github.com/user-attachments/assets/d0be8ef1-535e-47bc-b773-b87fad625bf0" /> <img width="360" height="270" alt="cuda-10-0 01-peak" src="https://github.com/user-attachments/assets/e308e66e-ae30-400f-8ad2-015517f6e1af" /> <img width="360" height="270" alt="cpu-10-0 10-peak" src="https://github.com/user-attachments/assets/c9b5bf59-9cc2-465c-ad5d-d4e23bdd138a" /> <img width="360" height="270" alt="cuda-10-0 10-peak" src="https://github.com/user-attachments/assets/311354d4-b488-400c-a1dc-c85a21917aa9" /> ### Full benchmark logs [benchmark-before-change.txt](https://github.com/user-attachments/files/22978300/benchmark-before-change.txt) [benchmark-after-change.txt](https://github.com/user-attachments/files/22978299/benchmark-after-change.txt) Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/2006 Reviewed By: shapovalov Differential Revision: D85356394 Pulled By: bottler fbshipit-source-id: 9b3ce5fc87bb73d4323cc5b4190fc38ae42f41b2
147 lines
5.1 KiB
Plaintext
147 lines
5.1 KiB
Plaintext
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <math.h>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
|
|
// A chunk of work is blocksize-many points of P1.
|
|
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
|
|
// call (1+(P1-1)/blocksize) chunks_per_cloud
|
|
// These chunks are divided among the gridSize-many blocks.
|
|
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
|
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
|
|
// blocksize*(i%chunks_per_cloud).
|
|
|
|
template <typename scalar_t>
|
|
__global__ void BallQueryKernel(
|
|
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p1,
|
|
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p2,
|
|
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
|
|
lengths1,
|
|
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
|
|
lengths2,
|
|
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
|
|
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
|
|
const int64_t K,
|
|
const float radius,
|
|
const float radius2,
|
|
const bool skip_points_outside_cube) {
|
|
const int64_t N = p1.size(0);
|
|
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
|
|
const int64_t chunks_to_do = N * chunks_per_cloud;
|
|
const int D = p1.size(2);
|
|
|
|
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
|
|
const int64_t n = chunk / chunks_per_cloud; // batch_index
|
|
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
|
|
int64_t i = start_point + threadIdx.x;
|
|
|
|
// Check if point is valid in heterogeneous tensor
|
|
if (i >= lengths1[n]) {
|
|
continue;
|
|
}
|
|
|
|
// Iterate over points in p2 until desired count is reached or
|
|
// all points have been considered
|
|
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
|
|
if (skip_points_outside_cube) {
|
|
bool is_within_radius = true;
|
|
// Filter when any one coordinate is already outside the radius
|
|
for (int d = 0; is_within_radius && d < D; ++d) {
|
|
scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]);
|
|
is_within_radius = (abs_diff <= radius);
|
|
}
|
|
if (!is_within_radius) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Else, calculate the distance between the points and compare
|
|
scalar_t dist2 = 0.0;
|
|
for (int d = 0; d < D; ++d) {
|
|
scalar_t diff = p1[n][i][d] - p2[n][j][d];
|
|
dist2 += (diff * diff);
|
|
}
|
|
|
|
if (dist2 < radius2) {
|
|
// If the point is within the radius
|
|
// Set the value of the index to the point index
|
|
idxs[n][i][count] = j;
|
|
dists[n][i][count] = dist2;
|
|
|
|
// increment the number of selected samples for the point i
|
|
++count;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
|
const at::Tensor& p1, // (N, P1, 3)
|
|
const at::Tensor& p2, // (N, P2, 3)
|
|
const at::Tensor& lengths1, // (N,)
|
|
const at::Tensor& lengths2, // (N,)
|
|
int K,
|
|
float radius,
|
|
bool skip_points_outside_cube) {
|
|
// Check inputs are on the same device
|
|
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
|
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
|
|
at::CheckedFrom c = "BallQueryCuda";
|
|
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
|
|
at::checkAllSameType(c, {p1_t, p2_t});
|
|
|
|
// Set the device for the kernel launch based on the device of p1
|
|
at::cuda::CUDAGuard device_guard(p1.device());
|
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
TORCH_CHECK(
|
|
p2.size(2) == p1.size(2), "Point sets must have the same last dimension");
|
|
|
|
const int N = p1.size(0);
|
|
const int P1 = p1.size(1);
|
|
const int64_t K_64 = K;
|
|
const float radius2 = radius * radius;
|
|
|
|
// Output tensor with indices of neighbors for each point in p1
|
|
auto long_dtype = lengths1.options().dtype(at::kLong);
|
|
auto idxs = at::full({N, P1, K}, -1, long_dtype);
|
|
auto dists = at::zeros({N, P1, K}, p1.options());
|
|
|
|
if (idxs.numel() == 0) {
|
|
AT_CUDA_CHECK(cudaGetLastError());
|
|
return std::make_tuple(idxs, dists);
|
|
}
|
|
|
|
const size_t blocks = 256;
|
|
const size_t threads = 256;
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(
|
|
p1.scalar_type(), "ball_query_kernel_cuda", ([&] {
|
|
BallQueryKernel<<<blocks, threads, 0, stream>>>(
|
|
p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
|
p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
|
lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
|
|
lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
|
|
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
|
|
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
|
K_64,
|
|
radius,
|
|
radius2,
|
|
skip_points_outside_cube);
|
|
}));
|
|
|
|
AT_CUDA_CHECK(cudaGetLastError());
|
|
|
|
return std::make_tuple(idxs, dists);
|
|
}
|