mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
knn autograd
Summary: Adds knn backward to return `grad_pts1` and `grad_pts2`. Adds `knn_gather` to return the nearest neighbors in pts2. The BM tests include backward pass and are ran on an M40. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- KNN_SQUARE_32_256_128_3_24_cpu 39558 43485 13 KNN_SQUARE_32_256_128_3_24_cuda:0 1080 1404 463 KNN_SQUARE_32_256_512_3_24_cpu 81950 85781 7 KNN_SQUARE_32_256_512_3_24_cuda:0 1519 1641 330 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- KNN_RAGGED_32_256_128_3_24_cpu 13798 14650 37 KNN_RAGGED_32_256_128_3_24_cuda:0 1576 1713 318 KNN_RAGGED_32_256_512_3_24_cpu 31255 32210 16 KNN_RAGGED_32_256_512_3_24_cuda:0 2024 2162 248 -------------------------------------------------------------------------------- ``` Reviewed By: jcjohnson Differential Revision: D20945556 fbshipit-source-id: a16f616029c6b5f8c2afceb5f2bc12c5c20d2f3c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
487d4d6607
commit
b2b0c5a442
@@ -4,116 +4,187 @@ import unittest
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops.knn import _KNN, knn_gather, knn_points
|
||||
|
||||
|
||||
class TestKNN(unittest.TestCase):
|
||||
class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
def _check_knn_result(self, out1, out2, sorted):
|
||||
# When sorted=True, points should be sorted by distance and should
|
||||
# match between implementations. When sorted=False we we only want to
|
||||
# check that we got the same set of indices, so we sort the indices by
|
||||
# index value.
|
||||
idx1, dist1 = out1
|
||||
idx2, dist2 = out2
|
||||
if not sorted:
|
||||
idx1 = idx1.sort(dim=2).values
|
||||
idx2 = idx2.sort(dim=2).values
|
||||
dist1 = dist1.sort(dim=2).values
|
||||
dist2 = dist2.sort(dim=2).values
|
||||
if not torch.all(idx1 == idx2):
|
||||
print(idx1)
|
||||
print(idx2)
|
||||
self.assertTrue(torch.all(idx1 == idx2))
|
||||
self.assertTrue(torch.allclose(dist1, dist2))
|
||||
@staticmethod
|
||||
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
|
||||
"""
|
||||
Naive PyTorch implementation of K-Nearest Neighbors.
|
||||
Returns always sorted results
|
||||
"""
|
||||
N, P1, D = p1.shape
|
||||
_N, P2, _D = p2.shape
|
||||
|
||||
def test_knn_vs_python_cpu_square(self):
|
||||
""" Test CPU output vs PyTorch implementation """
|
||||
device = torch.device("cpu")
|
||||
Ns = [1, 4]
|
||||
Ds = [2, 3]
|
||||
P1s = [1, 10, 101]
|
||||
P2s = [10, 101]
|
||||
Ks = [1, 3, 10]
|
||||
sorts = [True, False]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
out2 = knn_points_idx(
|
||||
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
assert N == _N and D == _D
|
||||
|
||||
def test_knn_vs_python_cuda_square(self):
|
||||
""" Test CUDA output vs PyTorch implementation """
|
||||
device = torch.device("cuda")
|
||||
if lengths1 is None:
|
||||
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device)
|
||||
if lengths2 is None:
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||
|
||||
dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device)
|
||||
idx = torch.zeros((N, P1, K), dtype=torch.int64, device=p1.device)
|
||||
|
||||
for n in range(N):
|
||||
num1 = lengths1[n].item()
|
||||
num2 = lengths2[n].item()
|
||||
pp1 = p1[n, :num1].view(num1, 1, D)
|
||||
pp2 = p2[n, :num2].view(1, num2, D)
|
||||
diff = pp1 - pp2
|
||||
diff = (diff * diff).sum(2)
|
||||
num2 = min(num2, K)
|
||||
for i in range(num1):
|
||||
dd = diff[i]
|
||||
srt_dd, srt_idx = dd.sort()
|
||||
|
||||
dists[n, i, :num2] = srt_dd[:num2]
|
||||
idx[n, i, :num2] = srt_idx[:num2]
|
||||
|
||||
return _KNN(dists=dists, idx=idx, knn=None)
|
||||
|
||||
def _knn_vs_python_square_helper(self, device):
|
||||
Ns = [1, 4]
|
||||
Ds = [2, 3, 8]
|
||||
P1s = [1, 8, 64, 128, 1001]
|
||||
P2s = [32, 128, 513]
|
||||
Ds = [3, 5, 8]
|
||||
P1s = [8, 24]
|
||||
P2s = [8, 16, 32]
|
||||
Ks = [1, 3, 10]
|
||||
sorts = [True, False]
|
||||
versions = [0, 1, 2, 3]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
|
||||
for N, D, P1, P2, K, sort in product(*factors):
|
||||
x = torch.randn(N, P1, D, device=device)
|
||||
y = torch.randn(N, P2, D, device=device)
|
||||
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
|
||||
factors = [Ns, Ds, P1s, P2s, Ks]
|
||||
for N, D, P1, P2, K in product(*factors):
|
||||
for version in versions:
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
x = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||
x_cuda = x.clone().detach()
|
||||
x_cuda.requires_grad_(True)
|
||||
y = torch.randn(N, P2, D, device=device, requires_grad=True)
|
||||
y_cuda = y.clone().detach()
|
||||
y_cuda.requires_grad_(True)
|
||||
|
||||
def test_knn_vs_python_cpu_ragged(self):
|
||||
# forward
|
||||
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
|
||||
out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
|
||||
self.assertClose(out1[0], out2[0])
|
||||
self.assertTrue(torch.all(out1[1] == out2[1]))
|
||||
|
||||
# backward
|
||||
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
|
||||
loss1 = (out1.dists * grad_dist).sum()
|
||||
loss1.backward()
|
||||
loss2 = (out2.dists * grad_dist).sum()
|
||||
loss2.backward()
|
||||
|
||||
self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
|
||||
self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
|
||||
|
||||
def test_knn_vs_python_square_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
||||
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
||||
N = 4
|
||||
D = 3
|
||||
Ks = [1, 9, 10, 11, 101]
|
||||
sorts = [False, True]
|
||||
factors = [Ks, sorts]
|
||||
for K, sort in product(*factors):
|
||||
x = torch.randn(N, lengths1.max(), D, device=device)
|
||||
y = torch.randn(N, lengths2.max(), D, device=device)
|
||||
out1 = _knn_points_idx_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
out2 = knn_points_idx(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
self._knn_vs_python_square_helper(device)
|
||||
|
||||
def test_knn_vs_python_cuda_ragged(self):
|
||||
device = torch.device("cuda")
|
||||
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
|
||||
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
|
||||
N = 4
|
||||
D = 3
|
||||
Ks = [1, 9, 10, 11, 101]
|
||||
sorts = [True, False]
|
||||
versions = [0, 1, 2, 3]
|
||||
factors = [Ks, sorts]
|
||||
for K, sort in product(*factors):
|
||||
x = torch.randn(N, lengths1.max(), D, device=device)
|
||||
y = torch.randn(N, lengths2.max(), D, device=device)
|
||||
out1 = _knn_points_idx_naive(
|
||||
def test_knn_vs_python_square_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._knn_vs_python_square_helper(device)
|
||||
|
||||
def _knn_vs_python_ragged_helper(self, device):
|
||||
Ns = [1, 4]
|
||||
Ds = [3, 5, 8]
|
||||
P1s = [8, 24]
|
||||
P2s = [8, 16, 32]
|
||||
Ks = [1, 3, 10]
|
||||
factors = [Ns, Ds, P1s, P2s, Ks]
|
||||
for N, D, P1, P2, K in product(*factors):
|
||||
x = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||
y = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||
|
||||
x_csrc = x.clone().detach()
|
||||
x_csrc.requires_grad_(True)
|
||||
y_csrc = y.clone().detach()
|
||||
y_csrc.requires_grad_(True)
|
||||
|
||||
# forward
|
||||
out1 = self._knn_points_naive(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
||||
)
|
||||
for version in versions:
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
out2 = knn_points_idx(
|
||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
|
||||
)
|
||||
self._check_knn_result(out1, out2, sort)
|
||||
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
|
||||
self.assertClose(out1[0], out2[0])
|
||||
self.assertTrue(torch.all(out1[1] == out2[1]))
|
||||
|
||||
# backward
|
||||
grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
|
||||
loss1 = (out1.dists * grad_dist).sum()
|
||||
loss1.backward()
|
||||
loss2 = (out2.dists * grad_dist).sum()
|
||||
loss2.backward()
|
||||
|
||||
self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
|
||||
self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
|
||||
|
||||
def test_knn_vs_python_ragged_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._knn_vs_python_ragged_helper(device)
|
||||
|
||||
def test_knn_vs_python_ragged_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._knn_vs_python_ragged_helper(device)
|
||||
|
||||
def test_knn_gather(self):
|
||||
device = torch.device("cuda:0")
|
||||
N, P1, P2, K, D = 4, 16, 12, 8, 3
|
||||
x = torch.rand((N, P1, D), device=device)
|
||||
y = torch.rand((N, P2, D), device=device)
|
||||
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||
|
||||
out = knn_points(x, y, lengths1=lengths1, lengths2=lengths2, K=K)
|
||||
y_nn = knn_gather(y, out.idx, lengths2)
|
||||
|
||||
for n in range(N):
|
||||
for p1 in range(P1):
|
||||
for k in range(K):
|
||||
if k < lengths2[n]:
|
||||
self.assertClose(y_nn[n, p1, k], y[n, out.idx[n, p1, k]])
|
||||
else:
|
||||
self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
|
||||
|
||||
@staticmethod
|
||||
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
||||
device = torch.device(device)
|
||||
pts1 = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||
pts2 = torch.randn(N, P2, D, device=device, requires_grad=True)
|
||||
grad_dists = torch.randn(N, P1, K, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
out = knn_points(pts1, pts2, K=K)
|
||||
loss = (out.dists * grad_dists).sum()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def knn_ragged(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
||||
device = torch.device(device)
|
||||
pts1 = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||
pts2 = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||
lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)
|
||||
grad_dists = torch.randn(N, P1, K, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
out = knn_points(pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K)
|
||||
loss = (out.dists * grad_dists).sum()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user