mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Linter, deprecated type()
Summary: Run linter after recent changes. Fix long comment in knn.h which clang-format has reflowed badly. Add crude test that code doesn't call deprecated `.type()` or `.data()`. Reviewed By: nikhilaravi Differential Revision: D20692935 fbshipit-source-id: 28ce0308adae79a870cb41a810b7cf8744f41ab8
This commit is contained in:
		
							parent
							
								
									3061c5b663
								
							
						
					
					
						commit
						37c5c8e0b6
					
				@ -1,9 +1,9 @@
 | 
			
		||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// This file provides utilities for dispatching to specialized versions of functions.
 | 
			
		||||
// This is especially useful for CUDA kernels, since specializing them to particular
 | 
			
		||||
// input sizes can often allow the compiler to unroll loops and place arrays into
 | 
			
		||||
// registers, which can give huge performance speedups.
 | 
			
		||||
// This file provides utilities for dispatching to specialized versions of
 | 
			
		||||
// functions. This is especially useful for CUDA kernels, since specializing
 | 
			
		||||
// them to particular input sizes can often allow the compiler to unroll loops
 | 
			
		||||
// and place arrays into registers, which can give huge performance speedups.
 | 
			
		||||
//
 | 
			
		||||
// As an example, suppose we have the following function which is specialized
 | 
			
		||||
// based on a compile-time int64_t value:
 | 
			
		||||
@ -92,14 +92,13 @@ namespace {
 | 
			
		||||
// In order to dispatch, we will take an additional template argument curN,
 | 
			
		||||
// and increment it via template recursion until it is equal to the run-time
 | 
			
		||||
// argument N.
 | 
			
		||||
template<
 | 
			
		||||
  template<typename, int64_t> class Kernel,
 | 
			
		||||
  typename T,
 | 
			
		||||
  int64_t minN,
 | 
			
		||||
  int64_t maxN,
 | 
			
		||||
  int64_t curN,
 | 
			
		||||
  typename... Args
 | 
			
		||||
>
 | 
			
		||||
template <
 | 
			
		||||
    template <typename, int64_t> class Kernel,
 | 
			
		||||
    typename T,
 | 
			
		||||
    int64_t minN,
 | 
			
		||||
    int64_t maxN,
 | 
			
		||||
    int64_t curN,
 | 
			
		||||
    typename... Args>
 | 
			
		||||
struct DispatchKernelHelper1D {
 | 
			
		||||
  static void run(const int64_t N, Args... args) {
 | 
			
		||||
    if (curN == N) {
 | 
			
		||||
@ -108,22 +107,21 @@ struct DispatchKernelHelper1D {
 | 
			
		||||
      Kernel<T, curN>::run(args...);
 | 
			
		||||
    } else if (curN < N) {
 | 
			
		||||
      // Increment curN via template recursion
 | 
			
		||||
      DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(N, args...);
 | 
			
		||||
      DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
 | 
			
		||||
          N, args...);
 | 
			
		||||
    }
 | 
			
		||||
    // We shouldn't get here -- throw an error?
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// 1D dispatch: Specialization when curN == maxN
 | 
			
		||||
// We need this base case to avoid infinite template recursion.
 | 
			
		||||
template<
 | 
			
		||||
  template<typename, int64_t> class Kernel,
 | 
			
		||||
  typename T,
 | 
			
		||||
  int64_t minN,
 | 
			
		||||
  int64_t maxN,
 | 
			
		||||
  typename... Args
 | 
			
		||||
>
 | 
			
		||||
template <
 | 
			
		||||
    template <typename, int64_t> class Kernel,
 | 
			
		||||
    typename T,
 | 
			
		||||
    int64_t minN,
 | 
			
		||||
    int64_t maxN,
 | 
			
		||||
    typename... Args>
 | 
			
		||||
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
 | 
			
		||||
  static void run(const int64_t N, Args... args) {
 | 
			
		||||
    if (N == maxN) {
 | 
			
		||||
@ -133,19 +131,21 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// 2D dispatch, general case.
 | 
			
		||||
// This is similar to the 1D case: we take additional template args curN and
 | 
			
		||||
// curM, and increment them via template recursion until they are equal to
 | 
			
		||||
// the run-time values of N and M, at which point we dispatch to the run
 | 
			
		||||
// method of the kernel.
 | 
			
		||||
template<
 | 
			
		||||
  template<typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
  typename T,
 | 
			
		||||
  int64_t minN, int64_t maxN, int64_t curN,
 | 
			
		||||
  int64_t minM, int64_t maxM, int64_t curM,
 | 
			
		||||
  typename... Args
 | 
			
		||||
>
 | 
			
		||||
template <
 | 
			
		||||
    template <typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
    typename T,
 | 
			
		||||
    int64_t minN,
 | 
			
		||||
    int64_t maxN,
 | 
			
		||||
    int64_t curN,
 | 
			
		||||
    int64_t minM,
 | 
			
		||||
    int64_t maxM,
 | 
			
		||||
    int64_t curM,
 | 
			
		||||
    typename... Args>
 | 
			
		||||
struct DispatchKernelHelper2D {
 | 
			
		||||
  static void run(const int64_t N, const int64_t M, Args... args) {
 | 
			
		||||
    if (curN == N && curM == M) {
 | 
			
		||||
@ -154,67 +154,141 @@ struct DispatchKernelHelper2D {
 | 
			
		||||
      // Increment both curN and curM. This isn't strictly necessary; we could
 | 
			
		||||
      // just increment one or the other at each step. But this helps to cut
 | 
			
		||||
      // on the number of recursive calls we make.
 | 
			
		||||
      DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM + 1, Args...>::run(N, M, args...);
 | 
			
		||||
      DispatchKernelHelper2D<
 | 
			
		||||
          Kernel,
 | 
			
		||||
          T,
 | 
			
		||||
          minN,
 | 
			
		||||
          maxN,
 | 
			
		||||
          curN + 1,
 | 
			
		||||
          minM,
 | 
			
		||||
          maxM,
 | 
			
		||||
          curM + 1,
 | 
			
		||||
          Args...>::run(N, M, args...);
 | 
			
		||||
    } else if (curN < N) {
 | 
			
		||||
      // Increment curN only
 | 
			
		||||
      DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, curM, Args...>::run(N, M, args...);
 | 
			
		||||
      DispatchKernelHelper2D<
 | 
			
		||||
          Kernel,
 | 
			
		||||
          T,
 | 
			
		||||
          minN,
 | 
			
		||||
          maxN,
 | 
			
		||||
          curN + 1,
 | 
			
		||||
          minM,
 | 
			
		||||
          maxM,
 | 
			
		||||
          curM,
 | 
			
		||||
          Args...>::run(N, M, args...);
 | 
			
		||||
    } else if (curM < M) {
 | 
			
		||||
      // Increment curM only
 | 
			
		||||
      DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
 | 
			
		||||
      DispatchKernelHelper2D<
 | 
			
		||||
          Kernel,
 | 
			
		||||
          T,
 | 
			
		||||
          minN,
 | 
			
		||||
          maxN,
 | 
			
		||||
          curN,
 | 
			
		||||
          minM,
 | 
			
		||||
          maxM,
 | 
			
		||||
          curM + 1,
 | 
			
		||||
          Args...>::run(N, M, args...);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// 2D dispatch, specialization for curN == maxN
 | 
			
		||||
template<
 | 
			
		||||
  template<typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
  typename T,
 | 
			
		||||
  int64_t minN, int64_t maxN,
 | 
			
		||||
  int64_t minM, int64_t maxM, int64_t curM,
 | 
			
		||||
  typename... Args
 | 
			
		||||
>
 | 
			
		||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM, Args...> {
 | 
			
		||||
template <
 | 
			
		||||
    template <typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
    typename T,
 | 
			
		||||
    int64_t minN,
 | 
			
		||||
    int64_t maxN,
 | 
			
		||||
    int64_t minM,
 | 
			
		||||
    int64_t maxM,
 | 
			
		||||
    int64_t curM,
 | 
			
		||||
    typename... Args>
 | 
			
		||||
struct DispatchKernelHelper2D<
 | 
			
		||||
    Kernel,
 | 
			
		||||
    T,
 | 
			
		||||
    minN,
 | 
			
		||||
    maxN,
 | 
			
		||||
    maxN,
 | 
			
		||||
    minM,
 | 
			
		||||
    maxM,
 | 
			
		||||
    curM,
 | 
			
		||||
    Args...> {
 | 
			
		||||
  static void run(const int64_t N, const int64_t M, Args... args) {
 | 
			
		||||
    if (maxN == N && curM == M) {
 | 
			
		||||
      Kernel<T, maxN, curM>::run(args...);
 | 
			
		||||
    } else if (curM < maxM) {
 | 
			
		||||
      DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, curM + 1, Args...>::run(N, M, args...);
 | 
			
		||||
      DispatchKernelHelper2D<
 | 
			
		||||
          Kernel,
 | 
			
		||||
          T,
 | 
			
		||||
          minN,
 | 
			
		||||
          maxN,
 | 
			
		||||
          maxN,
 | 
			
		||||
          minM,
 | 
			
		||||
          maxM,
 | 
			
		||||
          curM + 1,
 | 
			
		||||
          Args...>::run(N, M, args...);
 | 
			
		||||
    }
 | 
			
		||||
    // We should not get here -- throw an error?
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// 2D dispatch, specialization for curM == maxM
 | 
			
		||||
template<
 | 
			
		||||
  template<typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
  typename T,
 | 
			
		||||
  int64_t minN, int64_t maxN, int64_t curN,
 | 
			
		||||
  int64_t minM, int64_t maxM,
 | 
			
		||||
  typename... Args
 | 
			
		||||
>
 | 
			
		||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, curN, minM, maxM, maxM, Args...> {
 | 
			
		||||
template <
 | 
			
		||||
    template <typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
    typename T,
 | 
			
		||||
    int64_t minN,
 | 
			
		||||
    int64_t maxN,
 | 
			
		||||
    int64_t curN,
 | 
			
		||||
    int64_t minM,
 | 
			
		||||
    int64_t maxM,
 | 
			
		||||
    typename... Args>
 | 
			
		||||
struct DispatchKernelHelper2D<
 | 
			
		||||
    Kernel,
 | 
			
		||||
    T,
 | 
			
		||||
    minN,
 | 
			
		||||
    maxN,
 | 
			
		||||
    curN,
 | 
			
		||||
    minM,
 | 
			
		||||
    maxM,
 | 
			
		||||
    maxM,
 | 
			
		||||
    Args...> {
 | 
			
		||||
  static void run(const int64_t N, const int64_t M, Args... args) {
 | 
			
		||||
    if (curN == N && maxM == M) {
 | 
			
		||||
      Kernel<T, curN, maxM>::run(args...);
 | 
			
		||||
    } else if (curN < maxN) {
 | 
			
		||||
      DispatchKernelHelper2D<Kernel, T, minN, maxN, curN + 1, minM, maxM, maxM, Args...>::run(N, M, args...);
 | 
			
		||||
      DispatchKernelHelper2D<
 | 
			
		||||
          Kernel,
 | 
			
		||||
          T,
 | 
			
		||||
          minN,
 | 
			
		||||
          maxN,
 | 
			
		||||
          curN + 1,
 | 
			
		||||
          minM,
 | 
			
		||||
          maxM,
 | 
			
		||||
          maxM,
 | 
			
		||||
          Args...>::run(N, M, args...);
 | 
			
		||||
    }
 | 
			
		||||
    // We should not get here -- throw an error?
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// 2D dispatch, specialization for curN == maxN, curM == maxM
 | 
			
		||||
template<
 | 
			
		||||
  template<typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
  typename T,
 | 
			
		||||
  int64_t minN, int64_t maxN,
 | 
			
		||||
  int64_t minM, int64_t maxM,
 | 
			
		||||
  typename... Args
 | 
			
		||||
>
 | 
			
		||||
struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Args...> {
 | 
			
		||||
template <
 | 
			
		||||
    template <typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
    typename T,
 | 
			
		||||
    int64_t minN,
 | 
			
		||||
    int64_t maxN,
 | 
			
		||||
    int64_t minM,
 | 
			
		||||
    int64_t maxM,
 | 
			
		||||
    typename... Args>
 | 
			
		||||
struct DispatchKernelHelper2D<
 | 
			
		||||
    Kernel,
 | 
			
		||||
    T,
 | 
			
		||||
    minN,
 | 
			
		||||
    maxN,
 | 
			
		||||
    maxN,
 | 
			
		||||
    minM,
 | 
			
		||||
    maxM,
 | 
			
		||||
    maxM,
 | 
			
		||||
    Args...> {
 | 
			
		||||
  static void run(const int64_t N, const int64_t M, Args... args) {
 | 
			
		||||
    if (maxN == N && maxM == M) {
 | 
			
		||||
      Kernel<T, maxN, maxM>::run(args...);
 | 
			
		||||
@ -225,37 +299,45 @@ struct DispatchKernelHelper2D<Kernel, T, minN, maxN, maxN, minM, maxM, maxM, Arg
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// This is the function we expect users to call to dispatch to 1D functions
 | 
			
		||||
template<
 | 
			
		||||
  template<typename, int64_t> class Kernel,
 | 
			
		||||
  typename T,
 | 
			
		||||
  int64_t minN,
 | 
			
		||||
  int64_t maxN,
 | 
			
		||||
  typename... Args
 | 
			
		||||
>
 | 
			
		||||
template <
 | 
			
		||||
    template <typename, int64_t> class Kernel,
 | 
			
		||||
    typename T,
 | 
			
		||||
    int64_t minN,
 | 
			
		||||
    int64_t maxN,
 | 
			
		||||
    typename... Args>
 | 
			
		||||
void DispatchKernel1D(const int64_t N, Args... args) {
 | 
			
		||||
  if (minN <= N && N <= maxN) {
 | 
			
		||||
    // Kick off the template recursion by calling the Helper with curN = minN
 | 
			
		||||
    DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(N, args...);
 | 
			
		||||
    DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
 | 
			
		||||
        N, args...);
 | 
			
		||||
  }
 | 
			
		||||
  // Maybe throw an error if we tried to dispatch outside the allowed range?
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// This is the function we expect users to call to dispatch to 2D functions
 | 
			
		||||
template<
 | 
			
		||||
  template<typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
  typename T,
 | 
			
		||||
  int64_t minN, int64_t maxN,
 | 
			
		||||
  int64_t minM, int64_t maxM,
 | 
			
		||||
  typename... Args
 | 
			
		||||
>
 | 
			
		||||
template <
 | 
			
		||||
    template <typename, int64_t, int64_t> class Kernel,
 | 
			
		||||
    typename T,
 | 
			
		||||
    int64_t minN,
 | 
			
		||||
    int64_t maxN,
 | 
			
		||||
    int64_t minM,
 | 
			
		||||
    int64_t maxM,
 | 
			
		||||
    typename... Args>
 | 
			
		||||
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
 | 
			
		||||
  if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
 | 
			
		||||
    // Kick off the template recursion by calling the Helper with curN = minN
 | 
			
		||||
    // and curM = minM
 | 
			
		||||
    DispatchKernelHelper2D<Kernel, T, minN, maxN, minN, minM, maxM, minM, Args...>::run(N, M, args...);
 | 
			
		||||
    DispatchKernelHelper2D<
 | 
			
		||||
        Kernel,
 | 
			
		||||
        T,
 | 
			
		||||
        minN,
 | 
			
		||||
        maxN,
 | 
			
		||||
        minN,
 | 
			
		||||
        minM,
 | 
			
		||||
        maxM,
 | 
			
		||||
        minM,
 | 
			
		||||
        Args...>::run(N, M, args...);
 | 
			
		||||
  }
 | 
			
		||||
  // Maybe throw an error if we tried to dispatch outside the specified range?
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -39,82 +39,180 @@
 | 
			
		||||
// approach for this might lead to extra function calls at runtime if the
 | 
			
		||||
// compiler fails to optimize them away, which could be very slow on device.
 | 
			
		||||
// However I didn't actually benchmark or test this.
 | 
			
		||||
template<typename T, int N>
 | 
			
		||||
template <typename T, int N>
 | 
			
		||||
struct RegisterIndexUtils {
 | 
			
		||||
  __device__ __forceinline__ static T get(const T arr[N], int idx) {
 | 
			
		||||
    if (idx < 0 || idx >= N) return T();
 | 
			
		||||
    if (idx < 0 || idx >= N)
 | 
			
		||||
      return T();
 | 
			
		||||
    switch (idx) {
 | 
			
		||||
      case 0: return arr[0];
 | 
			
		||||
      case 1: return arr[1];
 | 
			
		||||
      case 2: return arr[2];
 | 
			
		||||
      case 3: return arr[3];
 | 
			
		||||
      case 4: return arr[4];
 | 
			
		||||
      case 5: return arr[5];
 | 
			
		||||
      case 6: return arr[6];
 | 
			
		||||
      case 7: return arr[7];
 | 
			
		||||
      case 8: return arr[8];
 | 
			
		||||
      case 9: return arr[9];
 | 
			
		||||
      case 10: return arr[10];
 | 
			
		||||
      case 11: return arr[11];
 | 
			
		||||
      case 12: return arr[12];
 | 
			
		||||
      case 13: return arr[13];
 | 
			
		||||
      case 14: return arr[14];
 | 
			
		||||
      case 15: return arr[15];
 | 
			
		||||
      case 16: return arr[16];
 | 
			
		||||
      case 17: return arr[17];
 | 
			
		||||
      case 18: return arr[18];
 | 
			
		||||
      case 19: return arr[19];
 | 
			
		||||
      case 20: return arr[20];
 | 
			
		||||
      case 21: return arr[21];
 | 
			
		||||
      case 22: return arr[22];
 | 
			
		||||
      case 23: return arr[23];
 | 
			
		||||
      case 24: return arr[24];
 | 
			
		||||
      case 25: return arr[25];
 | 
			
		||||
      case 26: return arr[26];
 | 
			
		||||
      case 27: return arr[27];
 | 
			
		||||
      case 28: return arr[28];
 | 
			
		||||
      case 29: return arr[29];
 | 
			
		||||
      case 30: return arr[30];
 | 
			
		||||
      case 31: return arr[31];
 | 
			
		||||
      case 0:
 | 
			
		||||
        return arr[0];
 | 
			
		||||
      case 1:
 | 
			
		||||
        return arr[1];
 | 
			
		||||
      case 2:
 | 
			
		||||
        return arr[2];
 | 
			
		||||
      case 3:
 | 
			
		||||
        return arr[3];
 | 
			
		||||
      case 4:
 | 
			
		||||
        return arr[4];
 | 
			
		||||
      case 5:
 | 
			
		||||
        return arr[5];
 | 
			
		||||
      case 6:
 | 
			
		||||
        return arr[6];
 | 
			
		||||
      case 7:
 | 
			
		||||
        return arr[7];
 | 
			
		||||
      case 8:
 | 
			
		||||
        return arr[8];
 | 
			
		||||
      case 9:
 | 
			
		||||
        return arr[9];
 | 
			
		||||
      case 10:
 | 
			
		||||
        return arr[10];
 | 
			
		||||
      case 11:
 | 
			
		||||
        return arr[11];
 | 
			
		||||
      case 12:
 | 
			
		||||
        return arr[12];
 | 
			
		||||
      case 13:
 | 
			
		||||
        return arr[13];
 | 
			
		||||
      case 14:
 | 
			
		||||
        return arr[14];
 | 
			
		||||
      case 15:
 | 
			
		||||
        return arr[15];
 | 
			
		||||
      case 16:
 | 
			
		||||
        return arr[16];
 | 
			
		||||
      case 17:
 | 
			
		||||
        return arr[17];
 | 
			
		||||
      case 18:
 | 
			
		||||
        return arr[18];
 | 
			
		||||
      case 19:
 | 
			
		||||
        return arr[19];
 | 
			
		||||
      case 20:
 | 
			
		||||
        return arr[20];
 | 
			
		||||
      case 21:
 | 
			
		||||
        return arr[21];
 | 
			
		||||
      case 22:
 | 
			
		||||
        return arr[22];
 | 
			
		||||
      case 23:
 | 
			
		||||
        return arr[23];
 | 
			
		||||
      case 24:
 | 
			
		||||
        return arr[24];
 | 
			
		||||
      case 25:
 | 
			
		||||
        return arr[25];
 | 
			
		||||
      case 26:
 | 
			
		||||
        return arr[26];
 | 
			
		||||
      case 27:
 | 
			
		||||
        return arr[27];
 | 
			
		||||
      case 28:
 | 
			
		||||
        return arr[28];
 | 
			
		||||
      case 29:
 | 
			
		||||
        return arr[29];
 | 
			
		||||
      case 30:
 | 
			
		||||
        return arr[30];
 | 
			
		||||
      case 31:
 | 
			
		||||
        return arr[31];
 | 
			
		||||
    };
 | 
			
		||||
    return T();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __device__ __forceinline__ static void set(T arr[N], int idx, T val) {
 | 
			
		||||
    if (idx < 0 || idx >= N) return;
 | 
			
		||||
    if (idx < 0 || idx >= N)
 | 
			
		||||
      return;
 | 
			
		||||
    switch (idx) {
 | 
			
		||||
      case 0: arr[0] = val; break;
 | 
			
		||||
      case 1: arr[1] = val; break;
 | 
			
		||||
      case 2: arr[2] = val; break;
 | 
			
		||||
      case 3: arr[3] = val; break;
 | 
			
		||||
      case 4: arr[4] = val; break;
 | 
			
		||||
      case 5: arr[5] = val; break;
 | 
			
		||||
      case 6: arr[6] = val; break;
 | 
			
		||||
      case 7: arr[7] = val; break;
 | 
			
		||||
      case 8: arr[8] = val; break;
 | 
			
		||||
      case 9: arr[9] = val; break;
 | 
			
		||||
      case 10: arr[10] = val; break;
 | 
			
		||||
      case 11: arr[11] = val; break;
 | 
			
		||||
      case 12: arr[12] = val; break;
 | 
			
		||||
      case 13: arr[13] = val; break;
 | 
			
		||||
      case 14: arr[14] = val; break;
 | 
			
		||||
      case 15: arr[15] = val; break;
 | 
			
		||||
      case 16: arr[16] = val; break;
 | 
			
		||||
      case 17: arr[17] = val; break;
 | 
			
		||||
      case 18: arr[18] = val; break;
 | 
			
		||||
      case 19: arr[19] = val; break;
 | 
			
		||||
      case 20: arr[20] = val; break;
 | 
			
		||||
      case 21: arr[21] = val; break;
 | 
			
		||||
      case 22: arr[22] = val; break;
 | 
			
		||||
      case 23: arr[23] = val; break;
 | 
			
		||||
      case 24: arr[24] = val; break;
 | 
			
		||||
      case 25: arr[25] = val; break;
 | 
			
		||||
      case 26: arr[26] = val; break;
 | 
			
		||||
      case 27: arr[27] = val; break;
 | 
			
		||||
      case 28: arr[28] = val; break;
 | 
			
		||||
      case 29: arr[29] = val; break;
 | 
			
		||||
      case 30: arr[30] = val; break;
 | 
			
		||||
      case 31: arr[31] = val; break;
 | 
			
		||||
      case 0:
 | 
			
		||||
        arr[0] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 1:
 | 
			
		||||
        arr[1] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 2:
 | 
			
		||||
        arr[2] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 3:
 | 
			
		||||
        arr[3] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 4:
 | 
			
		||||
        arr[4] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 5:
 | 
			
		||||
        arr[5] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 6:
 | 
			
		||||
        arr[6] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 7:
 | 
			
		||||
        arr[7] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 8:
 | 
			
		||||
        arr[8] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 9:
 | 
			
		||||
        arr[9] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 10:
 | 
			
		||||
        arr[10] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 11:
 | 
			
		||||
        arr[11] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 12:
 | 
			
		||||
        arr[12] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 13:
 | 
			
		||||
        arr[13] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 14:
 | 
			
		||||
        arr[14] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 15:
 | 
			
		||||
        arr[15] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 16:
 | 
			
		||||
        arr[16] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 17:
 | 
			
		||||
        arr[17] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 18:
 | 
			
		||||
        arr[18] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 19:
 | 
			
		||||
        arr[19] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 20:
 | 
			
		||||
        arr[20] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 21:
 | 
			
		||||
        arr[21] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 22:
 | 
			
		||||
        arr[22] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 23:
 | 
			
		||||
        arr[23] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 24:
 | 
			
		||||
        arr[24] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 25:
 | 
			
		||||
        arr[25] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 26:
 | 
			
		||||
        arr[26] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 27:
 | 
			
		||||
        arr[27] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 28:
 | 
			
		||||
        arr[28] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 29:
 | 
			
		||||
        arr[29] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 30:
 | 
			
		||||
        arr[30] = val;
 | 
			
		||||
        break;
 | 
			
		||||
      case 31:
 | 
			
		||||
        arr[31] = val;
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -289,7 +289,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
 | 
			
		||||
  const size_t threads = 256;
 | 
			
		||||
  const size_t blocks = 256;
 | 
			
		||||
  if (version == 0) {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
 | 
			
		||||
                                 KNearestNeighborKernelV0<scalar_t>
 | 
			
		||||
                                     <<<blocks, threads>>>(
 | 
			
		||||
                                         p1.data_ptr<scalar_t>(),
 | 
			
		||||
@ -303,7 +303,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
 | 
			
		||||
                                         K);
 | 
			
		||||
                               }));
 | 
			
		||||
  } else if (version == 1) {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
 | 
			
		||||
                                 DispatchKernel1D<
 | 
			
		||||
                                     KNearestNeighborV1Functor,
 | 
			
		||||
                                     scalar_t,
 | 
			
		||||
@ -322,7 +322,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
 | 
			
		||||
                                     K);
 | 
			
		||||
                               }));
 | 
			
		||||
  } else if (version == 2) {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
 | 
			
		||||
                                 DispatchKernel2D<
 | 
			
		||||
                                     KNearestNeighborKernelV2Functor,
 | 
			
		||||
                                     scalar_t,
 | 
			
		||||
@ -343,7 +343,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
 | 
			
		||||
                                     P2);
 | 
			
		||||
                               }));
 | 
			
		||||
  } else if (version == 3) {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
 | 
			
		||||
                                 DispatchKernel2D<
 | 
			
		||||
                                     KNearestNeighborKernelV3Functor,
 | 
			
		||||
                                     scalar_t,
 | 
			
		||||
 | 
			
		||||
@ -13,11 +13,11 @@
 | 
			
		||||
//        containing P1 points of dimension D.
 | 
			
		||||
//    p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
 | 
			
		||||
//        containing P2 points of dimension D.
 | 
			
		||||
//     K: int giving the number of nearest points to return.
 | 
			
		||||
//     sorted: bool telling whether to sort the K returned points by their
 | 
			
		||||
//     distance version: Integer telling which implementation to use.
 | 
			
		||||
//              TODO(jcjohns): Document this more, or maybe remove it before
 | 
			
		||||
//              landing.
 | 
			
		||||
//    K: int giving the number of nearest points to return.
 | 
			
		||||
//    sorted: bool telling whether to sort the K returned points by their
 | 
			
		||||
//        distance.
 | 
			
		||||
//    version: Integer telling which implementation to use.
 | 
			
		||||
//        TODO(jcjohns): Document this more, or maybe remove it before landing.
 | 
			
		||||
//
 | 
			
		||||
// Returns:
 | 
			
		||||
//    p1_neighbor_idx: LongTensor of shape (N, P1, K), where
 | 
			
		||||
@ -41,7 +41,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
 | 
			
		||||
    const at::Tensor& p2,
 | 
			
		||||
    int K,
 | 
			
		||||
    int version) {
 | 
			
		||||
  if (p1.type().is_cuda() || p2.type().is_cuda()) {
 | 
			
		||||
  if (p1.is_cuda() || p2.is_cuda()) {
 | 
			
		||||
#ifdef WITH_CUDA
 | 
			
		||||
    CHECK_CONTIGUOUS_CUDA(p1);
 | 
			
		||||
    CHECK_CONTIGUOUS_CUDA(p2);
 | 
			
		||||
 | 
			
		||||
@ -4,49 +4,48 @@
 | 
			
		||||
#include <queue>
 | 
			
		||||
#include <tuple>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor>
 | 
			
		||||
KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) {
 | 
			
		||||
    const int N = p1.size(0);
 | 
			
		||||
    const int P1 = p1.size(1);
 | 
			
		||||
    const int D = p1.size(2);
 | 
			
		||||
    const int P2 = p2.size(1);
 | 
			
		||||
  const int N = p1.size(0);
 | 
			
		||||
  const int P1 = p1.size(1);
 | 
			
		||||
  const int D = p1.size(2);
 | 
			
		||||
  const int P2 = p2.size(1);
 | 
			
		||||
 | 
			
		||||
    auto long_opts = p1.options().dtype(torch::kInt64);
 | 
			
		||||
    torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
 | 
			
		||||
    torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
 | 
			
		||||
  auto long_opts = p1.options().dtype(torch::kInt64);
 | 
			
		||||
  torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
 | 
			
		||||
  torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
 | 
			
		||||
 | 
			
		||||
    auto p1_a = p1.accessor<float, 3>();
 | 
			
		||||
    auto p2_a = p2.accessor<float, 3>();
 | 
			
		||||
    auto idxs_a = idxs.accessor<int64_t, 3>();
 | 
			
		||||
    auto dists_a = dists.accessor<float, 3>();
 | 
			
		||||
  auto p1_a = p1.accessor<float, 3>();
 | 
			
		||||
  auto p2_a = p2.accessor<float, 3>();
 | 
			
		||||
  auto idxs_a = idxs.accessor<int64_t, 3>();
 | 
			
		||||
  auto dists_a = dists.accessor<float, 3>();
 | 
			
		||||
 | 
			
		||||
    for (int n = 0; n < N; ++n) {
 | 
			
		||||
        for (int i1 = 0; i1 < P1; ++i1) {
 | 
			
		||||
            // Use a priority queue to store (distance, index) tuples.
 | 
			
		||||
            std::priority_queue<std::tuple<float, int>> q;
 | 
			
		||||
            for (int i2 = 0; i2 < P2; ++i2) {
 | 
			
		||||
                float dist = 0;
 | 
			
		||||
                for (int d = 0; d < D; ++d) {
 | 
			
		||||
                    float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
 | 
			
		||||
                    dist += diff * diff;
 | 
			
		||||
                }
 | 
			
		||||
                int size = static_cast<int>(q.size());
 | 
			
		||||
                if (size < K || dist < std::get<0>(q.top())) {
 | 
			
		||||
                    q.emplace(dist, i2);
 | 
			
		||||
                    if (size >= K) {
 | 
			
		||||
                        q.pop();
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            while (!q.empty()) {
 | 
			
		||||
                auto t = q.top();
 | 
			
		||||
                q.pop();
 | 
			
		||||
                const int k = q.size();
 | 
			
		||||
                dists_a[n][i1][k] = std::get<0>(t);
 | 
			
		||||
                idxs_a[n][i1][k] = std::get<1>(t);
 | 
			
		||||
            }
 | 
			
		||||
  for (int n = 0; n < N; ++n) {
 | 
			
		||||
    for (int i1 = 0; i1 < P1; ++i1) {
 | 
			
		||||
      // Use a priority queue to store (distance, index) tuples.
 | 
			
		||||
      std::priority_queue<std::tuple<float, int>> q;
 | 
			
		||||
      for (int i2 = 0; i2 < P2; ++i2) {
 | 
			
		||||
        float dist = 0;
 | 
			
		||||
        for (int d = 0; d < D; ++d) {
 | 
			
		||||
          float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
 | 
			
		||||
          dist += diff * diff;
 | 
			
		||||
        }
 | 
			
		||||
        int size = static_cast<int>(q.size());
 | 
			
		||||
        if (size < K || dist < std::get<0>(q.top())) {
 | 
			
		||||
          q.emplace(dist, i2);
 | 
			
		||||
          if (size >= K) {
 | 
			
		||||
            q.pop();
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      while (!q.empty()) {
 | 
			
		||||
        auto t = q.top();
 | 
			
		||||
        q.pop();
 | 
			
		||||
        const int k = q.size();
 | 
			
		||||
        dists_a[n][i1][k] = std::get<0>(t);
 | 
			
		||||
        idxs_a[n][i1][k] = std::get<1>(t);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return std::make_tuple(idxs, dists);
 | 
			
		||||
  }
 | 
			
		||||
  return std::make_tuple(idxs, dists);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,6 @@
 | 
			
		||||
 | 
			
		||||
#include "index_utils.cuh"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// A data structure to keep track of the smallest K keys seen so far as well
 | 
			
		||||
// as their associated values, intended to be used in device code.
 | 
			
		||||
// This data structure doesn't allocate any memory; keys and values are stored
 | 
			
		||||
@ -32,18 +31,17 @@
 | 
			
		||||
//   float key_k = keys[k];
 | 
			
		||||
//   int value_k = values[k];
 | 
			
		||||
// }
 | 
			
		||||
template<typename key_t, typename value_t>
 | 
			
		||||
template <typename key_t, typename value_t>
 | 
			
		||||
class MinK {
 | 
			
		||||
 public:
 | 
			
		||||
 | 
			
		||||
  // Constructor.
 | 
			
		||||
  //
 | 
			
		||||
  // Arguments:
 | 
			
		||||
  //   keys: Array in which to store keys
 | 
			
		||||
  //   values: Array in which to store values
 | 
			
		||||
  //   K: How many values to keep track of
 | 
			
		||||
  __device__ MinK(key_t *keys, value_t *vals, int K) :
 | 
			
		||||
                  keys(keys), vals(vals), K(K), _size(0) { }
 | 
			
		||||
  __device__ MinK(key_t* keys, value_t* vals, int K)
 | 
			
		||||
      : keys(keys), vals(vals), K(K), _size(0) {}
 | 
			
		||||
 | 
			
		||||
  // Try to add a new key and associated value to the data structure. If the key
 | 
			
		||||
  // is one of the smallest K seen so far then it will be kept; otherwise it
 | 
			
		||||
@ -55,7 +53,7 @@ class MinK {
 | 
			
		||||
  // Arguments:
 | 
			
		||||
  //   key: The key to add
 | 
			
		||||
  //   val: The value associated to the key
 | 
			
		||||
  __device__ __forceinline__ void add(const key_t &key, const value_t &val) {
 | 
			
		||||
  __device__ __forceinline__ void add(const key_t& key, const value_t& val) {
 | 
			
		||||
    if (_size < K) {
 | 
			
		||||
      keys[_size] = key;
 | 
			
		||||
      vals[_size] = val;
 | 
			
		||||
@ -71,8 +69,8 @@ class MinK {
 | 
			
		||||
      for (int k = 0; k < K; ++k) {
 | 
			
		||||
        key_t cur_key = keys[k];
 | 
			
		||||
        if (cur_key > max_key) {
 | 
			
		||||
            max_key = cur_key;
 | 
			
		||||
            max_idx = k;
 | 
			
		||||
          max_key = cur_key;
 | 
			
		||||
          max_idx = k;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
@ -102,15 +100,14 @@ class MinK {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  key_t *keys;
 | 
			
		||||
  value_t *vals;
 | 
			
		||||
  key_t* keys;
 | 
			
		||||
  value_t* vals;
 | 
			
		||||
  int K;
 | 
			
		||||
  int _size;
 | 
			
		||||
  key_t max_key;
 | 
			
		||||
  int max_idx;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// This is a version of MinK that only touches the arrays using static indexing
 | 
			
		||||
// via RegisterIndexUtils. If the keys and values are stored in thread-local
 | 
			
		||||
// arrays, then this may allow the compiler to place them in registers for
 | 
			
		||||
@ -120,13 +117,13 @@ class MinK {
 | 
			
		||||
// We found that sorting via RegisterIndexUtils gave very poor performance,
 | 
			
		||||
// and suspect it may have prevented the compiler from placing the arrays
 | 
			
		||||
// into registers.
 | 
			
		||||
template<typename key_t, typename value_t, int K>
 | 
			
		||||
template <typename key_t, typename value_t, int K>
 | 
			
		||||
class RegisterMinK {
 | 
			
		||||
 public:
 | 
			
		||||
  __device__ RegisterMinK(key_t *keys, value_t *vals) :
 | 
			
		||||
                          keys(keys), vals(vals), _size(0) {}
 | 
			
		||||
  __device__ RegisterMinK(key_t* keys, value_t* vals)
 | 
			
		||||
      : keys(keys), vals(vals), _size(0) {}
 | 
			
		||||
 | 
			
		||||
  __device__ __forceinline__ void add(const key_t &key, const value_t &val) {
 | 
			
		||||
  __device__ __forceinline__ void add(const key_t& key, const value_t& val) {
 | 
			
		||||
    if (_size < K) {
 | 
			
		||||
      RegisterIndexUtils<key_t, K>::set(keys, _size, key);
 | 
			
		||||
      RegisterIndexUtils<value_t, K>::set(vals, _size, val);
 | 
			
		||||
@ -154,8 +151,8 @@ class RegisterMinK {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  key_t *keys;
 | 
			
		||||
  value_t *vals;
 | 
			
		||||
  key_t* keys;
 | 
			
		||||
  value_t* vals;
 | 
			
		||||
  int _size;
 | 
			
		||||
  key_t max_key;
 | 
			
		||||
  int max_idx;
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d import _C
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from itertools import product
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from fvcore.common.benchmark import benchmark
 | 
			
		||||
 | 
			
		||||
@ -30,21 +29,13 @@ def benchmark_knn_cuda_versions() -> None:
 | 
			
		||||
            continue
 | 
			
		||||
        if version == 3 and K > 4:
 | 
			
		||||
            continue
 | 
			
		||||
        knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version})
 | 
			
		||||
        knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
 | 
			
		||||
    for N, P, D in product(Ns, Ps, Ds):
 | 
			
		||||
        nn_kwargs.append({'N': N, 'D': D, 'P': P})
 | 
			
		||||
        nn_kwargs.append({"N": N, "D": D, "P": P})
 | 
			
		||||
    benchmark(
 | 
			
		||||
        knn_cuda_with_init,
 | 
			
		||||
        'KNN_CUDA_VERSIONS',
 | 
			
		||||
        knn_kwargs,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(
 | 
			
		||||
        nn_cuda_with_init,
 | 
			
		||||
        'NN_CUDA',
 | 
			
		||||
        nn_kwargs,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
        knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def benchmark_knn_cuda_vs_naive() -> None:
 | 
			
		||||
@ -55,21 +46,16 @@ def benchmark_knn_cuda_vs_naive() -> None:
 | 
			
		||||
    Ks = [1, 2, 4, 8, 16]
 | 
			
		||||
    knn_kwargs, naive_kwargs = [], []
 | 
			
		||||
    for N, P, D, K in product(Ns, Ps, Ds, Ks):
 | 
			
		||||
        knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
 | 
			
		||||
        knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
 | 
			
		||||
        if P <= 4096:
 | 
			
		||||
            naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
 | 
			
		||||
            naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
 | 
			
		||||
    benchmark(
 | 
			
		||||
        knn_python_cuda_with_init,
 | 
			
		||||
        'KNN_CUDA_PYTHON',
 | 
			
		||||
        "KNN_CUDA_PYTHON",
 | 
			
		||||
        naive_kwargs,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(
 | 
			
		||||
        knn_cuda_with_init,
 | 
			
		||||
        'KNN_CUDA',
 | 
			
		||||
        knn_kwargs,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def benchmark_knn_cpu() -> None:
 | 
			
		||||
@ -79,31 +65,18 @@ def benchmark_knn_cpu() -> None:
 | 
			
		||||
    Ks = [1, 2, 4]
 | 
			
		||||
    knn_kwargs, nn_kwargs = [], []
 | 
			
		||||
    for N, P, D, K in product(Ns, Ps, Ds, Ks):
 | 
			
		||||
        knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K})
 | 
			
		||||
        knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
 | 
			
		||||
    for N, P, D in product(Ns, Ps, Ds):
 | 
			
		||||
        nn_kwargs.append({'N': N, 'D': D, 'P': P})
 | 
			
		||||
        nn_kwargs.append({"N": N, "D": D, "P": P})
 | 
			
		||||
    benchmark(
 | 
			
		||||
        knn_python_cpu_with_init,
 | 
			
		||||
        'KNN_CPU_PYTHON',
 | 
			
		||||
        knn_kwargs,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(
 | 
			
		||||
        knn_cpu_with_init,
 | 
			
		||||
        'KNN_CPU_CPP',
 | 
			
		||||
        knn_kwargs,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(
 | 
			
		||||
        nn_cpu_with_init,
 | 
			
		||||
        'NN_CPU_CPP',
 | 
			
		||||
        nn_kwargs,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
        knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
 | 
			
		||||
    benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def knn_cuda_with_init(N, D, P, K, v=-1):
 | 
			
		||||
    device = torch.device('cuda:0')
 | 
			
		||||
    device = torch.device("cuda:0")
 | 
			
		||||
    x = torch.randn(N, P, D, device=device)
 | 
			
		||||
    y = torch.randn(N, P, D, device=device)
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
@ -116,7 +89,7 @@ def knn_cuda_with_init(N, D, P, K, v=-1):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def knn_cpu_with_init(N, D, P, K):
 | 
			
		||||
    device = torch.device('cpu')
 | 
			
		||||
    device = torch.device("cpu")
 | 
			
		||||
    x = torch.randn(N, P, D, device=device)
 | 
			
		||||
    y = torch.randn(N, P, D, device=device)
 | 
			
		||||
 | 
			
		||||
@ -127,7 +100,7 @@ def knn_cpu_with_init(N, D, P, K):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def knn_python_cuda_with_init(N, D, P, K):
 | 
			
		||||
    device = torch.device('cuda')
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
    x = torch.randn(N, P, D, device=device)
 | 
			
		||||
    y = torch.randn(N, P, D, device=device)
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
@ -140,7 +113,7 @@ def knn_python_cuda_with_init(N, D, P, K):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def knn_python_cpu_with_init(N, D, P, K):
 | 
			
		||||
    device = torch.device('cpu')
 | 
			
		||||
    device = torch.device("cpu")
 | 
			
		||||
    x = torch.randn(N, P, D, device=device)
 | 
			
		||||
    y = torch.randn(N, P, D, device=device)
 | 
			
		||||
 | 
			
		||||
@ -151,7 +124,7 @@ def knn_python_cpu_with_init(N, D, P, K):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def nn_cuda_with_init(N, D, P):
 | 
			
		||||
    device = torch.device('cuda')
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
    x = torch.randn(N, P, D, device=device)
 | 
			
		||||
    y = torch.randn(N, P, D, device=device)
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
@ -164,7 +137,7 @@ def nn_cuda_with_init(N, D, P):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def nn_cpu_with_init(N, D, P):
 | 
			
		||||
    device = torch.device('cpu')
 | 
			
		||||
    device = torch.device("cpu")
 | 
			
		||||
    x = torch.randn(N, P, D, device=device)
 | 
			
		||||
    y = torch.randn(N, P, D, device=device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,27 @@ class TestBuild(unittest.TestCase):
 | 
			
		||||
        for k, v in counter.items():
 | 
			
		||||
            self.assertEqual(v, 1, f"Too many files with stem {k}.")
 | 
			
		||||
 | 
			
		||||
    def test_deprecated_usage(self):
 | 
			
		||||
        # Check certain expressions do not occur in the csrc code
 | 
			
		||||
        test_dir = Path(__file__).resolve().parent
 | 
			
		||||
        source_dir = test_dir.parent / "pytorch3d" / "csrc"
 | 
			
		||||
 | 
			
		||||
        files = sorted(source_dir.glob("**/*.*"))
 | 
			
		||||
        self.assertGreater(len(files), 4)
 | 
			
		||||
 | 
			
		||||
        patterns = [".type()", ".data()"]
 | 
			
		||||
 | 
			
		||||
        for file in files:
 | 
			
		||||
            with open(file) as f:
 | 
			
		||||
                text = f.read()
 | 
			
		||||
                for pattern in patterns:
 | 
			
		||||
                    found = pattern in text
 | 
			
		||||
                    msg = (
 | 
			
		||||
                        f"{pattern} found in {file.name}"
 | 
			
		||||
                        + ", this has been deprecated."
 | 
			
		||||
                    )
 | 
			
		||||
                    self.assertFalse(found, msg)
 | 
			
		||||
 | 
			
		||||
    def test_copyright(self):
 | 
			
		||||
        test_dir = Path(__file__).resolve().parent
 | 
			
		||||
        root_dir = test_dir.parent
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ class TestKNN(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_knn_vs_python_cpu(self):
 | 
			
		||||
        """ Test CPU output vs PyTorch implementation """
 | 
			
		||||
        device = torch.device('cpu')
 | 
			
		||||
        device = torch.device("cpu")
 | 
			
		||||
        Ns = [1, 4]
 | 
			
		||||
        Ds = [2, 3]
 | 
			
		||||
        P1s = [1, 10, 101]
 | 
			
		||||
@ -45,7 +45,7 @@ class TestKNN(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_knn_vs_python_cuda(self):
 | 
			
		||||
        """ Test CUDA output vs PyTorch implementation """
 | 
			
		||||
        device = torch.device('cuda')
 | 
			
		||||
        device = torch.device("cuda")
 | 
			
		||||
        Ns = [1, 4]
 | 
			
		||||
        Ds = [2, 3, 8]
 | 
			
		||||
        P1s = [1, 8, 64, 128, 1001]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user