Linter, deprecated type()

Summary: Run linter after recent changes. Fix long comment in knn.h which clang-format has reflowed badly. Add crude test that code doesn't call deprecated `.type()` or `.data()`.

Reviewed By: nikhilaravi

Differential Revision: D20692935

fbshipit-source-id: 28ce0308adae79a870cb41a810b7cf8744f41ab8
This commit is contained in:
Jeremy Reizenstein 2020-03-29 14:01:15 -07:00 committed by Facebook GitHub Bot
parent 3061c5b663
commit 37c5c8e0b6
10 changed files with 430 additions and 259 deletions

View File

@ -1,9 +1,9 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This file provides utilities for dispatching to specialized versions of functions.
// This is especially useful for CUDA kernels, since specializing them to particular
// input sizes can often allow the compiler to unroll loops and place arrays into
// registers, which can give huge performance speedups.
// This file provides utilities for dispatching to specialized versions of
// functions. This is especially useful for CUDA kernels, since specializing
// them to particular input sizes can often allow the compiler to unroll loops
// and place arrays into registers, which can give huge performance speedups.
//
// As an example, suppose we have the following function which is specialized
// based on a compile-time int64_t value:
@ -92,14 +92,13 @@ namespace {
// In order to dispatch, we will take an additional template argument curN,
// and increment it via template recursion until it is equal to the run-time
// argument N.
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
typename... Args
>
template <
template <typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
typename... Args>
struct DispatchKernelHelper1D {
static void run(const int64_t N, Args... args) {
if (curN == N) {
@ -108,22 +107,21 @@ struct DispatchKernelHelper1D {
Kernel<T, curN>::run(args...);
} else if (curN < N) {
// Increment curN via template recursion
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...);
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
N, args...);
}
// We shouldn't get here -- throw an error?
}
};
// 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion.
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args
>
template <
template <typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args>
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
static void run(const int64_t N, Args... args) {
if (N == maxN) {
@ -133,19 +131,21 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
}
};
// 2D dispatch, general case.
// This is similar to the 1D case: we take additional template args curN and
// curM, and increment them via template recursion until they are equal to
// the run-time values of N and M, at which point we dispatch to the run
// method of the kernel.
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN, int64_t curN,
int64_t minM, int64_t maxM, int64_t curM,
typename... Args
>
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && curM == M) {
@ -154,67 +154,141 @@ struct DispatchKernelHelper2D {
// Increment both curN and curM. This isn't strictly necessary; we could
// just increment one or the other at each step. But this helps to cut
// on the number of recursive calls we make.
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...);
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
} else if (curN < N) {
// Increment curN only
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...);
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM,
Args...>::run(N, M, args...);
} else if (curM < M) {
// Increment curM only
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
}
}
};
// 2D dispatch, specialization for curN == maxN
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM, int64_t curM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> {
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && curM == M) {
Kernel<T, maxN, curM>::run(args...);
} else if (curM < maxM) {
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curM == maxM
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN, int64_t curN,
int64_t minM, int64_t maxM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> {
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
int64_t minM,
int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && maxM == M) {
Kernel<T, curN, maxM>::run(args...);
} else if (curN < maxN) {
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...);
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
maxM,
Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curN == maxN, curM == maxM
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM,
typename... Args
>
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> {
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && maxM == M) {
Kernel<T, maxN, maxM>::run(args...);
@ -225,37 +299,45 @@ struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Arg
} // namespace
// This is the function we expect users to call to dispatch to 1D functions
template<
template<typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args
>
template <
template <typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args>
void DispatchKernel1D(const int64_t N, Args... args) {
if (minN <= N && N <= maxN) {
// Kick off the template recursion by calling the Helper with curN = minN
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...);
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
N, args...);
}
// Maybe throw an error if we tried to dispatch outside the allowed range?
}
// This is the function we expect users to call to dispatch to 2D functions
template<
template<typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN, int64_t maxN,
int64_t minM, int64_t maxM,
typename... Args
>
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
typename... Args>
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
// Kick off the template recursion by calling the Helper with curN = minN
// and curM = minM
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...);
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
minN,
minM,
maxM,
minM,
Args...>::run(N, M, args...);
}
// Maybe throw an error if we tried to dispatch outside the specified range?
}

View File

@ -39,82 +39,180 @@
// approach for this might lead to extra function calls at runtime if the
// compiler fails to optimize them away, which could be very slow on device.
// However I didn't actually benchmark or test this.
template<typename T, int N>
template <typename T, int N>
struct RegisterIndexUtils {
__device__ __forceinline__ static T get(const T arr[N], int idx) {
if (idx < 0 || idx >= N) return T();
if (idx < 0 || idx >= N)
return T();
switch (idx) {
case 0: return arr[0];
case 1: return arr[1];
case 2: return arr[2];
case 3: return arr[3];
case 4: return arr[4];
case 5: return arr[5];
case 6: return arr[6];
case 7: return arr[7];
case 8: return arr[8];
case 9: return arr[9];
case 10: return arr[10];
case 11: return arr[11];
case 12: return arr[12];
case 13: return arr[13];
case 14: return arr[14];
case 15: return arr[15];
case 16: return arr[16];
case 17: return arr[17];
case 18: return arr[18];
case 19: return arr[19];
case 20: return arr[20];
case 21: return arr[21];
case 22: return arr[22];
case 23: return arr[23];
case 24: return arr[24];
case 25: return arr[25];
case 26: return arr[26];
case 27: return arr[27];
case 28: return arr[28];
case 29: return arr[29];
case 30: return arr[30];
case 31: return arr[31];
case 0:
return arr[0];
case 1:
return arr[1];
case 2:
return arr[2];
case 3:
return arr[3];
case 4:
return arr[4];
case 5:
return arr[5];
case 6:
return arr[6];
case 7:
return arr[7];
case 8:
return arr[8];
case 9:
return arr[9];
case 10:
return arr[10];
case 11:
return arr[11];
case 12:
return arr[12];
case 13:
return arr[13];
case 14:
return arr[14];
case 15:
return arr[15];
case 16:
return arr[16];
case 17:
return arr[17];
case 18:
return arr[18];
case 19:
return arr[19];
case 20:
return arr[20];
case 21:
return arr[21];
case 22:
return arr[22];
case 23:
return arr[23];
case 24:
return arr[24];
case 25:
return arr[25];
case 26:
return arr[26];
case 27:
return arr[27];
case 28:
return arr[28];
case 29:
return arr[29];
case 30:
return arr[30];
case 31:
return arr[31];
};
return T();
}
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
if (idx < 0 || idx >= N) return;
if (idx < 0 || idx >= N)
return;
switch (idx) {
case 0: arr[0] = val; break;
case 1: arr[1] = val; break;
case 2: arr[2] = val; break;
case 3: arr[3] = val; break;
case 4: arr[4] = val; break;
case 5: arr[5] = val; break;
case 6: arr[6] = val; break;
case 7: arr[7] = val; break;
case 8: arr[8] = val; break;
case 9: arr[9] = val; break;
case 10: arr[10] = val; break;
case 11: arr[11] = val; break;
case 12: arr[12] = val; break;
case 13: arr[13] = val; break;
case 14: arr[14] = val; break;
case 15: arr[15] = val; break;
case 16: arr[16] = val; break;
case 17: arr[17] = val; break;
case 18: arr[18] = val; break;
case 19: arr[19] = val; break;
case 20: arr[20] = val; break;
case 21: arr[21] = val; break;
case 22: arr[22] = val; break;
case 23: arr[23] = val; break;
case 24: arr[24] = val; break;
case 25: arr[25] = val; break;
case 26: arr[26] = val; break;
case 27: arr[27] = val; break;
case 28: arr[28] = val; break;
case 29: arr[29] = val; break;
case 30: arr[30] = val; break;
case 31: arr[31] = val; break;
case 0:
arr[0] = val;
break;
case 1:
arr[1] = val;
break;
case 2:
arr[2] = val;
break;
case 3:
arr[3] = val;
break;
case 4:
arr[4] = val;
break;
case 5:
arr[5] = val;
break;
case 6:
arr[6] = val;
break;
case 7:
arr[7] = val;
break;
case 8:
arr[8] = val;
break;
case 9:
arr[9] = val;
break;
case 10:
arr[10] = val;
break;
case 11:
arr[11] = val;
break;
case 12:
arr[12] = val;
break;
case 13:
arr[13] = val;
break;
case 14:
arr[14] = val;
break;
case 15:
arr[15] = val;
break;
case 16:
arr[16] = val;
break;
case 17:
arr[17] = val;
break;
case 18:
arr[18] = val;
break;
case 19:
arr[19] = val;
break;
case 20:
arr[20] = val;
break;
case 21:
arr[21] = val;
break;
case 22:
arr[22] = val;
break;
case 23:
arr[23] = val;
break;
case 24:
arr[24] = val;
break;
case 25:
arr[25] = val;
break;
case 26:
arr[26] = val;
break;
case 27:
arr[27] = val;
break;
case 28:
arr[28] = val;
break;
case 29:
arr[29] = val;
break;
case 30:
arr[30] = val;
break;
case 31:
arr[31] = val;
break;
}
}
};

View File

@ -289,7 +289,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const size_t threads = 256;
const size_t blocks = 256;
if (version == 0) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
KNearestNeighborKernelV0<scalar_t>
<<<blocks, threads>>>(
p1.data_ptr<scalar_t>(),
@ -303,7 +303,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
K);
}));
} else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel1D<
KNearestNeighborV1Functor,
scalar_t,
@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
K);
}));
} else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV2Functor,
scalar_t,
@ -343,7 +343,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
P2);
}));
} else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV3Functor,
scalar_t,

View File

@ -13,11 +13,11 @@
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
// K: int giving the number of nearest points to return.
// sorted: bool telling whether to sort the K returned points by their
// distance version: Integer telling which implementation to use.
// TODO(jcjohns): Document this more, or maybe remove it before
// landing.
// K: int giving the number of nearest points to return.
// sorted: bool telling whether to sort the K returned points by their
// distance.
// version: Integer telling which implementation to use.
// TODO(jcjohns): Document this more, or maybe remove it before landing.
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
@ -41,7 +41,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
const at::Tensor& p2,
int K,
int version) {
if (p1.type().is_cuda() || p2.type().is_cuda()) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1);
CHECK_CONTIGUOUS_CUDA(p2);

View File

@ -4,49 +4,48 @@
#include <queue>
#include <tuple>
std::tuple<at::Tensor, at::Tensor>
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
const int P2 = p2.size(1);
auto long_opts = p1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
auto long_opts = p1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>();
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>();
for (int n = 0; n < N; ++n) {
for (int i1 = 0; i1 < P1; ++i1) {
// Use a priority queue to store (distance, index) tuples.
std::priority_queue<std::tuple<float, int>> q;
for (int i2 = 0; i2 < P2; ++i2) {
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
dist += diff * diff;
}
int size = static_cast<int>(q.size());
if (size < K || dist < std::get<0>(q.top())) {
q.emplace(dist, i2);
if (size >= K) {
q.pop();
}
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
const int k = q.size();
dists_a[n][i1][k] = std::get<0>(t);
idxs_a[n][i1][k] = std::get<1>(t);
}
for (int n = 0; n < N; ++n) {
for (int i1 = 0; i1 < P1; ++i1) {
// Use a priority queue to store (distance, index) tuples.
std::priority_queue<std::tuple<float, int>> q;
for (int i2 = 0; i2 < P2; ++i2) {
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
dist += diff * diff;
}
int size = static_cast<int>(q.size());
if (size < K || dist < std::get<0>(q.top())) {
q.emplace(dist, i2);
if (size >= K) {
q.pop();
}
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
const int k = q.size();
dists_a[n][i1][k] = std::get<0>(t);
idxs_a[n][i1][k] = std::get<1>(t);
}
}
return std::make_tuple(idxs, dists);
}
return std::make_tuple(idxs, dists);
}

View File

@ -5,7 +5,6 @@
#include "index_utils.cuh"
// A data structure to keep track of the smallest K keys seen so far as well
// as their associated values, intended to be used in device code.
// This data structure doesn't allocate any memory; keys and values are stored
@ -32,18 +31,17 @@
// float key_k = keys[k];
// int value_k = values[k];
// }
template<typename key_t, typename value_t>
template <typename key_t, typename value_t>
class MinK {
public:
// Constructor.
//
// Arguments:
// keys: Array in which to store keys
// values: Array in which to store values
// K: How many values to keep track of
__device__ MinK(key_t *keys, value_t *vals, int K) :
keys(keys), vals(vals), K(K), _size(0) { }
__device__ MinK(key_t* keys, value_t* vals, int K)
: keys(keys), vals(vals), K(K), _size(0) {}
// Try to add a new key and associated value to the data structure. If the key
// is one of the smallest K seen so far then it will be kept; otherwise it
@ -55,7 +53,7 @@ class MinK {
// Arguments:
// key: The key to add
// val: The value associated to the key
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) {
keys[_size] = key;
vals[_size] = val;
@ -71,8 +69,8 @@ class MinK {
for (int k = 0; k < K; ++k) {
key_t cur_key = keys[k];
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
max_key = cur_key;
max_idx = k;
}
}
}
@ -102,15 +100,14 @@ class MinK {
}
private:
key_t *keys;
value_t *vals;
key_t* keys;
value_t* vals;
int K;
int _size;
key_t max_key;
int max_idx;
};
// This is a version of MinK that only touches the arrays using static indexing
// via RegisterIndexUtils. If the keys and values are stored in thread-local
// arrays, then this may allow the compiler to place them in registers for
@ -120,13 +117,13 @@ class MinK {
// We found that sorting via RegisterIndexUtils gave very poor performance,
// and suspect it may have prevented the compiler from placing the arrays
// into registers.
template<typename key_t, typename value_t, int K>
template <typename key_t, typename value_t, int K>
class RegisterMinK {
public:
__device__ RegisterMinK(key_t *keys, value_t *vals) :
keys(keys), vals(vals), _size(0) {}
__device__ RegisterMinK(key_t* keys, value_t* vals)
: keys(keys), vals(vals), _size(0) {}
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) {
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
@ -154,8 +151,8 @@ class RegisterMinK {
}
private:
key_t *keys;
value_t *vals;
key_t* keys;
value_t* vals;
int _size;
key_t max_key;
int max_idx;

View File

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from pytorch3d import _C

View File

@ -1,7 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
@ -30,21 +29,13 @@ def benchmark_knn_cuda_versions() -> None:
continue
if version == 3 and K > 4:
continue
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version})
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({'N': N, 'D': D, 'P': P})
nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(
knn_cuda_with_init,
'KNN_CUDA_VERSIONS',
knn_kwargs,
warmup_iters=1,
)
benchmark(
nn_cuda_with_init,
'NN_CUDA',
nn_kwargs,
warmup_iters=1,
knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1
)
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
def benchmark_knn_cuda_vs_naive() -> None:
@ -55,21 +46,16 @@ def benchmark_knn_cuda_vs_naive() -> None:
Ks = [1, 2, 4, 8, 16]
knn_kwargs, naive_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
if P <= 4096:
naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
benchmark(
knn_python_cuda_with_init,
'KNN_CUDA_PYTHON',
"KNN_CUDA_PYTHON",
naive_kwargs,
warmup_iters=1,
)
benchmark(
knn_cuda_with_init,
'KNN_CUDA',
knn_kwargs,
warmup_iters=1,
)
benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
def benchmark_knn_cpu() -> None:
@ -79,31 +65,18 @@ def benchmark_knn_cpu() -> None:
Ks = [1, 2, 4]
knn_kwargs, nn_kwargs = [], []
for N, P, D, K in product(Ns, Ps, Ds, Ks):
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({'N': N, 'D': D, 'P': P})
nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(
knn_python_cpu_with_init,
'KNN_CPU_PYTHON',
knn_kwargs,
warmup_iters=1,
)
benchmark(
knn_cpu_with_init,
'KNN_CPU_CPP',
knn_kwargs,
warmup_iters=1,
)
benchmark(
nn_cpu_with_init,
'NN_CPU_CPP',
nn_kwargs,
warmup_iters=1,
knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1
)
benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
def knn_cuda_with_init(N, D, P, K, v=-1):
device = torch.device('cuda:0')
device = torch.device("cuda:0")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
@ -116,7 +89,7 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
def knn_cpu_with_init(N, D, P, K):
device = torch.device('cpu')
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
@ -127,7 +100,7 @@ def knn_cpu_with_init(N, D, P, K):
def knn_python_cuda_with_init(N, D, P, K):
device = torch.device('cuda')
device = torch.device("cuda")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
@ -140,7 +113,7 @@ def knn_python_cuda_with_init(N, D, P, K):
def knn_python_cpu_with_init(N, D, P, K):
device = torch.device('cpu')
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
@ -151,7 +124,7 @@ def knn_python_cpu_with_init(N, D, P, K):
def nn_cuda_with_init(N, D, P):
device = torch.device('cuda')
device = torch.device("cuda")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)
torch.cuda.synchronize()
@ -164,7 +137,7 @@ def nn_cuda_with_init(N, D, P):
def nn_cpu_with_init(N, D, P):
device = torch.device('cpu')
device = torch.device("cpu")
x = torch.randn(N, P, D, device=device)
y = torch.randn(N, P, D, device=device)

View File

@ -22,6 +22,27 @@ class TestBuild(unittest.TestCase):
for k, v in counter.items():
self.assertEqual(v, 1, f"Too many files with stem {k}.")
def test_deprecated_usage(self):
# Check certain expressions do not occur in the csrc code
test_dir = Path(__file__).resolve().parent
source_dir = test_dir.parent / "pytorch3d" / "csrc"
files = sorted(source_dir.glob("**/*.*"))
self.assertGreater(len(files), 4)
patterns = [".type()", ".data()"]
for file in files:
with open(file) as f:
text = f.read()
for pattern in patterns:
found = pattern in text
msg = (
f"{pattern} found in {file.name}"
+ ", this has been deprecated."
)
self.assertFalse(found, msg)
def test_copyright(self):
test_dir = Path(__file__).resolve().parent
root_dir = test_dir.parent

View File

@ -28,7 +28,7 @@ class TestKNN(unittest.TestCase):
def test_knn_vs_python_cpu(self):
""" Test CPU output vs PyTorch implementation """
device = torch.device('cpu')
device = torch.device("cpu")
Ns = [1, 4]
Ds = [2, 3]
P1s = [1, 10, 101]
@ -45,7 +45,7 @@ class TestKNN(unittest.TestCase):
def test_knn_vs_python_cuda(self):
""" Test CUDA output vs PyTorch implementation """
device = torch.device('cuda')
device = torch.device("cuda")
Ns = [1, 4]
Ds = [2, 3, 8]
P1s = [1, 8, 64, 128, 1001]