mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Linter, deprecated type()
Summary: Run linter after recent changes. Fix long comment in knn.h which clang-format has reflowed badly. Add crude test that code doesn't call deprecated `.type()` or `.data()`. Reviewed By: nikhilaravi Differential Revision: D20692935 fbshipit-source-id: 28ce0308adae79a870cb41a810b7cf8744f41ab8
This commit is contained in:
parent
3061c5b663
commit
37c5c8e0b6
@ -1,9 +1,9 @@
|
|||||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
//
|
//
|
||||||
// This file provides utilities for dispatching to specialized versions of functions.
|
// This file provides utilities for dispatching to specialized versions of
|
||||||
// This is especially useful for CUDA kernels, since specializing them to particular
|
// functions. This is especially useful for CUDA kernels, since specializing
|
||||||
// input sizes can often allow the compiler to unroll loops and place arrays into
|
// them to particular input sizes can often allow the compiler to unroll loops
|
||||||
// registers, which can give huge performance speedups.
|
// and place arrays into registers, which can give huge performance speedups.
|
||||||
//
|
//
|
||||||
// As an example, suppose we have the following function which is specialized
|
// As an example, suppose we have the following function which is specialized
|
||||||
// based on a compile-time int64_t value:
|
// based on a compile-time int64_t value:
|
||||||
@ -92,14 +92,13 @@ namespace {
|
|||||||
// In order to dispatch, we will take an additional template argument curN,
|
// In order to dispatch, we will take an additional template argument curN,
|
||||||
// and increment it via template recursion until it is equal to the run-time
|
// and increment it via template recursion until it is equal to the run-time
|
||||||
// argument N.
|
// argument N.
|
||||||
template<
|
template <
|
||||||
template<typename, int64_t> class Kernel,
|
template <typename, int64_t> class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
int64_t curN,
|
int64_t curN,
|
||||||
typename... Args
|
typename... Args>
|
||||||
>
|
|
||||||
struct DispatchKernelHelper1D {
|
struct DispatchKernelHelper1D {
|
||||||
static void run(const int64_t N, Args... args) {
|
static void run(const int64_t N, Args... args) {
|
||||||
if (curN == N) {
|
if (curN == N) {
|
||||||
@ -108,22 +107,21 @@ struct DispatchKernelHelper1D {
|
|||||||
Kernel<T, curN>::run(args...);
|
Kernel<T, curN>::run(args...);
|
||||||
} else if (curN < N) {
|
} else if (curN < N) {
|
||||||
// Increment curN via template recursion
|
// Increment curN via template recursion
|
||||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...);
|
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
|
||||||
|
N, args...);
|
||||||
}
|
}
|
||||||
// We shouldn't get here -- throw an error?
|
// We shouldn't get here -- throw an error?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// 1D dispatch: Specialization when curN == maxN
|
// 1D dispatch: Specialization when curN == maxN
|
||||||
// We need this base case to avoid infinite template recursion.
|
// We need this base case to avoid infinite template recursion.
|
||||||
template<
|
template <
|
||||||
template<typename, int64_t> class Kernel,
|
template <typename, int64_t> class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
typename... Args
|
typename... Args>
|
||||||
>
|
|
||||||
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
||||||
static void run(const int64_t N, Args... args) {
|
static void run(const int64_t N, Args... args) {
|
||||||
if (N == maxN) {
|
if (N == maxN) {
|
||||||
@ -133,19 +131,21 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// 2D dispatch, general case.
|
// 2D dispatch, general case.
|
||||||
// This is similar to the 1D case: we take additional template args curN and
|
// This is similar to the 1D case: we take additional template args curN and
|
||||||
// curM, and increment them via template recursion until they are equal to
|
// curM, and increment them via template recursion until they are equal to
|
||||||
// the run-time values of N and M, at which point we dispatch to the run
|
// the run-time values of N and M, at which point we dispatch to the run
|
||||||
// method of the kernel.
|
// method of the kernel.
|
||||||
template<
|
template <
|
||||||
template<typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t> class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN, int64_t maxN, int64_t curN,
|
int64_t minN,
|
||||||
int64_t minM, int64_t maxM, int64_t curM,
|
int64_t maxN,
|
||||||
typename... Args
|
int64_t curN,
|
||||||
>
|
int64_t minM,
|
||||||
|
int64_t maxM,
|
||||||
|
int64_t curM,
|
||||||
|
typename... Args>
|
||||||
struct DispatchKernelHelper2D {
|
struct DispatchKernelHelper2D {
|
||||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||||
if (curN == N && curM == M) {
|
if (curN == N && curM == M) {
|
||||||
@ -154,67 +154,141 @@ struct DispatchKernelHelper2D {
|
|||||||
// Increment both curN and curM. This isn't strictly necessary; we could
|
// Increment both curN and curM. This isn't strictly necessary; we could
|
||||||
// just increment one or the other at each step. But this helps to cut
|
// just increment one or the other at each step. But this helps to cut
|
||||||
// on the number of recursive calls we make.
|
// on the number of recursive calls we make.
|
||||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
curN + 1,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
curM + 1,
|
||||||
|
Args...>::run(N, M, args...);
|
||||||
} else if (curN < N) {
|
} else if (curN < N) {
|
||||||
// Increment curN only
|
// Increment curN only
|
||||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...);
|
DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
curN + 1,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
curM,
|
||||||
|
Args...>::run(N, M, args...);
|
||||||
} else if (curM < M) {
|
} else if (curM < M) {
|
||||||
// Increment curM only
|
// Increment curM only
|
||||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
curN,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
curM + 1,
|
||||||
|
Args...>::run(N, M, args...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// 2D dispatch, specialization for curN == maxN
|
// 2D dispatch, specialization for curN == maxN
|
||||||
template<
|
template <
|
||||||
template<typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t> class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN, int64_t maxN,
|
int64_t minN,
|
||||||
int64_t minM, int64_t maxM, int64_t curM,
|
int64_t maxN,
|
||||||
typename... Args
|
int64_t minM,
|
||||||
>
|
int64_t maxM,
|
||||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> {
|
int64_t curM,
|
||||||
|
typename... Args>
|
||||||
|
struct DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
maxN,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
curM,
|
||||||
|
Args...> {
|
||||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||||
if (maxN == N && curM == M) {
|
if (maxN == N && curM == M) {
|
||||||
Kernel<T, maxN, curM>::run(args...);
|
Kernel<T, maxN, curM>::run(args...);
|
||||||
} else if (curM < maxM) {
|
} else if (curM < maxM) {
|
||||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
maxN,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
curM + 1,
|
||||||
|
Args...>::run(N, M, args...);
|
||||||
}
|
}
|
||||||
// We should not get here -- throw an error?
|
// We should not get here -- throw an error?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// 2D dispatch, specialization for curM == maxM
|
// 2D dispatch, specialization for curM == maxM
|
||||||
template<
|
template <
|
||||||
template<typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t> class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN, int64_t maxN, int64_t curN,
|
int64_t minN,
|
||||||
int64_t minM, int64_t maxM,
|
int64_t maxN,
|
||||||
typename... Args
|
int64_t curN,
|
||||||
>
|
int64_t minM,
|
||||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> {
|
int64_t maxM,
|
||||||
|
typename... Args>
|
||||||
|
struct DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
curN,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
maxM,
|
||||||
|
Args...> {
|
||||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||||
if (curN == N && maxM == M) {
|
if (curN == N && maxM == M) {
|
||||||
Kernel<T, curN, maxM>::run(args...);
|
Kernel<T, curN, maxM>::run(args...);
|
||||||
} else if (curN < maxN) {
|
} else if (curN < maxN) {
|
||||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...);
|
DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
curN + 1,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
maxM,
|
||||||
|
Args...>::run(N, M, args...);
|
||||||
}
|
}
|
||||||
// We should not get here -- throw an error?
|
// We should not get here -- throw an error?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
||||||
template<
|
template <
|
||||||
template<typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t> class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN, int64_t maxN,
|
int64_t minN,
|
||||||
int64_t minM, int64_t maxM,
|
int64_t maxN,
|
||||||
typename... Args
|
int64_t minM,
|
||||||
>
|
int64_t maxM,
|
||||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> {
|
typename... Args>
|
||||||
|
struct DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
maxN,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
maxM,
|
||||||
|
Args...> {
|
||||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||||
if (maxN == N && maxM == M) {
|
if (maxN == N && maxM == M) {
|
||||||
Kernel<T, maxN, maxM>::run(args...);
|
Kernel<T, maxN, maxM>::run(args...);
|
||||||
@ -225,37 +299,45 @@ struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Arg
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
// This is the function we expect users to call to dispatch to 1D functions
|
// This is the function we expect users to call to dispatch to 1D functions
|
||||||
template<
|
template <
|
||||||
template<typename, int64_t> class Kernel,
|
template <typename, int64_t> class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
typename... Args
|
typename... Args>
|
||||||
>
|
|
||||||
void DispatchKernel1D(const int64_t N, Args... args) {
|
void DispatchKernel1D(const int64_t N, Args... args) {
|
||||||
if (minN <= N && N <= maxN) {
|
if (minN <= N && N <= maxN) {
|
||||||
// Kick off the template recursion by calling the Helper with curN = minN
|
// Kick off the template recursion by calling the Helper with curN = minN
|
||||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...);
|
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
|
||||||
|
N, args...);
|
||||||
}
|
}
|
||||||
// Maybe throw an error if we tried to dispatch outside the allowed range?
|
// Maybe throw an error if we tried to dispatch outside the allowed range?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// This is the function we expect users to call to dispatch to 2D functions
|
// This is the function we expect users to call to dispatch to 2D functions
|
||||||
template<
|
template <
|
||||||
template<typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t> class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN, int64_t maxN,
|
int64_t minN,
|
||||||
int64_t minM, int64_t maxM,
|
int64_t maxN,
|
||||||
typename... Args
|
int64_t minM,
|
||||||
>
|
int64_t maxM,
|
||||||
|
typename... Args>
|
||||||
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
|
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
|
||||||
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
|
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
|
||||||
// Kick off the template recursion by calling the Helper with curN = minN
|
// Kick off the template recursion by calling the Helper with curN = minN
|
||||||
// and curM = minM
|
// and curM = minM
|
||||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...);
|
DispatchKernelHelper2D<
|
||||||
|
Kernel,
|
||||||
|
T,
|
||||||
|
minN,
|
||||||
|
maxN,
|
||||||
|
minN,
|
||||||
|
minM,
|
||||||
|
maxM,
|
||||||
|
minM,
|
||||||
|
Args...>::run(N, M, args...);
|
||||||
}
|
}
|
||||||
// Maybe throw an error if we tried to dispatch outside the specified range?
|
// Maybe throw an error if we tried to dispatch outside the specified range?
|
||||||
}
|
}
|
||||||
|
@ -39,82 +39,180 @@
|
|||||||
// approach for this might lead to extra function calls at runtime if the
|
// approach for this might lead to extra function calls at runtime if the
|
||||||
// compiler fails to optimize them away, which could be very slow on device.
|
// compiler fails to optimize them away, which could be very slow on device.
|
||||||
// However I didn't actually benchmark or test this.
|
// However I didn't actually benchmark or test this.
|
||||||
template<typename T, int N>
|
template <typename T, int N>
|
||||||
struct RegisterIndexUtils {
|
struct RegisterIndexUtils {
|
||||||
__device__ __forceinline__ static T get(const T arr[N], int idx) {
|
__device__ __forceinline__ static T get(const T arr[N], int idx) {
|
||||||
if (idx < 0 || idx >= N) return T();
|
if (idx < 0 || idx >= N)
|
||||||
|
return T();
|
||||||
switch (idx) {
|
switch (idx) {
|
||||||
case 0: return arr[0];
|
case 0:
|
||||||
case 1: return arr[1];
|
return arr[0];
|
||||||
case 2: return arr[2];
|
case 1:
|
||||||
case 3: return arr[3];
|
return arr[1];
|
||||||
case 4: return arr[4];
|
case 2:
|
||||||
case 5: return arr[5];
|
return arr[2];
|
||||||
case 6: return arr[6];
|
case 3:
|
||||||
case 7: return arr[7];
|
return arr[3];
|
||||||
case 8: return arr[8];
|
case 4:
|
||||||
case 9: return arr[9];
|
return arr[4];
|
||||||
case 10: return arr[10];
|
case 5:
|
||||||
case 11: return arr[11];
|
return arr[5];
|
||||||
case 12: return arr[12];
|
case 6:
|
||||||
case 13: return arr[13];
|
return arr[6];
|
||||||
case 14: return arr[14];
|
case 7:
|
||||||
case 15: return arr[15];
|
return arr[7];
|
||||||
case 16: return arr[16];
|
case 8:
|
||||||
case 17: return arr[17];
|
return arr[8];
|
||||||
case 18: return arr[18];
|
case 9:
|
||||||
case 19: return arr[19];
|
return arr[9];
|
||||||
case 20: return arr[20];
|
case 10:
|
||||||
case 21: return arr[21];
|
return arr[10];
|
||||||
case 22: return arr[22];
|
case 11:
|
||||||
case 23: return arr[23];
|
return arr[11];
|
||||||
case 24: return arr[24];
|
case 12:
|
||||||
case 25: return arr[25];
|
return arr[12];
|
||||||
case 26: return arr[26];
|
case 13:
|
||||||
case 27: return arr[27];
|
return arr[13];
|
||||||
case 28: return arr[28];
|
case 14:
|
||||||
case 29: return arr[29];
|
return arr[14];
|
||||||
case 30: return arr[30];
|
case 15:
|
||||||
case 31: return arr[31];
|
return arr[15];
|
||||||
|
case 16:
|
||||||
|
return arr[16];
|
||||||
|
case 17:
|
||||||
|
return arr[17];
|
||||||
|
case 18:
|
||||||
|
return arr[18];
|
||||||
|
case 19:
|
||||||
|
return arr[19];
|
||||||
|
case 20:
|
||||||
|
return arr[20];
|
||||||
|
case 21:
|
||||||
|
return arr[21];
|
||||||
|
case 22:
|
||||||
|
return arr[22];
|
||||||
|
case 23:
|
||||||
|
return arr[23];
|
||||||
|
case 24:
|
||||||
|
return arr[24];
|
||||||
|
case 25:
|
||||||
|
return arr[25];
|
||||||
|
case 26:
|
||||||
|
return arr[26];
|
||||||
|
case 27:
|
||||||
|
return arr[27];
|
||||||
|
case 28:
|
||||||
|
return arr[28];
|
||||||
|
case 29:
|
||||||
|
return arr[29];
|
||||||
|
case 30:
|
||||||
|
return arr[30];
|
||||||
|
case 31:
|
||||||
|
return arr[31];
|
||||||
};
|
};
|
||||||
return T();
|
return T();
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
|
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
|
||||||
if (idx < 0 || idx >= N) return;
|
if (idx < 0 || idx >= N)
|
||||||
|
return;
|
||||||
switch (idx) {
|
switch (idx) {
|
||||||
case 0: arr[0] = val; break;
|
case 0:
|
||||||
case 1: arr[1] = val; break;
|
arr[0] = val;
|
||||||
case 2: arr[2] = val; break;
|
break;
|
||||||
case 3: arr[3] = val; break;
|
case 1:
|
||||||
case 4: arr[4] = val; break;
|
arr[1] = val;
|
||||||
case 5: arr[5] = val; break;
|
break;
|
||||||
case 6: arr[6] = val; break;
|
case 2:
|
||||||
case 7: arr[7] = val; break;
|
arr[2] = val;
|
||||||
case 8: arr[8] = val; break;
|
break;
|
||||||
case 9: arr[9] = val; break;
|
case 3:
|
||||||
case 10: arr[10] = val; break;
|
arr[3] = val;
|
||||||
case 11: arr[11] = val; break;
|
break;
|
||||||
case 12: arr[12] = val; break;
|
case 4:
|
||||||
case 13: arr[13] = val; break;
|
arr[4] = val;
|
||||||
case 14: arr[14] = val; break;
|
break;
|
||||||
case 15: arr[15] = val; break;
|
case 5:
|
||||||
case 16: arr[16] = val; break;
|
arr[5] = val;
|
||||||
case 17: arr[17] = val; break;
|
break;
|
||||||
case 18: arr[18] = val; break;
|
case 6:
|
||||||
case 19: arr[19] = val; break;
|
arr[6] = val;
|
||||||
case 20: arr[20] = val; break;
|
break;
|
||||||
case 21: arr[21] = val; break;
|
case 7:
|
||||||
case 22: arr[22] = val; break;
|
arr[7] = val;
|
||||||
case 23: arr[23] = val; break;
|
break;
|
||||||
case 24: arr[24] = val; break;
|
case 8:
|
||||||
case 25: arr[25] = val; break;
|
arr[8] = val;
|
||||||
case 26: arr[26] = val; break;
|
break;
|
||||||
case 27: arr[27] = val; break;
|
case 9:
|
||||||
case 28: arr[28] = val; break;
|
arr[9] = val;
|
||||||
case 29: arr[29] = val; break;
|
break;
|
||||||
case 30: arr[30] = val; break;
|
case 10:
|
||||||
case 31: arr[31] = val; break;
|
arr[10] = val;
|
||||||
|
break;
|
||||||
|
case 11:
|
||||||
|
arr[11] = val;
|
||||||
|
break;
|
||||||
|
case 12:
|
||||||
|
arr[12] = val;
|
||||||
|
break;
|
||||||
|
case 13:
|
||||||
|
arr[13] = val;
|
||||||
|
break;
|
||||||
|
case 14:
|
||||||
|
arr[14] = val;
|
||||||
|
break;
|
||||||
|
case 15:
|
||||||
|
arr[15] = val;
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
arr[16] = val;
|
||||||
|
break;
|
||||||
|
case 17:
|
||||||
|
arr[17] = val;
|
||||||
|
break;
|
||||||
|
case 18:
|
||||||
|
arr[18] = val;
|
||||||
|
break;
|
||||||
|
case 19:
|
||||||
|
arr[19] = val;
|
||||||
|
break;
|
||||||
|
case 20:
|
||||||
|
arr[20] = val;
|
||||||
|
break;
|
||||||
|
case 21:
|
||||||
|
arr[21] = val;
|
||||||
|
break;
|
||||||
|
case 22:
|
||||||
|
arr[22] = val;
|
||||||
|
break;
|
||||||
|
case 23:
|
||||||
|
arr[23] = val;
|
||||||
|
break;
|
||||||
|
case 24:
|
||||||
|
arr[24] = val;
|
||||||
|
break;
|
||||||
|
case 25:
|
||||||
|
arr[25] = val;
|
||||||
|
break;
|
||||||
|
case 26:
|
||||||
|
arr[26] = val;
|
||||||
|
break;
|
||||||
|
case 27:
|
||||||
|
arr[27] = val;
|
||||||
|
break;
|
||||||
|
case 28:
|
||||||
|
arr[28] = val;
|
||||||
|
break;
|
||||||
|
case 29:
|
||||||
|
arr[29] = val;
|
||||||
|
break;
|
||||||
|
case 30:
|
||||||
|
arr[30] = val;
|
||||||
|
break;
|
||||||
|
case 31:
|
||||||
|
arr[31] = val;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -289,7 +289,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
const size_t threads = 256;
|
const size_t threads = 256;
|
||||||
const size_t blocks = 256;
|
const size_t blocks = 256;
|
||||||
if (version == 0) {
|
if (version == 0) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||||
KNearestNeighborKernelV0<scalar_t>
|
KNearestNeighborKernelV0<scalar_t>
|
||||||
<<<blocks, threads>>>(
|
<<<blocks, threads>>>(
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.data_ptr<scalar_t>(),
|
||||||
@ -303,7 +303,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
K);
|
K);
|
||||||
}));
|
}));
|
||||||
} else if (version == 1) {
|
} else if (version == 1) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||||
DispatchKernel1D<
|
DispatchKernel1D<
|
||||||
KNearestNeighborV1Functor,
|
KNearestNeighborV1Functor,
|
||||||
scalar_t,
|
scalar_t,
|
||||||
@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
K);
|
K);
|
||||||
}));
|
}));
|
||||||
} else if (version == 2) {
|
} else if (version == 2) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||||
DispatchKernel2D<
|
DispatchKernel2D<
|
||||||
KNearestNeighborKernelV2Functor,
|
KNearestNeighborKernelV2Functor,
|
||||||
scalar_t,
|
scalar_t,
|
||||||
@ -343,7 +343,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
P2);
|
P2);
|
||||||
}));
|
}));
|
||||||
} else if (version == 3) {
|
} else if (version == 3) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||||
DispatchKernel2D<
|
DispatchKernel2D<
|
||||||
KNearestNeighborKernelV3Functor,
|
KNearestNeighborKernelV3Functor,
|
||||||
scalar_t,
|
scalar_t,
|
||||||
|
@ -13,11 +13,11 @@
|
|||||||
// containing P1 points of dimension D.
|
// containing P1 points of dimension D.
|
||||||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||||
// containing P2 points of dimension D.
|
// containing P2 points of dimension D.
|
||||||
// K: int giving the number of nearest points to return.
|
// K: int giving the number of nearest points to return.
|
||||||
// sorted: bool telling whether to sort the K returned points by their
|
// sorted: bool telling whether to sort the K returned points by their
|
||||||
// distance version: Integer telling which implementation to use.
|
// distance.
|
||||||
// TODO(jcjohns): Document this more, or maybe remove it before
|
// version: Integer telling which implementation to use.
|
||||||
// landing.
|
// TODO(jcjohns): Document this more, or maybe remove it before landing.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||||
@ -41,7 +41,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
|||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
int K,
|
int K,
|
||||||
int version) {
|
int version) {
|
||||||
if (p1.type().is_cuda() || p2.type().is_cuda()) {
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(p1);
|
CHECK_CONTIGUOUS_CUDA(p1);
|
||||||
CHECK_CONTIGUOUS_CUDA(p2);
|
CHECK_CONTIGUOUS_CUDA(p2);
|
||||||
|
@ -4,49 +4,48 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor>
|
||||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
|
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
|
||||||
const int N = p1.size(0);
|
const int N = p1.size(0);
|
||||||
const int P1 = p1.size(1);
|
const int P1 = p1.size(1);
|
||||||
const int D = p1.size(2);
|
const int D = p1.size(2);
|
||||||
const int P2 = p2.size(1);
|
const int P2 = p2.size(1);
|
||||||
|
|
||||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||||
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
||||||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||||
|
|
||||||
auto p1_a = p1.accessor<float, 3>();
|
auto p1_a = p1.accessor<float, 3>();
|
||||||
auto p2_a = p2.accessor<float, 3>();
|
auto p2_a = p2.accessor<float, 3>();
|
||||||
auto idxs_a = idxs.accessor<int64_t, 3>();
|
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||||
auto dists_a = dists.accessor<float, 3>();
|
auto dists_a = dists.accessor<float, 3>();
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
for (int i1 = 0; i1 < P1; ++i1) {
|
for (int i1 = 0; i1 < P1; ++i1) {
|
||||||
// Use a priority queue to store (distance, index) tuples.
|
// Use a priority queue to store (distance, index) tuples.
|
||||||
std::priority_queue<std::tuple<float, int>> q;
|
std::priority_queue<std::tuple<float, int>> q;
|
||||||
for (int i2 = 0; i2 < P2; ++i2) {
|
for (int i2 = 0; i2 < P2; ++i2) {
|
||||||
float dist = 0;
|
float dist = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||||
dist += diff * diff;
|
dist += diff * diff;
|
||||||
}
|
|
||||||
int size = static_cast<int>(q.size());
|
|
||||||
if (size < K || dist < std::get<0>(q.top())) {
|
|
||||||
q.emplace(dist, i2);
|
|
||||||
if (size >= K) {
|
|
||||||
q.pop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
while (!q.empty()) {
|
|
||||||
auto t = q.top();
|
|
||||||
q.pop();
|
|
||||||
const int k = q.size();
|
|
||||||
dists_a[n][i1][k] = std::get<0>(t);
|
|
||||||
idxs_a[n][i1][k] = std::get<1>(t);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
int size = static_cast<int>(q.size());
|
||||||
|
if (size < K || dist < std::get<0>(q.top())) {
|
||||||
|
q.emplace(dist, i2);
|
||||||
|
if (size >= K) {
|
||||||
|
q.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (!q.empty()) {
|
||||||
|
auto t = q.top();
|
||||||
|
q.pop();
|
||||||
|
const int k = q.size();
|
||||||
|
dists_a[n][i1][k] = std::get<0>(t);
|
||||||
|
idxs_a[n][i1][k] = std::get<1>(t);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::make_tuple(idxs, dists);
|
}
|
||||||
|
return std::make_tuple(idxs, dists);
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
#include "index_utils.cuh"
|
#include "index_utils.cuh"
|
||||||
|
|
||||||
|
|
||||||
// A data structure to keep track of the smallest K keys seen so far as well
|
// A data structure to keep track of the smallest K keys seen so far as well
|
||||||
// as their associated values, intended to be used in device code.
|
// as their associated values, intended to be used in device code.
|
||||||
// This data structure doesn't allocate any memory; keys and values are stored
|
// This data structure doesn't allocate any memory; keys and values are stored
|
||||||
@ -32,18 +31,17 @@
|
|||||||
// float key_k = keys[k];
|
// float key_k = keys[k];
|
||||||
// int value_k = values[k];
|
// int value_k = values[k];
|
||||||
// }
|
// }
|
||||||
template<typename key_t, typename value_t>
|
template <typename key_t, typename value_t>
|
||||||
class MinK {
|
class MinK {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
// Constructor.
|
// Constructor.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
// keys: Array in which to store keys
|
// keys: Array in which to store keys
|
||||||
// values: Array in which to store values
|
// values: Array in which to store values
|
||||||
// K: How many values to keep track of
|
// K: How many values to keep track of
|
||||||
__device__ MinK(key_t *keys, value_t *vals, int K) :
|
__device__ MinK(key_t* keys, value_t* vals, int K)
|
||||||
keys(keys), vals(vals), K(K), _size(0) { }
|
: keys(keys), vals(vals), K(K), _size(0) {}
|
||||||
|
|
||||||
// Try to add a new key and associated value to the data structure. If the key
|
// Try to add a new key and associated value to the data structure. If the key
|
||||||
// is one of the smallest K seen so far then it will be kept; otherwise it
|
// is one of the smallest K seen so far then it will be kept; otherwise it
|
||||||
@ -55,7 +53,7 @@ class MinK {
|
|||||||
// Arguments:
|
// Arguments:
|
||||||
// key: The key to add
|
// key: The key to add
|
||||||
// val: The value associated to the key
|
// val: The value associated to the key
|
||||||
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
|
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
|
||||||
if (_size < K) {
|
if (_size < K) {
|
||||||
keys[_size] = key;
|
keys[_size] = key;
|
||||||
vals[_size] = val;
|
vals[_size] = val;
|
||||||
@ -71,8 +69,8 @@ class MinK {
|
|||||||
for (int k = 0; k < K; ++k) {
|
for (int k = 0; k < K; ++k) {
|
||||||
key_t cur_key = keys[k];
|
key_t cur_key = keys[k];
|
||||||
if (cur_key > max_key) {
|
if (cur_key > max_key) {
|
||||||
max_key = cur_key;
|
max_key = cur_key;
|
||||||
max_idx = k;
|
max_idx = k;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -102,15 +100,14 @@ class MinK {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
key_t *keys;
|
key_t* keys;
|
||||||
value_t *vals;
|
value_t* vals;
|
||||||
int K;
|
int K;
|
||||||
int _size;
|
int _size;
|
||||||
key_t max_key;
|
key_t max_key;
|
||||||
int max_idx;
|
int max_idx;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// This is a version of MinK that only touches the arrays using static indexing
|
// This is a version of MinK that only touches the arrays using static indexing
|
||||||
// via RegisterIndexUtils. If the keys and values are stored in thread-local
|
// via RegisterIndexUtils. If the keys and values are stored in thread-local
|
||||||
// arrays, then this may allow the compiler to place them in registers for
|
// arrays, then this may allow the compiler to place them in registers for
|
||||||
@ -120,13 +117,13 @@ class MinK {
|
|||||||
// We found that sorting via RegisterIndexUtils gave very poor performance,
|
// We found that sorting via RegisterIndexUtils gave very poor performance,
|
||||||
// and suspect it may have prevented the compiler from placing the arrays
|
// and suspect it may have prevented the compiler from placing the arrays
|
||||||
// into registers.
|
// into registers.
|
||||||
template<typename key_t, typename value_t, int K>
|
template <typename key_t, typename value_t, int K>
|
||||||
class RegisterMinK {
|
class RegisterMinK {
|
||||||
public:
|
public:
|
||||||
__device__ RegisterMinK(key_t *keys, value_t *vals) :
|
__device__ RegisterMinK(key_t* keys, value_t* vals)
|
||||||
keys(keys), vals(vals), _size(0) {}
|
: keys(keys), vals(vals), _size(0) {}
|
||||||
|
|
||||||
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
|
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
|
||||||
if (_size < K) {
|
if (_size < K) {
|
||||||
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
|
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
|
||||||
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
|
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
|
||||||
@ -154,8 +151,8 @@ class RegisterMinK {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
key_t *keys;
|
key_t* keys;
|
||||||
value_t *vals;
|
value_t* vals;
|
||||||
int _size;
|
int _size;
|
||||||
key_t max_key;
|
key_t max_key;
|
||||||
int max_idx;
|
int max_idx;
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d import _C
|
from pytorch3d import _C
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fvcore.common.benchmark import benchmark
|
from fvcore.common.benchmark import benchmark
|
||||||
|
|
||||||
@ -30,21 +29,13 @@ def benchmark_knn_cuda_versions() -> None:
|
|||||||
continue
|
continue
|
||||||
if version == 3 and K > 4:
|
if version == 3 and K > 4:
|
||||||
continue
|
continue
|
||||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version})
|
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
|
||||||
for N, P, D in product(Ns, Ps, Ds):
|
for N, P, D in product(Ns, Ps, Ds):
|
||||||
nn_kwargs.append({'N': N, 'D': D, 'P': P})
|
nn_kwargs.append({"N": N, "D": D, "P": P})
|
||||||
benchmark(
|
benchmark(
|
||||||
knn_cuda_with_init,
|
knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1
|
||||||
'KNN_CUDA_VERSIONS',
|
|
||||||
knn_kwargs,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
|
||||||
benchmark(
|
|
||||||
nn_cuda_with_init,
|
|
||||||
'NN_CUDA',
|
|
||||||
nn_kwargs,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
)
|
||||||
|
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_knn_cuda_vs_naive() -> None:
|
def benchmark_knn_cuda_vs_naive() -> None:
|
||||||
@ -55,21 +46,16 @@ def benchmark_knn_cuda_vs_naive() -> None:
|
|||||||
Ks = [1, 2, 4, 8, 16]
|
Ks = [1, 2, 4, 8, 16]
|
||||||
knn_kwargs, naive_kwargs = [], []
|
knn_kwargs, naive_kwargs = [], []
|
||||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
||||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
||||||
if P <= 4096:
|
if P <= 4096:
|
||||||
naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
||||||
benchmark(
|
benchmark(
|
||||||
knn_python_cuda_with_init,
|
knn_python_cuda_with_init,
|
||||||
'KNN_CUDA_PYTHON',
|
"KNN_CUDA_PYTHON",
|
||||||
naive_kwargs,
|
naive_kwargs,
|
||||||
warmup_iters=1,
|
warmup_iters=1,
|
||||||
)
|
)
|
||||||
benchmark(
|
benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
|
||||||
knn_cuda_with_init,
|
|
||||||
'KNN_CUDA',
|
|
||||||
knn_kwargs,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_knn_cpu() -> None:
|
def benchmark_knn_cpu() -> None:
|
||||||
@ -79,31 +65,18 @@ def benchmark_knn_cpu() -> None:
|
|||||||
Ks = [1, 2, 4]
|
Ks = [1, 2, 4]
|
||||||
knn_kwargs, nn_kwargs = [], []
|
knn_kwargs, nn_kwargs = [], []
|
||||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
||||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
||||||
for N, P, D in product(Ns, Ps, Ds):
|
for N, P, D in product(Ns, Ps, Ds):
|
||||||
nn_kwargs.append({'N': N, 'D': D, 'P': P})
|
nn_kwargs.append({"N": N, "D": D, "P": P})
|
||||||
benchmark(
|
benchmark(
|
||||||
knn_python_cpu_with_init,
|
knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1
|
||||||
'KNN_CPU_PYTHON',
|
|
||||||
knn_kwargs,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
|
||||||
benchmark(
|
|
||||||
knn_cpu_with_init,
|
|
||||||
'KNN_CPU_CPP',
|
|
||||||
knn_kwargs,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
|
||||||
benchmark(
|
|
||||||
nn_cpu_with_init,
|
|
||||||
'NN_CPU_CPP',
|
|
||||||
nn_kwargs,
|
|
||||||
warmup_iters=1,
|
|
||||||
)
|
)
|
||||||
|
benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
|
||||||
|
benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
|
||||||
|
|
||||||
|
|
||||||
def knn_cuda_with_init(N, D, P, K, v=-1):
|
def knn_cuda_with_init(N, D, P, K, v=-1):
|
||||||
device = torch.device('cuda:0')
|
device = torch.device("cuda:0")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -116,7 +89,7 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
|
|||||||
|
|
||||||
|
|
||||||
def knn_cpu_with_init(N, D, P, K):
|
def knn_cpu_with_init(N, D, P, K):
|
||||||
device = torch.device('cpu')
|
device = torch.device("cpu")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
|
|
||||||
@ -127,7 +100,7 @@ def knn_cpu_with_init(N, D, P, K):
|
|||||||
|
|
||||||
|
|
||||||
def knn_python_cuda_with_init(N, D, P, K):
|
def knn_python_cuda_with_init(N, D, P, K):
|
||||||
device = torch.device('cuda')
|
device = torch.device("cuda")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -140,7 +113,7 @@ def knn_python_cuda_with_init(N, D, P, K):
|
|||||||
|
|
||||||
|
|
||||||
def knn_python_cpu_with_init(N, D, P, K):
|
def knn_python_cpu_with_init(N, D, P, K):
|
||||||
device = torch.device('cpu')
|
device = torch.device("cpu")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
|
|
||||||
@ -151,7 +124,7 @@ def knn_python_cpu_with_init(N, D, P, K):
|
|||||||
|
|
||||||
|
|
||||||
def nn_cuda_with_init(N, D, P):
|
def nn_cuda_with_init(N, D, P):
|
||||||
device = torch.device('cuda')
|
device = torch.device("cuda")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -164,7 +137,7 @@ def nn_cuda_with_init(N, D, P):
|
|||||||
|
|
||||||
|
|
||||||
def nn_cpu_with_init(N, D, P):
|
def nn_cpu_with_init(N, D, P):
|
||||||
device = torch.device('cpu')
|
device = torch.device("cpu")
|
||||||
x = torch.randn(N, P, D, device=device)
|
x = torch.randn(N, P, D, device=device)
|
||||||
y = torch.randn(N, P, D, device=device)
|
y = torch.randn(N, P, D, device=device)
|
||||||
|
|
||||||
|
@ -22,6 +22,27 @@ class TestBuild(unittest.TestCase):
|
|||||||
for k, v in counter.items():
|
for k, v in counter.items():
|
||||||
self.assertEqual(v, 1, f"Too many files with stem {k}.")
|
self.assertEqual(v, 1, f"Too many files with stem {k}.")
|
||||||
|
|
||||||
|
def test_deprecated_usage(self):
|
||||||
|
# Check certain expressions do not occur in the csrc code
|
||||||
|
test_dir = Path(__file__).resolve().parent
|
||||||
|
source_dir = test_dir.parent / "pytorch3d" / "csrc"
|
||||||
|
|
||||||
|
files = sorted(source_dir.glob("**/*.*"))
|
||||||
|
self.assertGreater(len(files), 4)
|
||||||
|
|
||||||
|
patterns = [".type()", ".data()"]
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
with open(file) as f:
|
||||||
|
text = f.read()
|
||||||
|
for pattern in patterns:
|
||||||
|
found = pattern in text
|
||||||
|
msg = (
|
||||||
|
f"{pattern} found in {file.name}"
|
||||||
|
+ ", this has been deprecated."
|
||||||
|
)
|
||||||
|
self.assertFalse(found, msg)
|
||||||
|
|
||||||
def test_copyright(self):
|
def test_copyright(self):
|
||||||
test_dir = Path(__file__).resolve().parent
|
test_dir = Path(__file__).resolve().parent
|
||||||
root_dir = test_dir.parent
|
root_dir = test_dir.parent
|
||||||
|
@ -28,7 +28,7 @@ class TestKNN(unittest.TestCase):
|
|||||||
|
|
||||||
def test_knn_vs_python_cpu(self):
|
def test_knn_vs_python_cpu(self):
|
||||||
""" Test CPU output vs PyTorch implementation """
|
""" Test CPU output vs PyTorch implementation """
|
||||||
device = torch.device('cpu')
|
device = torch.device("cpu")
|
||||||
Ns = [1, 4]
|
Ns = [1, 4]
|
||||||
Ds = [2, 3]
|
Ds = [2, 3]
|
||||||
P1s = [1, 10, 101]
|
P1s = [1, 10, 101]
|
||||||
@ -45,7 +45,7 @@ class TestKNN(unittest.TestCase):
|
|||||||
|
|
||||||
def test_knn_vs_python_cuda(self):
|
def test_knn_vs_python_cuda(self):
|
||||||
""" Test CUDA output vs PyTorch implementation """
|
""" Test CUDA output vs PyTorch implementation """
|
||||||
device = torch.device('cuda')
|
device = torch.device("cuda")
|
||||||
Ns = [1, 4]
|
Ns = [1, 4]
|
||||||
Ds = [2, 3, 8]
|
Ds = [2, 3, 8]
|
||||||
P1s = [1, 8, 64, 128, 1001]
|
P1s = [1, 8, 64, 128, 1001]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user