mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Expose knn_check_version in python
Summary: We have multiple KNN CUDA implementations. From python, users can currently request a particular implementation via the `version` flag, but they have no way of knowing which implementations can be used for a given problem. This diff exposes a function `pytorch3d._C.knn_check_version(version, D, K)` that returns whether a particular version can be used. Reviewed By: nikhilaravi Differential Revision: D21162573 fbshipit-source-id: 6061960bdcecba454fd920b00036f4e9ff3fdbc0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e38cbfe5e3
commit
9f31a4fd46
@@ -155,6 +155,26 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
else:
|
||||
self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
|
||||
|
||||
def test_knn_check_version(self):
|
||||
try:
|
||||
from pytorch3d._C import knn_check_version
|
||||
except ImportError:
|
||||
# knn_check_version will only be defined if we compiled with CUDA support
|
||||
return
|
||||
for D in range(-10, 10):
|
||||
for K in range(-10, 20):
|
||||
v0 = True
|
||||
v1 = 1 <= D <= 32
|
||||
v2 = 1 <= D <= 8 and 1 <= K <= 32
|
||||
v3 = 1 <= D <= 8 and 1 <= K <= 4
|
||||
all_expected = [v0, v1, v2, v3]
|
||||
for version in range(-10, 10):
|
||||
actual = knn_check_version(version, D, K)
|
||||
expected = False
|
||||
if 0 <= version < len(all_expected):
|
||||
expected = all_expected[version]
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@staticmethod
|
||||
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
||||
device = torch.device(device)
|
||||
|
||||
Reference in New Issue
Block a user