mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add CPU implementation for nearest neighbor
Summary: Adds a CPU implementation for `pytorch3d.ops.nn_points_idx`. Also renames the associated C++ and CUDA functions to use `AllCaps` names used in other C++ / CUDA code. Reviewed By: gkioxari Differential Revision: D19670491 fbshipit-source-id: 1b6409404025bf05e6a93f5d847e35afc9062f05
This commit is contained in:
parent
25c2f34096
commit
e290f87ca9
@ -11,7 +11,7 @@
|
|||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("face_areas_normals", &face_areas_normals);
|
m.def("face_areas_normals", &face_areas_normals);
|
||||||
m.def("packed_to_padded_tensor", &packed_to_padded_tensor);
|
m.def("packed_to_padded_tensor", &packed_to_padded_tensor);
|
||||||
m.def("nn_points_idx", &nn_points_idx);
|
m.def("nn_points_idx", &NearestNeighborIdx);
|
||||||
m.def("gather_scatter", &gather_scatter);
|
m.def("gather_scatter", &gather_scatter);
|
||||||
m.def("rasterize_points", &RasterizePoints);
|
m.def("rasterize_points", &RasterizePoints);
|
||||||
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include <float.h>
|
#include <float.h>
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ void warp_reduce(
|
__device__ void WarpReduce(
|
||||||
volatile scalar_t* min_dists,
|
volatile scalar_t* min_dists,
|
||||||
volatile long* min_idxs,
|
volatile long* min_idxs,
|
||||||
const size_t tid) {
|
const size_t tid) {
|
||||||
@ -54,7 +54,7 @@ __device__ void warp_reduce(
|
|||||||
// is aligned.
|
// is aligned.
|
||||||
//
|
//
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void nearest_neighbor_kernel(
|
__global__ void NearestNeighborKernel(
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
long* __restrict__ idx,
|
long* __restrict__ idx,
|
||||||
@ -123,7 +123,7 @@ __global__ void nearest_neighbor_kernel(
|
|||||||
// Unroll the last 6 iterations of the loop since they will happen
|
// Unroll the last 6 iterations of the loop since they will happen
|
||||||
// synchronized within a single warp.
|
// synchronized within a single warp.
|
||||||
if (tid < 32)
|
if (tid < 32)
|
||||||
warp_reduce<scalar_t>(min_dists, min_idxs, tid);
|
WarpReduce<scalar_t>(min_dists, min_idxs, tid);
|
||||||
|
|
||||||
// Finally thread 0 writes the result to the output buffer.
|
// Finally thread 0 writes the result to the output buffer.
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
@ -144,7 +144,7 @@ __global__ void nearest_neighbor_kernel(
|
|||||||
// P2: Number of points in points2.
|
// P2: Number of points in points2.
|
||||||
//
|
//
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void nearest_neighbor_kernel_D3(
|
__global__ void NearestNeighborKernelD3(
|
||||||
const scalar_t* __restrict__ points1,
|
const scalar_t* __restrict__ points1,
|
||||||
const scalar_t* __restrict__ points2,
|
const scalar_t* __restrict__ points2,
|
||||||
long* __restrict__ idx,
|
long* __restrict__ idx,
|
||||||
@ -204,7 +204,7 @@ __global__ void nearest_neighbor_kernel_D3(
|
|||||||
// Unroll the last 6 iterations of the loop since they will happen
|
// Unroll the last 6 iterations of the loop since they will happen
|
||||||
// synchronized within a single warp.
|
// synchronized within a single warp.
|
||||||
if (tid < 32)
|
if (tid < 32)
|
||||||
warp_reduce<scalar_t>(min_dists, min_idxs, tid);
|
WarpReduce<scalar_t>(min_dists, min_idxs, tid);
|
||||||
|
|
||||||
// Finally thread 0 writes the result to the output buffer.
|
// Finally thread 0 writes the result to the output buffer.
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
@ -212,7 +212,7 @@ __global__ void nearest_neighbor_kernel_D3(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
|
at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
|
||||||
const auto N = p1.size(0);
|
const auto N = p1.size(0);
|
||||||
const auto P1 = p1.size(1);
|
const auto P1 = p1.size(1);
|
||||||
const auto P2 = p2.size(1);
|
const auto P2 = p2.size(1);
|
||||||
@ -231,7 +231,7 @@ at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
|
|||||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_v3_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_v3_cuda", ([&] {
|
||||||
size_t shared_size = threads * sizeof(size_t) +
|
size_t shared_size = threads * sizeof(size_t) +
|
||||||
threads * sizeof(long);
|
threads * sizeof(long);
|
||||||
nearest_neighbor_kernel_D3<scalar_t>
|
NearestNeighborKernelD3<scalar_t>
|
||||||
<<<blocks, threads, shared_size>>>(
|
<<<blocks, threads, shared_size>>>(
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.data_ptr<scalar_t>(),
|
||||||
@ -249,7 +249,7 @@ at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
|
|||||||
size_t D_2 = D + (D % 2);
|
size_t D_2 = D + (D % 2);
|
||||||
size_t shared_size = (D_2 + threads) * sizeof(size_t);
|
size_t shared_size = (D_2 + threads) * sizeof(size_t);
|
||||||
shared_size += threads * sizeof(long);
|
shared_size += threads * sizeof(long);
|
||||||
nearest_neighbor_kernel<scalar_t><<<blocks, threads, shared_size>>>(
|
NearestNeighborKernel<scalar_t><<<blocks, threads, shared_size>>>(
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.data_ptr<scalar_t>(),
|
||||||
idx.data_ptr<long>(),
|
idx.data_ptr<long>(),
|
||||||
|
@ -19,19 +19,22 @@
|
|||||||
// to p1[n, i] in the cloud p2[n] is p2[n, j].
|
// to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// CPU implementation.
|
||||||
|
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2);
|
||||||
|
|
||||||
// Cuda implementation.
|
// Cuda implementation.
|
||||||
at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2);
|
at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2);
|
||||||
|
|
||||||
// Implementation which is exposed.
|
// Implementation which is exposed.
|
||||||
at::Tensor nn_points_idx(at::Tensor p1, at::Tensor p2) {
|
at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) {
|
||||||
if (p1.type().is_cuda() && p2.type().is_cuda()) {
|
if (p1.type().is_cuda() && p2.type().is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(p1);
|
CHECK_CONTIGUOUS_CUDA(p1);
|
||||||
CHECK_CONTIGUOUS_CUDA(p2);
|
CHECK_CONTIGUOUS_CUDA(p2);
|
||||||
return nn_points_idx_cuda(p1, p2);
|
return NearestNeighborIdxCuda(p1, p2);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
AT_ERROR("Not implemented on the CPU.");
|
return NearestNeighborIdxCpu(p1, p2);
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2) {
|
||||||
|
const int N = p1.size(0);
|
||||||
|
const int P1 = p1.size(1);
|
||||||
|
const int D = p1.size(2);
|
||||||
|
const int P2 = p2.size(1);
|
||||||
|
|
||||||
|
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||||
|
torch::Tensor out = torch::empty({N, P1}, long_opts);
|
||||||
|
|
||||||
|
auto p1_a = p1.accessor<float, 3>();
|
||||||
|
auto p2_a = p2.accessor<float, 3>();
|
||||||
|
auto out_a = out.accessor<int64_t, 2>();
|
||||||
|
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
for (int i1 = 0; i1 < P1; ++i1) {
|
||||||
|
// TODO: support other floating-point types?
|
||||||
|
float min_dist = -1;
|
||||||
|
int64_t min_idx = -1;
|
||||||
|
for (int i2 = 0; i2 < P2; ++i2) {
|
||||||
|
float dist = 0;
|
||||||
|
for (int d = 0; d < D; ++d) {
|
||||||
|
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||||
|
dist += diff * diff;
|
||||||
|
}
|
||||||
|
if (min_dist == -1 || dist < min_dist) {
|
||||||
|
min_dist = dist;
|
||||||
|
min_idx = i2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out_a[n][i1] = min_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
@ -27,6 +27,13 @@ def bm_nn_points() -> None:
|
|||||||
warmup_iters=1,
|
warmup_iters=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
benchmark(
|
||||||
|
TestNearestNeighborPoints.bm_nn_points_cpu_with_init,
|
||||||
|
"NN_CPU",
|
||||||
|
kwargs_list,
|
||||||
|
warmup_iters=1,
|
||||||
|
)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
benchmark(
|
benchmark(
|
||||||
TestNearestNeighborPoints.bm_nn_points_cuda_with_init,
|
TestNearestNeighborPoints.bm_nn_points_cuda_with_init,
|
||||||
|
@ -21,11 +21,7 @@ class TestNearestNeighborPoints(unittest.TestCase):
|
|||||||
idx = dists2.argmin(2)
|
idx = dists2.argmin(2)
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def test_nn_cuda(self):
|
def _test_nn_helper(self, device):
|
||||||
"""
|
|
||||||
Test cuda output vs naive python implementation.
|
|
||||||
"""
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
for D in [3, 4]:
|
for D in [3, 4]:
|
||||||
for N in [1, 4]:
|
for N in [1, 4]:
|
||||||
for P1 in [1, 8, 64, 128]:
|
for P1 in [1, 8, 64, 128]:
|
||||||
@ -43,16 +39,32 @@ class TestNearestNeighborPoints(unittest.TestCase):
|
|||||||
self.assertTrue(idx1.size(1) == P1)
|
self.assertTrue(idx1.size(1) == P1)
|
||||||
self.assertTrue(torch.all(idx1 == idx2))
|
self.assertTrue(torch.all(idx1 == idx2))
|
||||||
|
|
||||||
def test_nn_cuda_error(self):
|
def test_nn_cuda(self):
|
||||||
"""
|
"""
|
||||||
Check that nn_points_idx throws an error if cpu tensors
|
Test cuda output vs naive python implementation.
|
||||||
are given as input.
|
|
||||||
"""
|
"""
|
||||||
x = torch.randn(1, 1, 3)
|
device = torch.device('cuda:0')
|
||||||
y = torch.randn(1, 1, 3)
|
self._test_nn_helper(device)
|
||||||
with self.assertRaises(Exception) as err:
|
|
||||||
_C.nn_points_idx(x, y)
|
def test_nn_cpu(self):
|
||||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
"""
|
||||||
|
Test cpu output vs naive python implementation
|
||||||
|
"""
|
||||||
|
device = torch.device('cpu')
|
||||||
|
self._test_nn_helper(device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def bm_nn_points_cpu_with_init(
|
||||||
|
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
|
||||||
|
):
|
||||||
|
device = torch.device('cpu')
|
||||||
|
x = torch.randn(N, P1, D, device=device)
|
||||||
|
y = torch.randn(N, P2, D, device=device)
|
||||||
|
|
||||||
|
def nn_cpu():
|
||||||
|
_C.nn_points_idx(x.contiguous(), y.contiguous())
|
||||||
|
|
||||||
|
return nn_cpu
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bm_nn_points_cuda_with_init(
|
def bm_nn_points_cuda_with_init(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user