pytorch3d/tests/test_knn.py
Justin Johnson 9f31a4fd46 Expose knn_check_version in python
Summary:
We have multiple KNN CUDA implementations. From python, users can currently request a particular implementation via the `version` flag, but they have no way of knowing which implementations can be used for a given problem.

This diff exposes a function `pytorch3d._C.knn_check_version(version, D, K)` that returns whether a particular version can be used.

Reviewed By: nikhilaravi

Differential Revision: D21162573

fbshipit-source-id: 6061960bdcecba454fd920b00036f4e9ff3fdbc0
2020-04-22 14:30:52 -07:00

211 lines
7.9 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from itertools import product
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
class TestKNN(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
@staticmethod
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
"""
Naive PyTorch implementation of K-Nearest Neighbors.
Returns always sorted results
"""
N, P1, D = p1.shape
_N, P2, _D = p2.shape
assert N == _N and D == _D
if lengths1 is None:
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device)
idx = torch.zeros((N, P1, K), dtype=torch.int64, device=p1.device)
for n in range(N):
num1 = lengths1[n].item()
num2 = lengths2[n].item()
pp1 = p1[n, :num1].view(num1, 1, D)
pp2 = p2[n, :num2].view(1, num2, D)
diff = pp1 - pp2
diff = (diff * diff).sum(2)
num2 = min(num2, K)
for i in range(num1):
dd = diff[i]
srt_dd, srt_idx = dd.sort()
dists[n, i, :num2] = srt_dd[:num2]
idx[n, i, :num2] = srt_idx[:num2]
return _KNN(dists=dists, idx=idx, knn=None)
def _knn_vs_python_square_helper(self, device):
Ns = [1, 4]
Ds = [3, 5, 8]
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
for version in versions:
if version == 3 and K > 4:
continue
x = torch.randn(N, P1, D, device=device, requires_grad=True)
x_cuda = x.clone().detach()
x_cuda.requires_grad_(True)
y = torch.randn(N, P2, D, device=device, requires_grad=True)
y_cuda = y.clone().detach()
y_cuda.requires_grad_(True)
# forward
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1]))
# backward
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
loss1 = (out1.dists * grad_dist).sum()
loss1.backward()
loss2 = (out2.dists * grad_dist).sum()
loss2.backward()
self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
def test_knn_vs_python_square_cpu(self):
device = torch.device("cpu")
self._knn_vs_python_square_helper(device)
def test_knn_vs_python_square_cuda(self):
device = torch.device("cuda:0")
self._knn_vs_python_square_helper(device)
def _knn_vs_python_ragged_helper(self, device):
Ns = [1, 4]
Ds = [3, 5, 8]
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
x = torch.rand((N, P1, D), device=device, requires_grad=True)
y = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
x_csrc = x.clone().detach()
x_csrc.requires_grad_(True)
y_csrc = y.clone().detach()
y_csrc.requires_grad_(True)
# forward
out1 = self._knn_points_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1]))
# backward
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
loss1 = (out1.dists * grad_dist).sum()
loss1.backward()
loss2 = (out2.dists * grad_dist).sum()
loss2.backward()
self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
def test_knn_vs_python_ragged_cpu(self):
device = torch.device("cpu")
self._knn_vs_python_ragged_helper(device)
def test_knn_vs_python_ragged_cuda(self):
device = torch.device("cuda:0")
self._knn_vs_python_ragged_helper(device)
def test_knn_gather(self):
device = torch.device("cuda:0")
N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
out = knn_points(x, y, lengths1=lengths1, lengths2=lengths2, K=K)
y_nn = knn_gather(y, out.idx, lengths2)
for n in range(N):
for p1 in range(P1):
for k in range(K):
if k < lengths2[n]:
self.assertClose(y_nn[n, p1, k], y[n, out.idx[n, p1, k]])
else:
self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
def test_knn_check_version(self):
try:
from pytorch3d._C import knn_check_version
except ImportError:
# knn_check_version will only be defined if we compiled with CUDA support
return
for D in range(-10, 10):
for K in range(-10, 20):
v0 = True
v1 = 1 <= D <= 32
v2 = 1 <= D <= 8 and 1 <= K <= 32
v3 = 1 <= D <= 8 and 1 <= K <= 4
all_expected = [v0, v1, v2, v3]
for version in range(-10, 10):
actual = knn_check_version(version, D, K)
expected = False
if 0 <= version < len(all_expected):
expected = all_expected[version]
self.assertEqual(actual, expected)
@staticmethod
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)
pts1 = torch.randn(N, P1, D, device=device, requires_grad=True)
pts2 = torch.randn(N, P2, D, device=device, requires_grad=True)
grad_dists = torch.randn(N, P1, K, device=device)
torch.cuda.synchronize()
def output():
out = knn_points(pts1, pts2, K=K)
loss = (out.dists * grad_dists).sum()
loss.backward()
torch.cuda.synchronize()
return output
@staticmethod
def knn_ragged(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)
pts1 = torch.rand((N, P1, D), device=device, requires_grad=True)
pts2 = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
grad_dists = torch.randn(N, P1, K, device=device)
torch.cuda.synchronize()
def output():
out = knn_points(pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K)
loss = (out.dists * grad_dists).sum()
loss.backward()
torch.cuda.synchronize()
return output