pytorch3d/tests/bm_knn.py
Georgia Gkioxari b2b0c5a442 knn autograd
Summary:
Adds knn backward to return `grad_pts1` and `grad_pts2`. Adds `knn_gather` to return the nearest neighbors in pts2.

The BM tests include backward pass and are ran on an M40.
```
Benchmark                               Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
KNN_SQUARE_32_256_128_3_24_cpu              39558           43485             13
KNN_SQUARE_32_256_128_3_24_cuda:0            1080            1404            463
KNN_SQUARE_32_256_512_3_24_cpu              81950           85781              7
KNN_SQUARE_32_256_512_3_24_cuda:0            1519            1641            330
--------------------------------------------------------------------------------

Benchmark                               Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
KNN_RAGGED_32_256_128_3_24_cpu              13798           14650             37
KNN_RAGGED_32_256_128_3_24_cuda:0            1576            1713            318
KNN_RAGGED_32_256_512_3_24_cpu              31255           32210             16
KNN_RAGGED_32_256_512_3_24_cuda:0            2024            2162            248
--------------------------------------------------------------------------------
```

Reviewed By: jcjohnson

Differential Revision: D20945556

fbshipit-source-id: a16f616029c6b5f8c2afceb5f2bc12c5c20d2f3c
2020-04-14 17:22:56 -07:00

27 lines
698 B
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
from fvcore.common.benchmark import benchmark
from test_knn import TestKNN
def bm_knn() -> None:
backends = ["cpu", "cuda:0"]
kwargs_list = []
Ns = [32]
P1s = [256]
P2s = [128, 512]
Ds = [3]
Ks = [24]
test_cases = product(Ns, P1s, P2s, Ds, Ks, backends)
for case in test_cases:
N, P1, P2, D, K, b = case
kwargs_list.append({"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "device": b})
benchmark(TestKNN.knn_square, "KNN_SQUARE", kwargs_list, warmup_iters=1)
benchmark(TestKNN.knn_ragged, "KNN_RAGGED", kwargs_list, warmup_iters=1)