mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
remove nearest_neighbors
Summary: knn is more general and faster than the nearest_neighbor code, so remove the latter. Reviewed By: gkioxari Differential Revision: D20816424 fbshipit-source-id: 75d6c44d17180752d0c9859814bbdf7892558158
This commit is contained in:
parent
790eb8c402
commit
3794f6753f
@ -7,7 +7,6 @@
|
|||||||
#include "face_areas_normals/face_areas_normals.h"
|
#include "face_areas_normals/face_areas_normals.h"
|
||||||
#include "gather_scatter/gather_scatter.h"
|
#include "gather_scatter/gather_scatter.h"
|
||||||
#include "knn/knn.h"
|
#include "knn/knn.h"
|
||||||
#include "nearest_neighbor_points/nearest_neighbor_points.h"
|
|
||||||
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
||||||
#include "point_mesh/point_mesh_edge.h"
|
#include "point_mesh/point_mesh_edge.h"
|
||||||
#include "point_mesh/point_mesh_face.h"
|
#include "point_mesh/point_mesh_face.h"
|
||||||
@ -21,7 +20,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("padded_to_packed", &PaddedToPacked);
|
m.def("padded_to_packed", &PaddedToPacked);
|
||||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||||
m.def("knn_points_backward", &KNearestNeighborBackward);
|
m.def("knn_points_backward", &KNearestNeighborBackward);
|
||||||
m.def("nn_points_idx", &NearestNeighborIdx);
|
|
||||||
m.def("gather_scatter", &gather_scatter);
|
m.def("gather_scatter", &gather_scatter);
|
||||||
m.def("rasterize_points", &RasterizePoints);
|
m.def("rasterize_points", &RasterizePoints);
|
||||||
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2) {
|
|
||||||
const int N = p1.size(0);
|
|
||||||
const int P1 = p1.size(1);
|
|
||||||
const int D = p1.size(2);
|
|
||||||
const int P2 = p2.size(1);
|
|
||||||
|
|
||||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
|
||||||
torch::Tensor out = torch::empty({N, P1}, long_opts);
|
|
||||||
|
|
||||||
auto p1_a = p1.accessor<float, 3>();
|
|
||||||
auto p2_a = p2.accessor<float, 3>();
|
|
||||||
auto out_a = out.accessor<int64_t, 2>();
|
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
|
||||||
for (int i1 = 0; i1 < P1; ++i1) {
|
|
||||||
// TODO: support other floating-point types?
|
|
||||||
float min_dist = -1;
|
|
||||||
int64_t min_idx = -1;
|
|
||||||
for (int i2 = 0; i2 < P2; ++i2) {
|
|
||||||
float dist = 0;
|
|
||||||
for (int d = 0; d < D; ++d) {
|
|
||||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
|
||||||
dist += diff * diff;
|
|
||||||
}
|
|
||||||
if (min_dist == -1 || dist < min_dist) {
|
|
||||||
min_dist = dist;
|
|
||||||
min_idx = i2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out_a[n][i1] = min_idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
@ -5,7 +5,6 @@ from .cubify import cubify
|
|||||||
from .graph_conv import GraphConv
|
from .graph_conv import GraphConv
|
||||||
from .knn import knn_gather, knn_points
|
from .knn import knn_gather, knn_points
|
||||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||||
from .nearest_neighbor_points import nn_points_idx
|
|
||||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||||
from .points_alignment import corresponding_points_alignment
|
from .points_alignment import corresponding_points_alignment
|
||||||
from .sample_points_from_meshes import sample_points_from_meshes
|
from .sample_points_from_meshes import sample_points_from_meshes
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pytorch3d import _C
|
|
||||||
|
|
||||||
|
|
||||||
def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute the coordinates of nearest neighbors in pointcloud p2 to points in p1.
|
|
||||||
Args:
|
|
||||||
p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
|
||||||
containing P1 points of dimension D.
|
|
||||||
p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
|
||||||
containing P2 points of dimension D.
|
|
||||||
p2_normals: [optional] FloatTensor of shape (N, P2, D) giving
|
|
||||||
normals for p2. Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
3-element tuple containing
|
|
||||||
|
|
||||||
- **p1_nn_points**: FloatTensor of shape (N, P1, D) where
|
|
||||||
p1_neighbors[n, i] is the point in p2[n] which is
|
|
||||||
the nearest neighbor to p1[n, i].
|
|
||||||
- **p1_nn_idx**: LongTensor of shape (N, P1) giving the indices of
|
|
||||||
the neighbors.
|
|
||||||
- **p1_nn_normals**: Normal vectors for each point in p1_neighbors;
|
|
||||||
only returned if p2_normals is passed
|
|
||||||
else return [].
|
|
||||||
"""
|
|
||||||
N, P1, D = p1.shape
|
|
||||||
with torch.no_grad():
|
|
||||||
p1_nn_idx = _C.nn_points_idx(p1.contiguous(), p2.contiguous()) # (N, P1)
|
|
||||||
p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D)
|
|
||||||
p1_nn_points = p2.gather(1, p1_nn_idx_expanded)
|
|
||||||
if p2_normals is None:
|
|
||||||
p1_nn_normals = []
|
|
||||||
else:
|
|
||||||
if p2_normals.shape != p2.shape:
|
|
||||||
raise ValueError("p2_normals has incorrect shape.")
|
|
||||||
p1_nn_normals = p2_normals.gather(1, p1_nn_idx_expanded)
|
|
||||||
|
|
||||||
return p1_nn_points, p1_nn_idx, p1_nn_normals
|
|
@ -1,42 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
||||||
|
|
||||||
from itertools import product
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from fvcore.common.benchmark import benchmark
|
|
||||||
from test_nearest_neighbor_points import TestNearestNeighborPoints
|
|
||||||
|
|
||||||
|
|
||||||
def bm_nn_points() -> None:
|
|
||||||
kwargs_list = []
|
|
||||||
|
|
||||||
N = [1, 4, 32]
|
|
||||||
D = [3, 4]
|
|
||||||
P1 = [1, 128]
|
|
||||||
P2 = [32, 128]
|
|
||||||
test_cases = product(N, D, P1, P2)
|
|
||||||
for case in test_cases:
|
|
||||||
n, d, p1, p2 = case
|
|
||||||
kwargs_list.append({"N": n, "D": d, "P1": p1, "P2": p2})
|
|
||||||
|
|
||||||
benchmark(
|
|
||||||
TestNearestNeighborPoints.bm_nn_points_python_with_init,
|
|
||||||
"NN_PYTHON",
|
|
||||||
kwargs_list,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
benchmark(
|
|
||||||
TestNearestNeighborPoints.bm_nn_points_cpu_with_init,
|
|
||||||
"NN_CPU",
|
|
||||||
kwargs_list,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
benchmark(
|
|
||||||
TestNearestNeighborPoints.bm_nn_points_cuda_with_init,
|
|
||||||
"NN_CUDA",
|
|
||||||
kwargs_list,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
|
@ -1,91 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pytorch3d import _C
|
|
||||||
|
|
||||||
|
|
||||||
class TestNearestNeighborPoints(unittest.TestCase):
|
|
||||||
@staticmethod
|
|
||||||
def nn_points_idx_naive(x, y):
|
|
||||||
"""
|
|
||||||
PyTorch implementation of nn_points_idx function.
|
|
||||||
"""
|
|
||||||
N, P1, D = x.shape
|
|
||||||
_N, P2, _D = y.shape
|
|
||||||
assert N == _N and D == _D
|
|
||||||
diffs = x.view(N, P1, 1, D) - y.view(N, 1, P2, D)
|
|
||||||
dists2 = (diffs * diffs).sum(3)
|
|
||||||
idx = dists2.argmin(2)
|
|
||||||
return idx
|
|
||||||
|
|
||||||
def _test_nn_helper(self, device):
|
|
||||||
for D in [3, 4]:
|
|
||||||
for N in [1, 4]:
|
|
||||||
for P1 in [1, 8, 64, 128]:
|
|
||||||
for P2 in [32, 128]:
|
|
||||||
x = torch.randn(N, P1, D, device=device)
|
|
||||||
y = torch.randn(N, P2, D, device=device)
|
|
||||||
|
|
||||||
# _C.nn_points_idx should dispatch
|
|
||||||
# to the cpp or cuda versions of the function
|
|
||||||
# depending on the input type.
|
|
||||||
idx1 = _C.nn_points_idx(x, y)
|
|
||||||
idx2 = TestNearestNeighborPoints.nn_points_idx_naive(x, y)
|
|
||||||
self.assertTrue(idx1.size(1) == P1)
|
|
||||||
self.assertTrue(torch.all(idx1 == idx2))
|
|
||||||
|
|
||||||
def test_nn_cuda(self):
|
|
||||||
"""
|
|
||||||
Test cuda output vs naive python implementation.
|
|
||||||
"""
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
self._test_nn_helper(device)
|
|
||||||
|
|
||||||
def test_nn_cpu(self):
|
|
||||||
"""
|
|
||||||
Test cpu output vs naive python implementation
|
|
||||||
"""
|
|
||||||
device = torch.device("cpu")
|
|
||||||
self._test_nn_helper(device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def bm_nn_points_cpu_with_init(
|
|
||||||
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
|
|
||||||
):
|
|
||||||
device = torch.device("cpu")
|
|
||||||
x = torch.randn(N, P1, D, device=device)
|
|
||||||
y = torch.randn(N, P2, D, device=device)
|
|
||||||
|
|
||||||
def nn_cpu():
|
|
||||||
_C.nn_points_idx(x.contiguous(), y.contiguous())
|
|
||||||
|
|
||||||
return nn_cpu
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def bm_nn_points_cuda_with_init(
|
|
||||||
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
|
|
||||||
):
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
x = torch.randn(N, P1, D, device=device)
|
|
||||||
y = torch.randn(N, P2, D, device=device)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def nn_cpp():
|
|
||||||
_C.nn_points_idx(x.contiguous(), y.contiguous())
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return nn_cpp
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def bm_nn_points_python_with_init(
|
|
||||||
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
|
|
||||||
):
|
|
||||||
x = torch.randn(N, P1, D)
|
|
||||||
y = torch.randn(N, P2, D)
|
|
||||||
|
|
||||||
def nn_python():
|
|
||||||
TestNearestNeighborPoints.nn_points_idx_naive(x, y)
|
|
||||||
|
|
||||||
return nn_python
|
|
Loading…
x
Reference in New Issue
Block a user