knn autograd

Summary:
Adds knn backward to return `grad_pts1` and `grad_pts2`. Adds `knn_gather` to return the nearest neighbors in pts2.

The BM tests include backward pass and are ran on an M40.
```
Benchmark                               Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
KNN_SQUARE_32_256_128_3_24_cpu              39558           43485             13
KNN_SQUARE_32_256_128_3_24_cuda:0            1080            1404            463
KNN_SQUARE_32_256_512_3_24_cpu              81950           85781              7
KNN_SQUARE_32_256_512_3_24_cuda:0            1519            1641            330
--------------------------------------------------------------------------------

Benchmark                               Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
KNN_RAGGED_32_256_128_3_24_cpu              13798           14650             37
KNN_RAGGED_32_256_128_3_24_cuda:0            1576            1713            318
KNN_RAGGED_32_256_512_3_24_cpu              31255           32210             16
KNN_RAGGED_32_256_512_3_24_cuda:0            2024            2162            248
--------------------------------------------------------------------------------
```

Reviewed By: jcjohnson

Differential Revision: D20945556

fbshipit-source-id: a16f616029c6b5f8c2afceb5f2bc12c5c20d2f3c
This commit is contained in:
Georgia Gkioxari
2020-04-14 17:20:16 -07:00
committed by Facebook GitHub Bot
parent 487d4d6607
commit b2b0c5a442
8 changed files with 545 additions and 365 deletions

View File

@@ -2,180 +2,25 @@
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from pytorch3d import _C
from pytorch3d.ops.knn import _knn_points_idx_naive
from test_knn import TestKNN
def bm_knn() -> None:
""" Entry point for the benchmark """
benchmark_knn_cpu()
benchmark_knn_cuda_vs_naive()
benchmark_knn_cuda_versions()
benchmark_knn_cuda_versions_ragged()
backends = ["cpu", "cuda:0"]
def benchmark_knn_cuda_versions() -> None:
# Compare our different KNN implementations,
# and also compare against our existing 1-NN
Ns = [1, 2]
Ps = [4096, 16384]
kwargs_list = []
Ns = [32]
P1s = [256]
P2s = [128, 512]
Ds = [3]
Ks = [1, 4, 16, 64]
versions = [0, 1, 2, 3]
knn_kwargs, nn_kwargs = [], []
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
if version == 2 and K > 32:
continue
if version == 3 and K > 4:
continue
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1)
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
Ks = [24]
test_cases = product(Ns, P1s, P2s, Ds, Ks, backends)
for case in test_cases:
N, P1, P2, D, K, b = case
kwargs_list.append({"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "device": b})
benchmark(TestKNN.knn_square, "KNN_SQUARE", kwargs_list, warmup_iters=1)
def benchmark_knn_cuda_versions_ragged() -> None:
# Compare our different KNN implementations,
# and also compare against our existing 1-NN
Ns = [8]
Ps = [4096, 16384]
Ds = [3]
Ks = [1, 4, 16, 64]
versions = [0, 1, 2, 3]
knn_kwargs = []
for N, P, D, K, version in product(Ns, Ps, Ds, Ks, versions):
if version == 2 and K > 32:
continue
if version == 3 and K > 4:
continue
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
benchmark(knn_cuda_with_init, "KNN_CUDA_COMPARISON", knn_kwargs, warmup_iters=1)
benchmark(knn_cuda_ragged, "KNN_CUDA_RAGGED", knn_kwargs, warmup_iters=1)
def benchmark_knn_cuda_vs_naive() -> None:
# Compare against naive pytorch version of KNN
Ns = [1, 2, 4]
Ps = [1024, 4096, 16384, 65536]
Ds = [3]
Ks = [1, 2, 4, 8, 16]
knn_kwargs, naive_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
if P <= 4096:
naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
benchmark(
knn_python_cuda_with_init, "KNN_CUDA_PYTHON", naive_kwargs, warmup_iters=1
)
benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
def benchmark_knn_cpu() -> None:
Ns = [1, 2]
Ps = [256, 512]
Ds = [3]
Ks = [1, 2, 4]
knn_kwargs, nn_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1)
benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
def knn_cuda_with_init(N, D, P, K, v=-1):
device = torch.device("cuda:0")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
torch.cuda.synchronize()
def knn():
_C.knn_points_idx(x, y, lengths, lengths, K, v)
torch.cuda.synchronize()
return knn
def knn_cuda_ragged(N, D, P, K, v=-1):
device = torch.device("cuda:0")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths1 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
lengths2 = torch.randint(P, size=(N,), device=device, dtype=torch.int64)
torch.cuda.synchronize()
def knn():
_C.knn_points_idx(x, y, lengths1, lengths2, K, v)
torch.cuda.synchronize()
return knn
def knn_cpu_with_init(N, D, P, K):
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
def knn():
_C.knn_points_idx(x, y, lengths, lengths, K, -1)
return knn
def knn_python_cuda_with_init(N, D, P, K):
device = torch.device("cuda")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
torch.cuda.synchronize()
def knn():
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
torch.cuda.synchronize()
return knn
def knn_python_cpu_with_init(N, D, P, K):
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
def knn():
_knn_points_idx_naive(x, y, K=K, lengths1=lengths, lengths2=lengths)
return knn
def nn_cuda_with_init(N, D, P):
device = torch.device("cuda")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
def knn():
_C.nn_points_idx(x, y)
torch.cuda.synchronize()
return knn
def nn_cpu_with_init(N, D, P):
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
def knn():
_C.nn_points_idx(x, y)
return knn
benchmark(TestKNN.knn_ragged, "KNN_RAGGED", kwargs_list, warmup_iters=1)

View File

@@ -4,116 +4,187 @@ import unittest
from itertools import product
import torch
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
from common_testing import TestCaseMixin
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
class TestKNN(unittest.TestCase):
class TestKNN(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
def _check_knn_result(self, out1, out2, sorted):
# When sorted=True, points should be sorted by distance and should
# match between implementations. When sorted=False we we only want to
# check that we got the same set of indices, so we sort the indices by
# index value.
idx1, dist1 = out1
idx2, dist2 = out2
if not sorted:
idx1 = idx1.sort(dim=2).values
idx2 = idx2.sort(dim=2).values
dist1 = dist1.sort(dim=2).values
dist2 = dist2.sort(dim=2).values
if not torch.all(idx1 == idx2):
print(idx1)
print(idx2)
self.assertTrue(torch.all(idx1 == idx2))
self.assertTrue(torch.allclose(dist1, dist2))
@staticmethod
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
"""
Naive PyTorch implementation of K-Nearest Neighbors.
Returns always sorted results
"""
N, P1, D = p1.shape
_N, P2, _D = p2.shape
def test_knn_vs_python_cpu_square(self):
""" Test CPU output vs PyTorch implementation """
device = torch.device("cpu")
Ns = [1, 4]
Ds = [2, 3]
P1s = [1, 10, 101]
P2s = [10, 101]
Ks = [1, 3, 10]
sorts = [True, False]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points_idx(
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
)
self._check_knn_result(out1, out2, sort)
assert N == _N and D == _D
def test_knn_vs_python_cuda_square(self):
""" Test CUDA output vs PyTorch implementation """
device = torch.device("cuda")
if lengths1 is None:
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device)
idx = torch.zeros((N, P1, K), dtype=torch.int64, device=p1.device)
for n in range(N):
num1 = lengths1[n].item()
num2 = lengths2[n].item()
pp1 = p1[n, :num1].view(num1, 1, D)
pp2 = p2[n, :num2].view(1, num2, D)
diff = pp1 - pp2
diff = (diff * diff).sum(2)
num2 = min(num2, K)
for i in range(num1):
dd = diff[i]
srt_dd, srt_idx = dd.sort()
dists[n, i, :num2] = srt_dd[:num2]
idx[n, i, :num2] = srt_idx[:num2]
return _KNN(dists=dists, idx=idx, knn=None)
def _knn_vs_python_square_helper(self, device):
Ns = [1, 4]
Ds = [2, 3, 8]
P1s = [1, 8, 64, 128, 1001]
P2s = [32, 128, 513]
Ds = [3, 5, 8]
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
sorts = [True, False]
versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
for version in versions:
if version == 3 and K > 4:
continue
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
self._check_knn_result(out1, out2, sort)
x = torch.randn(N, P1, D, device=device, requires_grad=True)
x_cuda = x.clone().detach()
x_cuda.requires_grad_(True)
y = torch.randn(N, P2, D, device=device, requires_grad=True)
y_cuda = y.clone().detach()
y_cuda.requires_grad_(True)
def test_knn_vs_python_cpu_ragged(self):
# forward
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1]))
# backward
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
loss1 = (out1.dists * grad_dist).sum()
loss1.backward()
loss2 = (out2.dists * grad_dist).sum()
loss2.backward()
self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
def test_knn_vs_python_square_cpu(self):
device = torch.device("cpu")
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
N = 4
D = 3
Ks = [1, 9, 10, 11, 101]
sorts = [False, True]
factors = [Ks, sorts]
for K, sort in product(*factors):
x = torch.randn(N, lengths1.max(), D, device=device)
y = torch.randn(N, lengths2.max(), D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points_idx(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
)
self._check_knn_result(out1, out2, sort)
self._knn_vs_python_square_helper(device)
def test_knn_vs_python_cuda_ragged(self):
device = torch.device("cuda")
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
N = 4
D = 3
Ks = [1, 9, 10, 11, 101]
sorts = [True, False]
versions = [0, 1, 2, 3]
factors = [Ks, sorts]
for K, sort in product(*factors):
x = torch.randn(N, lengths1.max(), D, device=device)
y = torch.randn(N, lengths2.max(), D, device=device)
out1 = _knn_points_idx_naive(
def test_knn_vs_python_square_cuda(self):
device = torch.device("cuda:0")
self._knn_vs_python_square_helper(device)
def _knn_vs_python_ragged_helper(self, device):
Ns = [1, 4]
Ds = [3, 5, 8]
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
x = torch.rand((N, P1, D), device=device, requires_grad=True)
y = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
x_csrc = x.clone().detach()
x_csrc.requires_grad_(True)
y_csrc = y.clone().detach()
y_csrc.requires_grad_(True)
# forward
out1 = self._knn_points_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
for version in versions:
if version == 3 and K > 4:
continue
out2 = knn_points_idx(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
)
self._check_knn_result(out1, out2, sort)
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1]))
# backward
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
loss1 = (out1.dists * grad_dist).sum()
loss1.backward()
loss2 = (out2.dists * grad_dist).sum()
loss2.backward()
self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
def test_knn_vs_python_ragged_cpu(self):
device = torch.device("cpu")
self._knn_vs_python_ragged_helper(device)
def test_knn_vs_python_ragged_cuda(self):
device = torch.device("cuda:0")
self._knn_vs_python_ragged_helper(device)
def test_knn_gather(self):
device = torch.device("cuda:0")
N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
out = knn_points(x, y, lengths1=lengths1, lengths2=lengths2, K=K)
y_nn = knn_gather(y, out.idx, lengths2)
for n in range(N):
for p1 in range(P1):
for k in range(K):
if k < lengths2[n]:
self.assertClose(y_nn[n, p1, k], y[n, out.idx[n, p1, k]])
else:
self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
@staticmethod
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)
pts1 = torch.randn(N, P1, D, device=device, requires_grad=True)
pts2 = torch.randn(N, P2, D, device=device, requires_grad=True)
grad_dists = torch.randn(N, P1, K, device=device)
torch.cuda.synchronize()
def output():
out = knn_points(pts1, pts2, K=K)
loss = (out.dists * grad_dists).sum()
loss.backward()
torch.cuda.synchronize()
return output
@staticmethod
def knn_ragged(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)
pts1 = torch.rand((N, P1, D), device=device, requires_grad=True)
pts2 = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
grad_dists = torch.randn(N, P1, K, device=device)
torch.cuda.synchronize()
def output():
out = knn_points(pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K)
loss = (out.dists * grad_dists).sum()
loss.backward()
torch.cuda.synchronize()
return output