From 9f31a4fd46e54d9eb87fd537bd648ee4005e2a01 Mon Sep 17 00:00:00 2001 From: Justin Johnson Date: Wed, 22 Apr 2020 14:28:19 -0700 Subject: [PATCH] Expose knn_check_version in python Summary: We have multiple KNN CUDA implementations. From python, users can currently request a particular implementation via the `version` flag, but they have no way of knowing which implementations can be used for a given problem. This diff exposes a function `pytorch3d._C.knn_check_version(version, D, K)` that returns whether a particular version can be used. Reviewed By: nikhilaravi Differential Revision: D21162573 fbshipit-source-id: 6061960bdcecba454fd920b00036f4e9ff3fdbc0 --- pytorch3d/csrc/ext.cpp | 3 +++ pytorch3d/csrc/knn/knn.cu | 8 ++++---- pytorch3d/csrc/knn/knn.h | 12 ++++++++++++ tests/test_knn.py | 20 ++++++++++++++++++++ 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 2502eff8..21d32fb7 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -18,6 +18,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals_backward", &FaceAreasNormalsBackward); m.def("packed_to_padded", &PackedToPadded); m.def("padded_to_packed", &PaddedToPacked); +#ifdef WITH_CUDA + m.def("knn_check_version", &KnnCheckVersion); +#endif m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_backward", &KNearestNeighborBackward); m.def("gather_scatter", &gather_scatter); diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index 20250d7d..b045788c 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -267,7 +267,7 @@ bool InBounds(const int64_t min, const int64_t x, const int64_t max) { return min <= x && x <= max; } -bool CheckVersion(int version, const int64_t D, const int64_t K) { +bool KnnCheckVersion(int version, const int64_t D, const int64_t K) { if (version == 0) { return true; } else if (version == 1) { @@ -282,7 +282,7 @@ bool CheckVersion(int version, const int64_t D, const int64_t K) { int ChooseVersion(const int64_t D, const int64_t K) { for (int version = 3; version >= 1; version--) { - if (CheckVersion(version, D, K)) { + if (KnnCheckVersion(version, D, K)) { return version; } } @@ -309,7 +309,7 @@ std::tuple KNearestNeighborIdxCuda( if (version < 0) { version = ChooseVersion(D, K); - } else if (!CheckVersion(version, D, K)) { + } else if (!KnnCheckVersion(version, D, K)) { int new_version = ChooseVersion(D, K); std::cout << "WARNING: Requested KNN version " << version << " is not compatible with D = " << D << "; K = " << K @@ -321,7 +321,7 @@ std::tuple KNearestNeighborIdxCuda( // gave us. But we can check once more to be sure; however this time // assert fail since failing at this point means we have a bug in our version // selection or checking code. - AT_ASSERTM(CheckVersion(version, D, K), "Invalid version"); + AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version"); const size_t threads = 256; const size_t blocks = 256; diff --git a/pytorch3d/csrc/knn/knn.h b/pytorch3d/csrc/knn/knn.h index 321da6cd..9a4b42f6 100644 --- a/pytorch3d/csrc/knn/knn.h +++ b/pytorch3d/csrc/knn/knn.h @@ -128,3 +128,15 @@ std::tuple KNearestNeighborBackward( return KNearestNeighborBackwardCpu( p1, p2, lengths1, lengths2, idxs, grad_dists); } + +// Utility to check whether a KNN version can be used. +// +// Args: +// version: Integer in the range 0 <= version <= 3 indicating one of our +// KNN implementations. +// D: Number of dimensions for the input and query point clouds +// K: Number of neighbors to be found +// +// Returns: +// Whether the indicated KNN version can be used. +bool KnnCheckVersion(int version, const int64_t D, const int64_t K); diff --git a/tests/test_knn.py b/tests/test_knn.py index d39df6f0..112d1cc9 100644 --- a/tests/test_knn.py +++ b/tests/test_knn.py @@ -155,6 +155,26 @@ class TestKNN(TestCaseMixin, unittest.TestCase): else: self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0)) + def test_knn_check_version(self): + try: + from pytorch3d._C import knn_check_version + except ImportError: + # knn_check_version will only be defined if we compiled with CUDA support + return + for D in range(-10, 10): + for K in range(-10, 20): + v0 = True + v1 = 1 <= D <= 32 + v2 = 1 <= D <= 8 and 1 <= K <= 32 + v3 = 1 <= D <= 8 and 1 <= K <= 4 + all_expected = [v0, v1, v2, v3] + for version in range(-10, 10): + actual = knn_check_version(version, D, K) + expected = False + if 0 <= version < len(all_expected): + expected = all_expected[version] + self.assertEqual(actual, expected) + @staticmethod def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str): device = torch.device(device)