pytorch3d/tests/test_knn.py
Jeremy Reizenstein 01b5f7b228 heterogenous KNN
Summary: Interface and working implementation of ragged KNN. Benchmarks (which aren't ragged) haven't slowed. New benchmark shows that ragged is faster than non-ragged of the same shape.

Reviewed By: jcjohnson

Differential Revision: D20696507

fbshipit-source-id: 21b80f71343a3475c8d3ee0ce2680f92f0fae4de
2020-04-07 01:47:37 -07:00

120 lines
4.6 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from itertools import product
import torch
from pytorch3d.ops.knn import _knn_points_idx_naive, knn_points_idx
class TestKNN(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
def _check_knn_result(self, out1, out2, sorted):
# When sorted=True, points should be sorted by distance and should
# match between implementations. When sorted=False we we only want to
# check that we got the same set of indices, so we sort the indices by
# index value.
idx1, dist1 = out1
idx2, dist2 = out2
if not sorted:
idx1 = idx1.sort(dim=2).values
idx2 = idx2.sort(dim=2).values
dist1 = dist1.sort(dim=2).values
dist2 = dist2.sort(dim=2).values
if not torch.all(idx1 == idx2):
print(idx1)
print(idx2)
self.assertTrue(torch.all(idx1 == idx2))
self.assertTrue(torch.allclose(dist1, dist2))
def test_knn_vs_python_cpu_square(self):
""" Test CPU output vs PyTorch implementation """
device = torch.device("cpu")
Ns = [1, 4]
Ds = [2, 3]
P1s = [1, 10, 101]
P2s = [10, 101]
Ks = [1, 3, 10]
sorts = [True, False]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
lengths1 = torch.full((N,), P1, dtype=torch.int64, device=device)
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=device)
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points_idx(
x, y, K=K, lengths1=lengths1, lengths2=lengths2, sorted=sort
)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cuda_square(self):
""" Test CUDA output vs PyTorch implementation """
device = torch.device("cuda")
Ns = [1, 4]
Ds = [2, 3, 8]
P1s = [1, 8, 64, 128, 1001]
P2s = [32, 128, 513]
Ks = [1, 3, 10]
sorts = [True, False]
versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks, sorts]
for N, D, P1, P2, K, sort in product(*factors):
x = torch.randn(N, P1, D, device=device)
y = torch.randn(N, P2, D, device=device)
out1 = _knn_points_idx_naive(x, y, lengths1=None, lengths2=None, K=K)
for version in versions:
if version == 3 and K > 4:
continue
out2 = knn_points_idx(x, y, K=K, sorted=sort, version=version)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cpu_ragged(self):
device = torch.device("cpu")
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
N = 4
D = 3
Ks = [1, 9, 10, 11, 101]
sorts = [False, True]
factors = [Ks, sorts]
for K, sort in product(*factors):
x = torch.randn(N, lengths1.max(), D, device=device)
y = torch.randn(N, lengths2.max(), D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
out2 = knn_points_idx(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
)
self._check_knn_result(out1, out2, sort)
def test_knn_vs_python_cuda_ragged(self):
device = torch.device("cuda")
lengths1 = torch.tensor([10, 100, 10, 100], device=device, dtype=torch.int64)
lengths2 = torch.tensor([10, 10, 100, 100], device=device, dtype=torch.int64)
N = 4
D = 3
Ks = [1, 9, 10, 11, 101]
sorts = [True, False]
versions = [0, 1, 2, 3]
factors = [Ks, sorts]
for K, sort in product(*factors):
x = torch.randn(N, lengths1.max(), D, device=device)
y = torch.randn(N, lengths2.max(), D, device=device)
out1 = _knn_points_idx_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
)
for version in versions:
if version == 3 and K > 4:
continue
out2 = knn_points_idx(
x, y, lengths1=lengths1, lengths2=lengths2, K=K, sorted=sort
)
self._check_knn_result(out1, out2, sort)