mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Linter, deprecated type()
Summary: Run linter after recent changes. Fix long comment in knn.h which clang-format has reflowed badly. Add crude test that code doesn't call deprecated `.type()` or `.data()`. Reviewed By: nikhilaravi Differential Revision: D20692935 fbshipit-source-id: 28ce0308adae79a870cb41a810b7cf8744f41ab8
This commit is contained in:
parent
3061c5b663
commit
37c5c8e0b6
@ -1,9 +1,9 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
//
|
||||
// This file provides utilities for dispatching to specialized versions of functions.
|
||||
// This is especially useful for CUDA kernels, since specializing them to particular
|
||||
// input sizes can often allow the compiler to unroll loops and place arrays into
|
||||
// registers, which can give huge performance speedups.
|
||||
// This file provides utilities for dispatching to specialized versions of
|
||||
// functions. This is especially useful for CUDA kernels, since specializing
|
||||
// them to particular input sizes can often allow the compiler to unroll loops
|
||||
// and place arrays into registers, which can give huge performance speedups.
|
||||
//
|
||||
// As an example, suppose we have the following function which is specialized
|
||||
// based on a compile-time int64_t value:
|
||||
@ -92,14 +92,13 @@ namespace {
|
||||
// In order to dispatch, we will take an additional template argument curN,
|
||||
// and increment it via template recursion until it is equal to the run-time
|
||||
// argument N.
|
||||
template<
|
||||
template<typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t curN,
|
||||
typename... Args
|
||||
>
|
||||
template <
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t curN,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper1D {
|
||||
static void run(const int64_t N, Args... args) {
|
||||
if (curN == N) {
|
||||
@ -108,22 +107,21 @@ struct DispatchKernelHelper1D {
|
||||
Kernel<T, curN>::run(args...);
|
||||
} else if (curN < N) {
|
||||
// Increment curN via template recursion
|
||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...);
|
||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
|
||||
N, args...);
|
||||
}
|
||||
// We shouldn't get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 1D dispatch: Specialization when curN == maxN
|
||||
// We need this base case to avoid infinite template recursion.
|
||||
template<
|
||||
template<typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
typename... Args
|
||||
>
|
||||
template <
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
||||
static void run(const int64_t N, Args... args) {
|
||||
if (N == maxN) {
|
||||
@ -133,19 +131,21 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 2D dispatch, general case.
|
||||
// This is similar to the 1D case: we take additional template args curN and
|
||||
// curM, and increment them via template recursion until they are equal to
|
||||
// the run-time values of N and M, at which point we dispatch to the run
|
||||
// method of the kernel.
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN, int64_t curN,
|
||||
int64_t minM, int64_t maxM, int64_t curM,
|
||||
typename... Args
|
||||
>
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t curN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
int64_t curM,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper2D {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (curN == N && curM == M) {
|
||||
@ -154,67 +154,141 @@ struct DispatchKernelHelper2D {
|
||||
// Increment both curN and curM. This isn't strictly necessary; we could
|
||||
// just increment one or the other at each step. But this helps to cut
|
||||
// on the number of recursive calls we make.
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN + 1,
|
||||
minM,
|
||||
maxM,
|
||||
curM + 1,
|
||||
Args...>::run(N, M, args...);
|
||||
} else if (curN < N) {
|
||||
// Increment curN only
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...);
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN + 1,
|
||||
minM,
|
||||
maxM,
|
||||
curM,
|
||||
Args...>::run(N, M, args...);
|
||||
} else if (curM < M) {
|
||||
// Increment curM only
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN,
|
||||
minM,
|
||||
maxM,
|
||||
curM + 1,
|
||||
Args...>::run(N, M, args...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN,
|
||||
int64_t minM, int64_t maxM, int64_t curM,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> {
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
int64_t curM,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
maxN,
|
||||
minM,
|
||||
maxM,
|
||||
curM,
|
||||
Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (maxN == N && curM == M) {
|
||||
Kernel<T, maxN, curM>::run(args...);
|
||||
} else if (curM < maxM) {
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
maxN,
|
||||
minM,
|
||||
maxM,
|
||||
curM + 1,
|
||||
Args...>::run(N, M, args...);
|
||||
}
|
||||
// We should not get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 2D dispatch, specialization for curM == maxM
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN, int64_t curN,
|
||||
int64_t minM, int64_t maxM,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> {
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t curN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN,
|
||||
minM,
|
||||
maxM,
|
||||
maxM,
|
||||
Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (curN == N && maxM == M) {
|
||||
Kernel<T, curN, maxM>::run(args...);
|
||||
} else if (curN < maxN) {
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...);
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN + 1,
|
||||
minM,
|
||||
maxM,
|
||||
maxM,
|
||||
Args...>::run(N, M, args...);
|
||||
}
|
||||
// We should not get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN,
|
||||
int64_t minM, int64_t maxM,
|
||||
typename... Args
|
||||
>
|
||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> {
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
maxN,
|
||||
minM,
|
||||
maxM,
|
||||
maxM,
|
||||
Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (maxN == N && maxM == M) {
|
||||
Kernel<T, maxN, maxM>::run(args...);
|
||||
@ -225,37 +299,45 @@ struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Arg
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
// This is the function we expect users to call to dispatch to 1D functions
|
||||
template<
|
||||
template<typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
typename... Args
|
||||
>
|
||||
template <
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
typename... Args>
|
||||
void DispatchKernel1D(const int64_t N, Args... args) {
|
||||
if (minN <= N && N <= maxN) {
|
||||
// Kick off the template recursion by calling the Helper with curN = minN
|
||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...);
|
||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
|
||||
N, args...);
|
||||
}
|
||||
// Maybe throw an error if we tried to dispatch outside the allowed range?
|
||||
}
|
||||
|
||||
|
||||
// This is the function we expect users to call to dispatch to 2D functions
|
||||
template<
|
||||
template<typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN, int64_t maxN,
|
||||
int64_t minM, int64_t maxM,
|
||||
typename... Args
|
||||
>
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
typename... Args>
|
||||
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
|
||||
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
|
||||
// Kick off the template recursion by calling the Helper with curN = minN
|
||||
// and curM = minM
|
||||
DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...);
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
minN,
|
||||
minM,
|
||||
maxM,
|
||||
minM,
|
||||
Args...>::run(N, M, args...);
|
||||
}
|
||||
// Maybe throw an error if we tried to dispatch outside the specified range?
|
||||
}
|
||||
|
@ -39,82 +39,180 @@
|
||||
// approach for this might lead to extra function calls at runtime if the
|
||||
// compiler fails to optimize them away, which could be very slow on device.
|
||||
// However I didn't actually benchmark or test this.
|
||||
template<typename T, int N>
|
||||
template <typename T, int N>
|
||||
struct RegisterIndexUtils {
|
||||
__device__ __forceinline__ static T get(const T arr[N], int idx) {
|
||||
if (idx < 0 || idx >= N) return T();
|
||||
if (idx < 0 || idx >= N)
|
||||
return T();
|
||||
switch (idx) {
|
||||
case 0: return arr[0];
|
||||
case 1: return arr[1];
|
||||
case 2: return arr[2];
|
||||
case 3: return arr[3];
|
||||
case 4: return arr[4];
|
||||
case 5: return arr[5];
|
||||
case 6: return arr[6];
|
||||
case 7: return arr[7];
|
||||
case 8: return arr[8];
|
||||
case 9: return arr[9];
|
||||
case 10: return arr[10];
|
||||
case 11: return arr[11];
|
||||
case 12: return arr[12];
|
||||
case 13: return arr[13];
|
||||
case 14: return arr[14];
|
||||
case 15: return arr[15];
|
||||
case 16: return arr[16];
|
||||
case 17: return arr[17];
|
||||
case 18: return arr[18];
|
||||
case 19: return arr[19];
|
||||
case 20: return arr[20];
|
||||
case 21: return arr[21];
|
||||
case 22: return arr[22];
|
||||
case 23: return arr[23];
|
||||
case 24: return arr[24];
|
||||
case 25: return arr[25];
|
||||
case 26: return arr[26];
|
||||
case 27: return arr[27];
|
||||
case 28: return arr[28];
|
||||
case 29: return arr[29];
|
||||
case 30: return arr[30];
|
||||
case 31: return arr[31];
|
||||
case 0:
|
||||
return arr[0];
|
||||
case 1:
|
||||
return arr[1];
|
||||
case 2:
|
||||
return arr[2];
|
||||
case 3:
|
||||
return arr[3];
|
||||
case 4:
|
||||
return arr[4];
|
||||
case 5:
|
||||
return arr[5];
|
||||
case 6:
|
||||
return arr[6];
|
||||
case 7:
|
||||
return arr[7];
|
||||
case 8:
|
||||
return arr[8];
|
||||
case 9:
|
||||
return arr[9];
|
||||
case 10:
|
||||
return arr[10];
|
||||
case 11:
|
||||
return arr[11];
|
||||
case 12:
|
||||
return arr[12];
|
||||
case 13:
|
||||
return arr[13];
|
||||
case 14:
|
||||
return arr[14];
|
||||
case 15:
|
||||
return arr[15];
|
||||
case 16:
|
||||
return arr[16];
|
||||
case 17:
|
||||
return arr[17];
|
||||
case 18:
|
||||
return arr[18];
|
||||
case 19:
|
||||
return arr[19];
|
||||
case 20:
|
||||
return arr[20];
|
||||
case 21:
|
||||
return arr[21];
|
||||
case 22:
|
||||
return arr[22];
|
||||
case 23:
|
||||
return arr[23];
|
||||
case 24:
|
||||
return arr[24];
|
||||
case 25:
|
||||
return arr[25];
|
||||
case 26:
|
||||
return arr[26];
|
||||
case 27:
|
||||
return arr[27];
|
||||
case 28:
|
||||
return arr[28];
|
||||
case 29:
|
||||
return arr[29];
|
||||
case 30:
|
||||
return arr[30];
|
||||
case 31:
|
||||
return arr[31];
|
||||
};
|
||||
return T();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
|
||||
if (idx < 0 || idx >= N) return;
|
||||
if (idx < 0 || idx >= N)
|
||||
return;
|
||||
switch (idx) {
|
||||
case 0: arr[0] = val; break;
|
||||
case 1: arr[1] = val; break;
|
||||
case 2: arr[2] = val; break;
|
||||
case 3: arr[3] = val; break;
|
||||
case 4: arr[4] = val; break;
|
||||
case 5: arr[5] = val; break;
|
||||
case 6: arr[6] = val; break;
|
||||
case 7: arr[7] = val; break;
|
||||
case 8: arr[8] = val; break;
|
||||
case 9: arr[9] = val; break;
|
||||
case 10: arr[10] = val; break;
|
||||
case 11: arr[11] = val; break;
|
||||
case 12: arr[12] = val; break;
|
||||
case 13: arr[13] = val; break;
|
||||
case 14: arr[14] = val; break;
|
||||
case 15: arr[15] = val; break;
|
||||
case 16: arr[16] = val; break;
|
||||
case 17: arr[17] = val; break;
|
||||
case 18: arr[18] = val; break;
|
||||
case 19: arr[19] = val; break;
|
||||
case 20: arr[20] = val; break;
|
||||
case 21: arr[21] = val; break;
|
||||
case 22: arr[22] = val; break;
|
||||
case 23: arr[23] = val; break;
|
||||
case 24: arr[24] = val; break;
|
||||
case 25: arr[25] = val; break;
|
||||
case 26: arr[26] = val; break;
|
||||
case 27: arr[27] = val; break;
|
||||
case 28: arr[28] = val; break;
|
||||
case 29: arr[29] = val; break;
|
||||
case 30: arr[30] = val; break;
|
||||
case 31: arr[31] = val; break;
|
||||
case 0:
|
||||
arr[0] = val;
|
||||
break;
|
||||
case 1:
|
||||
arr[1] = val;
|
||||
break;
|
||||
case 2:
|
||||
arr[2] = val;
|
||||
break;
|
||||
case 3:
|
||||
arr[3] = val;
|
||||
break;
|
||||
case 4:
|
||||
arr[4] = val;
|
||||
break;
|
||||
case 5:
|
||||
arr[5] = val;
|
||||
break;
|
||||
case 6:
|
||||
arr[6] = val;
|
||||
break;
|
||||
case 7:
|
||||
arr[7] = val;
|
||||
break;
|
||||
case 8:
|
||||
arr[8] = val;
|
||||
break;
|
||||
case 9:
|
||||
arr[9] = val;
|
||||
break;
|
||||
case 10:
|
||||
arr[10] = val;
|
||||
break;
|
||||
case 11:
|
||||
arr[11] = val;
|
||||
break;
|
||||
case 12:
|
||||
arr[12] = val;
|
||||
break;
|
||||
case 13:
|
||||
arr[13] = val;
|
||||
break;
|
||||
case 14:
|
||||
arr[14] = val;
|
||||
break;
|
||||
case 15:
|
||||
arr[15] = val;
|
||||
break;
|
||||
case 16:
|
||||
arr[16] = val;
|
||||
break;
|
||||
case 17:
|
||||
arr[17] = val;
|
||||
break;
|
||||
case 18:
|
||||
arr[18] = val;
|
||||
break;
|
||||
case 19:
|
||||
arr[19] = val;
|
||||
break;
|
||||
case 20:
|
||||
arr[20] = val;
|
||||
break;
|
||||
case 21:
|
||||
arr[21] = val;
|
||||
break;
|
||||
case 22:
|
||||
arr[22] = val;
|
||||
break;
|
||||
case 23:
|
||||
arr[23] = val;
|
||||
break;
|
||||
case 24:
|
||||
arr[24] = val;
|
||||
break;
|
||||
case 25:
|
||||
arr[25] = val;
|
||||
break;
|
||||
case 26:
|
||||
arr[26] = val;
|
||||
break;
|
||||
case 27:
|
||||
arr[27] = val;
|
||||
break;
|
||||
case 28:
|
||||
arr[28] = val;
|
||||
break;
|
||||
case 29:
|
||||
arr[29] = val;
|
||||
break;
|
||||
case 30:
|
||||
arr[30] = val;
|
||||
break;
|
||||
case 31:
|
||||
arr[31] = val;
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -289,7 +289,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const size_t threads = 256;
|
||||
const size_t blocks = 256;
|
||||
if (version == 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||
KNearestNeighborKernelV0<scalar_t>
|
||||
<<<blocks, threads>>>(
|
||||
p1.data_ptr<scalar_t>(),
|
||||
@ -303,7 +303,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
K);
|
||||
}));
|
||||
} else if (version == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||
DispatchKernel1D<
|
||||
KNearestNeighborV1Functor,
|
||||
scalar_t,
|
||||
@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
K);
|
||||
}));
|
||||
} else if (version == 2) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||
DispatchKernel2D<
|
||||
KNearestNeighborKernelV2Functor,
|
||||
scalar_t,
|
||||
@ -343,7 +343,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
P2);
|
||||
}));
|
||||
} else if (version == 3) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||
DispatchKernel2D<
|
||||
KNearestNeighborKernelV3Functor,
|
||||
scalar_t,
|
||||
|
@ -13,11 +13,11 @@
|
||||
// containing P1 points of dimension D.
|
||||
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||
// containing P2 points of dimension D.
|
||||
// K: int giving the number of nearest points to return.
|
||||
// sorted: bool telling whether to sort the K returned points by their
|
||||
// distance version: Integer telling which implementation to use.
|
||||
// TODO(jcjohns): Document this more, or maybe remove it before
|
||||
// landing.
|
||||
// K: int giving the number of nearest points to return.
|
||||
// sorted: bool telling whether to sort the K returned points by their
|
||||
// distance.
|
||||
// version: Integer telling which implementation to use.
|
||||
// TODO(jcjohns): Document this more, or maybe remove it before landing.
|
||||
//
|
||||
// Returns:
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
@ -41,7 +41,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
const at::Tensor& p2,
|
||||
int K,
|
||||
int version) {
|
||||
if (p1.type().is_cuda() || p2.type().is_cuda()) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(p1);
|
||||
CHECK_CONTIGUOUS_CUDA(p2);
|
||||
|
@ -4,49 +4,48 @@
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
const int P2 = p2.size(1);
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
const int P2 = p2.size(1);
|
||||
|
||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
||||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
|
||||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||
|
||||
auto p1_a = p1.accessor<float, 3>();
|
||||
auto p2_a = p2.accessor<float, 3>();
|
||||
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||
auto dists_a = dists.accessor<float, 3>();
|
||||
auto p1_a = p1.accessor<float, 3>();
|
||||
auto p2_a = p2.accessor<float, 3>();
|
||||
auto idxs_a = idxs.accessor<int64_t, 3>();
|
||||
auto dists_a = dists.accessor<float, 3>();
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int i1 = 0; i1 < P1; ++i1) {
|
||||
// Use a priority queue to store (distance, index) tuples.
|
||||
std::priority_queue<std::tuple<float, int>> q;
|
||||
for (int i2 = 0; i2 < P2; ++i2) {
|
||||
float dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||
dist += diff * diff;
|
||||
}
|
||||
int size = static_cast<int>(q.size());
|
||||
if (size < K || dist < std::get<0>(q.top())) {
|
||||
q.emplace(dist, i2);
|
||||
if (size >= K) {
|
||||
q.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
while (!q.empty()) {
|
||||
auto t = q.top();
|
||||
q.pop();
|
||||
const int k = q.size();
|
||||
dists_a[n][i1][k] = std::get<0>(t);
|
||||
idxs_a[n][i1][k] = std::get<1>(t);
|
||||
}
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int i1 = 0; i1 < P1; ++i1) {
|
||||
// Use a priority queue to store (distance, index) tuples.
|
||||
std::priority_queue<std::tuple<float, int>> q;
|
||||
for (int i2 = 0; i2 < P2; ++i2) {
|
||||
float dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||
dist += diff * diff;
|
||||
}
|
||||
int size = static_cast<int>(q.size());
|
||||
if (size < K || dist < std::get<0>(q.top())) {
|
||||
q.emplace(dist, i2);
|
||||
if (size >= K) {
|
||||
q.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
while (!q.empty()) {
|
||||
auto t = q.top();
|
||||
q.pop();
|
||||
const int k = q.size();
|
||||
dists_a[n][i1][k] = std::get<0>(t);
|
||||
idxs_a[n][i1][k] = std::get<1>(t);
|
||||
}
|
||||
}
|
||||
return std::make_tuple(idxs, dists);
|
||||
}
|
||||
return std::make_tuple(idxs, dists);
|
||||
}
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
#include "index_utils.cuh"
|
||||
|
||||
|
||||
// A data structure to keep track of the smallest K keys seen so far as well
|
||||
// as their associated values, intended to be used in device code.
|
||||
// This data structure doesn't allocate any memory; keys and values are stored
|
||||
@ -32,18 +31,17 @@
|
||||
// float key_k = keys[k];
|
||||
// int value_k = values[k];
|
||||
// }
|
||||
template<typename key_t, typename value_t>
|
||||
template <typename key_t, typename value_t>
|
||||
class MinK {
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
//
|
||||
// Arguments:
|
||||
// keys: Array in which to store keys
|
||||
// values: Array in which to store values
|
||||
// K: How many values to keep track of
|
||||
__device__ MinK(key_t *keys, value_t *vals, int K) :
|
||||
keys(keys), vals(vals), K(K), _size(0) { }
|
||||
__device__ MinK(key_t* keys, value_t* vals, int K)
|
||||
: keys(keys), vals(vals), K(K), _size(0) {}
|
||||
|
||||
// Try to add a new key and associated value to the data structure. If the key
|
||||
// is one of the smallest K seen so far then it will be kept; otherwise it
|
||||
@ -55,7 +53,7 @@ class MinK {
|
||||
// Arguments:
|
||||
// key: The key to add
|
||||
// val: The value associated to the key
|
||||
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
|
||||
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
|
||||
if (_size < K) {
|
||||
keys[_size] = key;
|
||||
vals[_size] = val;
|
||||
@ -71,8 +69,8 @@ class MinK {
|
||||
for (int k = 0; k < K; ++k) {
|
||||
key_t cur_key = keys[k];
|
||||
if (cur_key > max_key) {
|
||||
max_key = cur_key;
|
||||
max_idx = k;
|
||||
max_key = cur_key;
|
||||
max_idx = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -102,15 +100,14 @@ class MinK {
|
||||
}
|
||||
|
||||
private:
|
||||
key_t *keys;
|
||||
value_t *vals;
|
||||
key_t* keys;
|
||||
value_t* vals;
|
||||
int K;
|
||||
int _size;
|
||||
key_t max_key;
|
||||
int max_idx;
|
||||
};
|
||||
|
||||
|
||||
// This is a version of MinK that only touches the arrays using static indexing
|
||||
// via RegisterIndexUtils. If the keys and values are stored in thread-local
|
||||
// arrays, then this may allow the compiler to place them in registers for
|
||||
@ -120,13 +117,13 @@ class MinK {
|
||||
// We found that sorting via RegisterIndexUtils gave very poor performance,
|
||||
// and suspect it may have prevented the compiler from placing the arrays
|
||||
// into registers.
|
||||
template<typename key_t, typename value_t, int K>
|
||||
template <typename key_t, typename value_t, int K>
|
||||
class RegisterMinK {
|
||||
public:
|
||||
__device__ RegisterMinK(key_t *keys, value_t *vals) :
|
||||
keys(keys), vals(vals), _size(0) {}
|
||||
__device__ RegisterMinK(key_t* keys, value_t* vals)
|
||||
: keys(keys), vals(vals), _size(0) {}
|
||||
|
||||
__device__ __forceinline__ void add(const key_t &key, const value_t &val) {
|
||||
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
|
||||
if (_size < K) {
|
||||
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
|
||||
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
|
||||
@ -154,8 +151,8 @@ class RegisterMinK {
|
||||
}
|
||||
|
||||
private:
|
||||
key_t *keys;
|
||||
value_t *vals;
|
||||
key_t* keys;
|
||||
value_t* vals;
|
||||
int _size;
|
||||
key_t max_key;
|
||||
int max_idx;
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
@ -30,21 +29,13 @@ def benchmark_knn_cuda_versions() -> None:
|
||||
continue
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version})
|
||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
|
||||
for N, P, D in product(Ns, Ps, Ds):
|
||||
nn_kwargs.append({'N': N, 'D': D, 'P': P})
|
||||
nn_kwargs.append({"N": N, "D": D, "P": P})
|
||||
benchmark(
|
||||
knn_cuda_with_init,
|
||||
'KNN_CUDA_VERSIONS',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
nn_cuda_with_init,
|
||||
'NN_CUDA',
|
||||
nn_kwargs,
|
||||
warmup_iters=1,
|
||||
knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1
|
||||
)
|
||||
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
|
||||
|
||||
|
||||
def benchmark_knn_cuda_vs_naive() -> None:
|
||||
@ -55,21 +46,16 @@ def benchmark_knn_cuda_vs_naive() -> None:
|
||||
Ks = [1, 2, 4, 8, 16]
|
||||
knn_kwargs, naive_kwargs = [], []
|
||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
||||
if P <= 4096:
|
||||
naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
||||
benchmark(
|
||||
knn_python_cuda_with_init,
|
||||
'KNN_CUDA_PYTHON',
|
||||
"KNN_CUDA_PYTHON",
|
||||
naive_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
knn_cuda_with_init,
|
||||
'KNN_CUDA',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
|
||||
|
||||
|
||||
def benchmark_knn_cpu() -> None:
|
||||
@ -79,31 +65,18 @@ def benchmark_knn_cpu() -> None:
|
||||
Ks = [1, 2, 4]
|
||||
knn_kwargs, nn_kwargs = [], []
|
||||
for N, P, D, K in product(Ns, Ps, Ds, Ks):
|
||||
knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
|
||||
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
|
||||
for N, P, D in product(Ns, Ps, Ds):
|
||||
nn_kwargs.append({'N': N, 'D': D, 'P': P})
|
||||
nn_kwargs.append({"N": N, "D": D, "P": P})
|
||||
benchmark(
|
||||
knn_python_cpu_with_init,
|
||||
'KNN_CPU_PYTHON',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
knn_cpu_with_init,
|
||||
'KNN_CPU_CPP',
|
||||
knn_kwargs,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
nn_cpu_with_init,
|
||||
'NN_CPU_CPP',
|
||||
nn_kwargs,
|
||||
warmup_iters=1,
|
||||
knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1
|
||||
)
|
||||
benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
|
||||
benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
|
||||
|
||||
|
||||
def knn_cuda_with_init(N, D, P, K, v=-1):
|
||||
device = torch.device('cuda:0')
|
||||
device = torch.device("cuda:0")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
@ -116,7 +89,7 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
|
||||
|
||||
|
||||
def knn_cpu_with_init(N, D, P, K):
|
||||
device = torch.device('cpu')
|
||||
device = torch.device("cpu")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
@ -127,7 +100,7 @@ def knn_cpu_with_init(N, D, P, K):
|
||||
|
||||
|
||||
def knn_python_cuda_with_init(N, D, P, K):
|
||||
device = torch.device('cuda')
|
||||
device = torch.device("cuda")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
@ -140,7 +113,7 @@ def knn_python_cuda_with_init(N, D, P, K):
|
||||
|
||||
|
||||
def knn_python_cpu_with_init(N, D, P, K):
|
||||
device = torch.device('cpu')
|
||||
device = torch.device("cpu")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
@ -151,7 +124,7 @@ def knn_python_cpu_with_init(N, D, P, K):
|
||||
|
||||
|
||||
def nn_cuda_with_init(N, D, P):
|
||||
device = torch.device('cuda')
|
||||
device = torch.device("cuda")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
@ -164,7 +137,7 @@ def nn_cuda_with_init(N, D, P):
|
||||
|
||||
|
||||
def nn_cpu_with_init(N, D, P):
|
||||
device = torch.device('cpu')
|
||||
device = torch.device("cpu")
|
||||
x = torch.randn(N, P, D, device=device)
|
||||
y = torch.randn(N, P, D, device=device)
|
||||
|
||||
|
@ -22,6 +22,27 @@ class TestBuild(unittest.TestCase):
|
||||
for k, v in counter.items():
|
||||
self.assertEqual(v, 1, f"Too many files with stem {k}.")
|
||||
|
||||
def test_deprecated_usage(self):
|
||||
# Check certain expressions do not occur in the csrc code
|
||||
test_dir = Path(__file__).resolve().parent
|
||||
source_dir = test_dir.parent / "pytorch3d" / "csrc"
|
||||
|
||||
files = sorted(source_dir.glob("**/*.*"))
|
||||
self.assertGreater(len(files), 4)
|
||||
|
||||
patterns = [".type()", ".data()"]
|
||||
|
||||
for file in files:
|
||||
with open(file) as f:
|
||||
text = f.read()
|
||||
for pattern in patterns:
|
||||
found = pattern in text
|
||||
msg = (
|
||||
f"{pattern} found in {file.name}"
|
||||
+ ", this has been deprecated."
|
||||
)
|
||||
self.assertFalse(found, msg)
|
||||
|
||||
def test_copyright(self):
|
||||
test_dir = Path(__file__).resolve().parent
|
||||
root_dir = test_dir.parent
|
||||
|
@ -28,7 +28,7 @@ class TestKNN(unittest.TestCase):
|
||||
|
||||
def test_knn_vs_python_cpu(self):
|
||||
""" Test CPU output vs PyTorch implementation """
|
||||
device = torch.device('cpu')
|
||||
device = torch.device("cpu")
|
||||
Ns = [1, 4]
|
||||
Ds = [2, 3]
|
||||
P1s = [1, 10, 101]
|
||||
@ -45,7 +45,7 @@ class TestKNN(unittest.TestCase):
|
||||
|
||||
def test_knn_vs_python_cuda(self):
|
||||
""" Test CUDA output vs PyTorch implementation """
|
||||
device = torch.device('cuda')
|
||||
device = torch.device("cuda")
|
||||
Ns = [1, 4]
|
||||
Ds = [2, 3, 8]
|
||||
P1s = [1, 8, 64, 128, 1001]
|
||||
|
Loading…
x
Reference in New Issue
Block a user