remove nearest_neighbors

Summary: knn is more general and faster than the nearest_neighbor code, so remove the latter.

Reviewed By: gkioxari

Differential Revision: D20816424

fbshipit-source-id: 75d6c44d17180752d0c9859814bbdf7892558158
This commit is contained in:
Nikhila Ravi 2020-04-15 20:49:16 -07:00 committed by Facebook GitHub Bot
parent 790eb8c402
commit 3794f6753f
6 changed files with 0 additions and 217 deletions

View File

@ -7,7 +7,6 @@
#include "face_areas_normals/face_areas_normals.h"
#include "gather_scatter/gather_scatter.h"
#include "knn/knn.h"
#include "nearest_neighbor_points/nearest_neighbor_points.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "point_mesh/point_mesh_edge.h"
#include "point_mesh/point_mesh_face.h"
@ -21,7 +20,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("padded_to_packed", &PaddedToPacked);
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
m.def("nn_points_idx", &NearestNeighborIdx);
m.def("gather_scatter", &gather_scatter);
m.def("rasterize_points", &RasterizePoints);
m.def("rasterize_points_backward", &RasterizePointsBackward);

View File

@ -1,38 +0,0 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);
auto long_opts = p1.options().dtype(torch::kInt64);
torch::Tensor out = torch::empty({N, P1}, long_opts);
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto out_a = out.accessor<int64_t, 2>();
for (int n = 0; n < N; ++n) {
for (int i1 = 0; i1 < P1; ++i1) {
// TODO: support other floating-point types?
float min_dist = -1;
int64_t min_idx = -1;
for (int i2 = 0; i2 < P2; ++i2) {
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
dist += diff * diff;
}
if (min_dist == -1 || dist < min_dist) {
min_dist = dist;
min_idx = i2;
}
}
out_a[n][i1] = min_idx;
}
}
return out;
}

View File

@ -5,7 +5,6 @@ from .cubify import cubify
from .graph_conv import GraphConv
from .knn import knn_gather, knn_points
from .mesh_face_areas_normals import mesh_face_areas_normals
from .nearest_neighbor_points import nn_points_idx
from .packed_to_padded import packed_to_padded, padded_to_packed
from .points_alignment import corresponding_points_alignment
from .sample_points_from_meshes import sample_points_from_meshes

View File

@ -1,43 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from pytorch3d import _C
def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor:
"""
Compute the coordinates of nearest neighbors in pointcloud p2 to points in p1.
Args:
p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
containing P1 points of dimension D.
p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
containing P2 points of dimension D.
p2_normals: [optional] FloatTensor of shape (N, P2, D) giving
normals for p2. Default: None.
Returns:
3-element tuple containing
- **p1_nn_points**: FloatTensor of shape (N, P1, D) where
p1_neighbors[n, i] is the point in p2[n] which is
the nearest neighbor to p1[n, i].
- **p1_nn_idx**: LongTensor of shape (N, P1) giving the indices of
the neighbors.
- **p1_nn_normals**: Normal vectors for each point in p1_neighbors;
only returned if p2_normals is passed
else return [].
"""
N, P1, D = p1.shape
with torch.no_grad():
p1_nn_idx = _C.nn_points_idx(p1.contiguous(), p2.contiguous()) # (N, P1)
p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D)
p1_nn_points = p2.gather(1, p1_nn_idx_expanded)
if p2_normals is None:
p1_nn_normals = []
else:
if p2_normals.shape != p2.shape:
raise ValueError("p2_normals has incorrect shape.")
p1_nn_normals = p2_normals.gather(1, p1_nn_idx_expanded)
return p1_nn_points, p1_nn_idx, p1_nn_normals

View File

@ -1,42 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from test_nearest_neighbor_points import TestNearestNeighborPoints
def bm_nn_points() -> None:
kwargs_list = []
N = [1, 4, 32]
D = [3, 4]
P1 = [1, 128]
P2 = [32, 128]
test_cases = product(N, D, P1, P2)
for case in test_cases:
n, d, p1, p2 = case
kwargs_list.append({"N": n, "D": d, "P1": p1, "P2": p2})
benchmark(
TestNearestNeighborPoints.bm_nn_points_python_with_init,
"NN_PYTHON",
kwargs_list,
warmup_iters=1,
)
benchmark(
TestNearestNeighborPoints.bm_nn_points_cpu_with_init,
"NN_CPU",
kwargs_list,
warmup_iters=1,
)
if torch.cuda.is_available():
benchmark(
TestNearestNeighborPoints.bm_nn_points_cuda_with_init,
"NN_CUDA",
kwargs_list,
warmup_iters=1,
)

View File

@ -1,91 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from pytorch3d import _C
class TestNearestNeighborPoints(unittest.TestCase):
@staticmethod
def nn_points_idx_naive(x, y):
"""
PyTorch implementation of nn_points_idx function.
"""
N, P1, D = x.shape
_N, P2, _D = y.shape
assert N == _N and D == _D
diffs = x.view(N, P1, 1, D) - y.view(N, 1, P2, D)
dists2 = (diffs * diffs).sum(3)
idx = dists2.argmin(2)
return idx
def _test_nn_helper(self, device):
for D in [3, 4]:
for N in [1, 4]:
for P1 in [1, 8, 64, 128]:
for P2 in [32, 128]:
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
# _C.nn_points_idx should dispatch
# to the cpp or cuda versions of the function
# depending on the input type.
idx1 = _C.nn_points_idx(x, y)
idx2 = TestNearestNeighborPoints.nn_points_idx_naive(x, y)
self.assertTrue(idx1.size(1) == P1)
self.assertTrue(torch.all(idx1 == idx2))
def test_nn_cuda(self):
"""
Test cuda output vs naive python implementation.
"""
device = torch.device("cuda:0")
self._test_nn_helper(device)
def test_nn_cpu(self):
"""
Test cpu output vs naive python implementation
"""
device = torch.device("cpu")
self._test_nn_helper(device)
@staticmethod
def bm_nn_points_cpu_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
device = torch.device("cpu")
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
def nn_cpu():
_C.nn_points_idx(x.contiguous(), y.contiguous())
return nn_cpu
@staticmethod
def bm_nn_points_cuda_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
device = torch.device("cuda:0")
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
torch.cuda.synchronize()
def nn_cpp():
_C.nn_points_idx(x.contiguous(), y.contiguous())
torch.cuda.synchronize()
return nn_cpp
@staticmethod
def bm_nn_points_python_with_init(
N: int = 4, D: int = 4, P1: int = 128, P2: int = 128
):
x = torch.randn(N, P1, D)
y = torch.randn(N, P2, D)
def nn_python():
TestNearestNeighborPoints.nn_points_idx_naive(x, y)
return nn_python