From 37c5c8e0b68445536d1b3eabd2ab1ffe4d87b61a Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Sun, 29 Mar 2020 14:01:15 -0700 Subject: [PATCH] Linter, deprecated type() Summary: Run linter after recent changes. Fix long comment in knn.h which clang-format has reflowed badly. Add crude test that code doesn't call deprecated `.type()` or `.data()`. Reviewed By: nikhilaravi Differential Revision: D20692935 fbshipit-source-id: 28ce0308adae79a870cb41a810b7cf8744f41ab8 --- pytorch3d/csrc/dispatch.cuh | 240 ++++++++++++++++++++++----------- pytorch3d/csrc/index_utils.cuh | 232 ++++++++++++++++++++++--------- pytorch3d/csrc/knn/knn.cu | 8 +- pytorch3d/csrc/knn/knn.h | 12 +- pytorch3d/csrc/knn/knn_cpu.cpp | 75 +++++------ pytorch3d/csrc/mink.cuh | 31 ++--- pytorch3d/ops/knn.py | 1 + tests/bm_knn.py | 65 +++------ tests/test_build.py | 21 +++ tests/test_knn.py | 4 +- 10 files changed, 430 insertions(+), 259 deletions(-) diff --git a/pytorch3d/csrc/dispatch.cuh b/pytorch3d/csrc/dispatch.cuh index 5226a5d7..001a09e3 100644 --- a/pytorch3d/csrc/dispatch.cuh +++ b/pytorch3d/csrc/dispatch.cuh @@ -1,9 +1,9 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. // -// This file provides utilities for dispatching to specialized versions of functions. -// This is especially useful for CUDA kernels, since specializing them to particular -// input sizes can often allow the compiler to unroll loops and place arrays into -// registers, which can give huge performance speedups. +// This file provides utilities for dispatching to specialized versions of +// functions. This is especially useful for CUDA kernels, since specializing +// them to particular input sizes can often allow the compiler to unroll loops +// and place arrays into registers, which can give huge performance speedups. // // As an example, suppose we have the following function which is specialized // based on a compile-time int64_t value: @@ -92,14 +92,13 @@ namespace { // In order to dispatch, we will take an additional template argument curN, // and increment it via template recursion until it is equal to the run-time // argument N. -template< - template class Kernel, - typename T, - int64_t minN, - int64_t maxN, - int64_t curN, - typename... Args -> +template < + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + typename... Args> struct DispatchKernelHelper1D { static void run(const int64_t N, Args... args) { if (curN == N) { @@ -108,22 +107,21 @@ struct DispatchKernelHelper1D { Kernel::run(args...); } else if (curN < N) { // Increment curN via template recursion - DispatchKernelHelper1D::run(N, args...); + DispatchKernelHelper1D::run( + N, args...); } // We shouldn't get here -- throw an error? } }; - // 1D dispatch: Specialization when curN == maxN // We need this base case to avoid infinite template recursion. -template< - template class Kernel, - typename T, - int64_t minN, - int64_t maxN, - typename... Args -> +template < + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + typename... Args> struct DispatchKernelHelper1D { static void run(const int64_t N, Args... args) { if (N == maxN) { @@ -133,19 +131,21 @@ struct DispatchKernelHelper1D { } }; - // 2D dispatch, general case. // This is similar to the 1D case: we take additional template args curN and // curM, and increment them via template recursion until they are equal to // the run-time values of N and M, at which point we dispatch to the run // method of the kernel. -template< - template class Kernel, - typename T, - int64_t minN, int64_t maxN, int64_t curN, - int64_t minM, int64_t maxM, int64_t curM, - typename... Args -> +template < + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + int64_t minM, + int64_t maxM, + int64_t curM, + typename... Args> struct DispatchKernelHelper2D { static void run(const int64_t N, const int64_t M, Args... args) { if (curN == N && curM == M) { @@ -154,67 +154,141 @@ struct DispatchKernelHelper2D { // Increment both curN and curM. This isn't strictly necessary; we could // just increment one or the other at each step. But this helps to cut // on the number of recursive calls we make. - DispatchKernelHelper2D::run(N, M, args...); + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); } else if (curN < N) { // Increment curN only - DispatchKernelHelper2D::run(N, M, args...); + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + curM, + Args...>::run(N, M, args...); } else if (curM < M) { // Increment curM only - DispatchKernelHelper2D::run(N, M, args...); + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); } } }; - // 2D dispatch, specialization for curN == maxN -template< - template class Kernel, - typename T, - int64_t minN, int64_t maxN, - int64_t minM, int64_t maxM, int64_t curM, - typename... Args -> -struct DispatchKernelHelper2D { +template < + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + int64_t curM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + curM, + Args...> { static void run(const int64_t N, const int64_t M, Args... args) { if (maxN == N && curM == M) { Kernel::run(args...); } else if (curM < maxM) { - DispatchKernelHelper2D::run(N, M, args...); + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); } // We should not get here -- throw an error? } }; - // 2D dispatch, specialization for curM == maxM -template< - template class Kernel, - typename T, - int64_t minN, int64_t maxN, int64_t curN, - int64_t minM, int64_t maxM, - typename... Args -> -struct DispatchKernelHelper2D { +template < + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + int64_t minM, + int64_t maxM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN, + minM, + maxM, + maxM, + Args...> { static void run(const int64_t N, const int64_t M, Args... args) { if (curN == N && maxM == M) { Kernel::run(args...); } else if (curN < maxN) { - DispatchKernelHelper2D::run(N, M, args...); + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + maxM, + Args...>::run(N, M, args...); } // We should not get here -- throw an error? } }; - // 2D dispatch, specialization for curN == maxN, curM == maxM -template< - template class Kernel, - typename T, - int64_t minN, int64_t maxN, - int64_t minM, int64_t maxM, - typename... Args -> -struct DispatchKernelHelper2D { +template < + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + maxM, + Args...> { static void run(const int64_t N, const int64_t M, Args... args) { if (maxN == N && maxM == M) { Kernel::run(args...); @@ -225,37 +299,45 @@ struct DispatchKernelHelper2D class Kernel, - typename T, - int64_t minN, - int64_t maxN, - typename... Args -> +template < + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + typename... Args> void DispatchKernel1D(const int64_t N, Args... args) { if (minN <= N && N <= maxN) { // Kick off the template recursion by calling the Helper with curN = minN - DispatchKernelHelper1D::run(N, args...); + DispatchKernelHelper1D::run( + N, args...); } // Maybe throw an error if we tried to dispatch outside the allowed range? } - // This is the function we expect users to call to dispatch to 2D functions -template< - template class Kernel, - typename T, - int64_t minN, int64_t maxN, - int64_t minM, int64_t maxM, - typename... Args -> +template < + template class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + typename... Args> void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { if (minN <= N && N <= maxN && minM <= M && M <= maxM) { // Kick off the template recursion by calling the Helper with curN = minN // and curM = minM - DispatchKernelHelper2D::run(N, M, args...); + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + minN, + minM, + maxM, + minM, + Args...>::run(N, M, args...); } // Maybe throw an error if we tried to dispatch outside the specified range? } diff --git a/pytorch3d/csrc/index_utils.cuh b/pytorch3d/csrc/index_utils.cuh index 66460ebf..26fda57d 100644 --- a/pytorch3d/csrc/index_utils.cuh +++ b/pytorch3d/csrc/index_utils.cuh @@ -39,82 +39,180 @@ // approach for this might lead to extra function calls at runtime if the // compiler fails to optimize them away, which could be very slow on device. // However I didn't actually benchmark or test this. -template +template struct RegisterIndexUtils { __device__ __forceinline__ static T get(const T arr[N], int idx) { - if (idx < 0 || idx >= N) return T(); + if (idx < 0 || idx >= N) + return T(); switch (idx) { - case 0: return arr[0]; - case 1: return arr[1]; - case 2: return arr[2]; - case 3: return arr[3]; - case 4: return arr[4]; - case 5: return arr[5]; - case 6: return arr[6]; - case 7: return arr[7]; - case 8: return arr[8]; - case 9: return arr[9]; - case 10: return arr[10]; - case 11: return arr[11]; - case 12: return arr[12]; - case 13: return arr[13]; - case 14: return arr[14]; - case 15: return arr[15]; - case 16: return arr[16]; - case 17: return arr[17]; - case 18: return arr[18]; - case 19: return arr[19]; - case 20: return arr[20]; - case 21: return arr[21]; - case 22: return arr[22]; - case 23: return arr[23]; - case 24: return arr[24]; - case 25: return arr[25]; - case 26: return arr[26]; - case 27: return arr[27]; - case 28: return arr[28]; - case 29: return arr[29]; - case 30: return arr[30]; - case 31: return arr[31]; + case 0: + return arr[0]; + case 1: + return arr[1]; + case 2: + return arr[2]; + case 3: + return arr[3]; + case 4: + return arr[4]; + case 5: + return arr[5]; + case 6: + return arr[6]; + case 7: + return arr[7]; + case 8: + return arr[8]; + case 9: + return arr[9]; + case 10: + return arr[10]; + case 11: + return arr[11]; + case 12: + return arr[12]; + case 13: + return arr[13]; + case 14: + return arr[14]; + case 15: + return arr[15]; + case 16: + return arr[16]; + case 17: + return arr[17]; + case 18: + return arr[18]; + case 19: + return arr[19]; + case 20: + return arr[20]; + case 21: + return arr[21]; + case 22: + return arr[22]; + case 23: + return arr[23]; + case 24: + return arr[24]; + case 25: + return arr[25]; + case 26: + return arr[26]; + case 27: + return arr[27]; + case 28: + return arr[28]; + case 29: + return arr[29]; + case 30: + return arr[30]; + case 31: + return arr[31]; }; return T(); } __device__ __forceinline__ static void set(T arr[N], int idx, T val) { - if (idx < 0 || idx >= N) return; + if (idx < 0 || idx >= N) + return; switch (idx) { - case 0: arr[0] = val; break; - case 1: arr[1] = val; break; - case 2: arr[2] = val; break; - case 3: arr[3] = val; break; - case 4: arr[4] = val; break; - case 5: arr[5] = val; break; - case 6: arr[6] = val; break; - case 7: arr[7] = val; break; - case 8: arr[8] = val; break; - case 9: arr[9] = val; break; - case 10: arr[10] = val; break; - case 11: arr[11] = val; break; - case 12: arr[12] = val; break; - case 13: arr[13] = val; break; - case 14: arr[14] = val; break; - case 15: arr[15] = val; break; - case 16: arr[16] = val; break; - case 17: arr[17] = val; break; - case 18: arr[18] = val; break; - case 19: arr[19] = val; break; - case 20: arr[20] = val; break; - case 21: arr[21] = val; break; - case 22: arr[22] = val; break; - case 23: arr[23] = val; break; - case 24: arr[24] = val; break; - case 25: arr[25] = val; break; - case 26: arr[26] = val; break; - case 27: arr[27] = val; break; - case 28: arr[28] = val; break; - case 29: arr[29] = val; break; - case 30: arr[30] = val; break; - case 31: arr[31] = val; break; + case 0: + arr[0] = val; + break; + case 1: + arr[1] = val; + break; + case 2: + arr[2] = val; + break; + case 3: + arr[3] = val; + break; + case 4: + arr[4] = val; + break; + case 5: + arr[5] = val; + break; + case 6: + arr[6] = val; + break; + case 7: + arr[7] = val; + break; + case 8: + arr[8] = val; + break; + case 9: + arr[9] = val; + break; + case 10: + arr[10] = val; + break; + case 11: + arr[11] = val; + break; + case 12: + arr[12] = val; + break; + case 13: + arr[13] = val; + break; + case 14: + arr[14] = val; + break; + case 15: + arr[15] = val; + break; + case 16: + arr[16] = val; + break; + case 17: + arr[17] = val; + break; + case 18: + arr[18] = val; + break; + case 19: + arr[19] = val; + break; + case 20: + arr[20] = val; + break; + case 21: + arr[21] = val; + break; + case 22: + arr[22] = val; + break; + case 23: + arr[23] = val; + break; + case 24: + arr[24] = val; + break; + case 25: + arr[25] = val; + break; + case 26: + arr[26] = val; + break; + case 27: + arr[27] = val; + break; + case 28: + arr[28] = val; + break; + case 29: + arr[29] = val; + break; + case 30: + arr[30] = val; + break; + case 31: + arr[31] = val; + break; } } }; diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index b065d969..a6d53951 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -289,7 +289,7 @@ std::tuple KNearestNeighborIdxCuda( const size_t threads = 256; const size_t blocks = 256; if (version == 0) { - AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { KNearestNeighborKernelV0 <<>>( p1.data_ptr(), @@ -303,7 +303,7 @@ std::tuple KNearestNeighborIdxCuda( K); })); } else if (version == 1) { - AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel1D< KNearestNeighborV1Functor, scalar_t, @@ -322,7 +322,7 @@ std::tuple KNearestNeighborIdxCuda( K); })); } else if (version == 2) { - AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel2D< KNearestNeighborKernelV2Functor, scalar_t, @@ -343,7 +343,7 @@ std::tuple KNearestNeighborIdxCuda( P2); })); } else if (version == 3) { - AT_DISPATCH_FLOATING_TYPES(p1.type(), "knn_kernel_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel2D< KNearestNeighborKernelV3Functor, scalar_t, diff --git a/pytorch3d/csrc/knn/knn.h b/pytorch3d/csrc/knn/knn.h index cb760c34..65c3732b 100644 --- a/pytorch3d/csrc/knn/knn.h +++ b/pytorch3d/csrc/knn/knn.h @@ -13,11 +13,11 @@ // containing P1 points of dimension D. // p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each // containing P2 points of dimension D. -// K: int giving the number of nearest points to return. -// sorted: bool telling whether to sort the K returned points by their -// distance version: Integer telling which implementation to use. -// TODO(jcjohns): Document this more, or maybe remove it before -// landing. +// K: int giving the number of nearest points to return. +// sorted: bool telling whether to sort the K returned points by their +// distance. +// version: Integer telling which implementation to use. +// TODO(jcjohns): Document this more, or maybe remove it before landing. // // Returns: // p1_neighbor_idx: LongTensor of shape (N, P1, K), where @@ -41,7 +41,7 @@ std::tuple KNearestNeighborIdx( const at::Tensor& p2, int K, int version) { - if (p1.type().is_cuda() || p2.type().is_cuda()) { + if (p1.is_cuda() || p2.is_cuda()) { #ifdef WITH_CUDA CHECK_CONTIGUOUS_CUDA(p1); CHECK_CONTIGUOUS_CUDA(p2); diff --git a/pytorch3d/csrc/knn/knn_cpu.cpp b/pytorch3d/csrc/knn/knn_cpu.cpp index dada972a..a2a55d2c 100644 --- a/pytorch3d/csrc/knn/knn_cpu.cpp +++ b/pytorch3d/csrc/knn/knn_cpu.cpp @@ -4,49 +4,48 @@ #include #include - std::tuple KNearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2, int K) { - const int N = p1.size(0); - const int P1 = p1.size(1); - const int D = p1.size(2); - const int P2 = p2.size(1); + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + const int P2 = p2.size(1); - auto long_opts = p1.options().dtype(torch::kInt64); - torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts); - torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); + auto long_opts = p1.options().dtype(torch::kInt64); + torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts); + torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); - auto p1_a = p1.accessor(); - auto p2_a = p2.accessor(); - auto idxs_a = idxs.accessor(); - auto dists_a = dists.accessor(); + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto idxs_a = idxs.accessor(); + auto dists_a = dists.accessor(); - for (int n = 0; n < N; ++n) { - for (int i1 = 0; i1 < P1; ++i1) { - // Use a priority queue to store (distance, index) tuples. - std::priority_queue> q; - for (int i2 = 0; i2 < P2; ++i2) { - float dist = 0; - for (int d = 0; d < D; ++d) { - float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; - dist += diff * diff; - } - int size = static_cast(q.size()); - if (size < K || dist < std::get<0>(q.top())) { - q.emplace(dist, i2); - if (size >= K) { - q.pop(); - } - } - } - while (!q.empty()) { - auto t = q.top(); - q.pop(); - const int k = q.size(); - dists_a[n][i1][k] = std::get<0>(t); - idxs_a[n][i1][k] = std::get<1>(t); - } + for (int n = 0; n < N; ++n) { + for (int i1 = 0; i1 < P1; ++i1) { + // Use a priority queue to store (distance, index) tuples. + std::priority_queue> q; + for (int i2 = 0; i2 < P2; ++i2) { + float dist = 0; + for (int d = 0; d < D; ++d) { + float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; + dist += diff * diff; } + int size = static_cast(q.size()); + if (size < K || dist < std::get<0>(q.top())) { + q.emplace(dist, i2); + if (size >= K) { + q.pop(); + } + } + } + while (!q.empty()) { + auto t = q.top(); + q.pop(); + const int k = q.size(); + dists_a[n][i1][k] = std::get<0>(t); + idxs_a[n][i1][k] = std::get<1>(t); + } } - return std::make_tuple(idxs, dists); + } + return std::make_tuple(idxs, dists); } diff --git a/pytorch3d/csrc/mink.cuh b/pytorch3d/csrc/mink.cuh index 5d7eb730..221b816f 100644 --- a/pytorch3d/csrc/mink.cuh +++ b/pytorch3d/csrc/mink.cuh @@ -5,7 +5,6 @@ #include "index_utils.cuh" - // A data structure to keep track of the smallest K keys seen so far as well // as their associated values, intended to be used in device code. // This data structure doesn't allocate any memory; keys and values are stored @@ -32,18 +31,17 @@ // float key_k = keys[k]; // int value_k = values[k]; // } -template +template class MinK { public: - // Constructor. // // Arguments: // keys: Array in which to store keys // values: Array in which to store values // K: How many values to keep track of - __device__ MinK(key_t *keys, value_t *vals, int K) : - keys(keys), vals(vals), K(K), _size(0) { } + __device__ MinK(key_t* keys, value_t* vals, int K) + : keys(keys), vals(vals), K(K), _size(0) {} // Try to add a new key and associated value to the data structure. If the key // is one of the smallest K seen so far then it will be kept; otherwise it @@ -55,7 +53,7 @@ class MinK { // Arguments: // key: The key to add // val: The value associated to the key - __device__ __forceinline__ void add(const key_t &key, const value_t &val) { + __device__ __forceinline__ void add(const key_t& key, const value_t& val) { if (_size < K) { keys[_size] = key; vals[_size] = val; @@ -71,8 +69,8 @@ class MinK { for (int k = 0; k < K; ++k) { key_t cur_key = keys[k]; if (cur_key > max_key) { - max_key = cur_key; - max_idx = k; + max_key = cur_key; + max_idx = k; } } } @@ -102,15 +100,14 @@ class MinK { } private: - key_t *keys; - value_t *vals; + key_t* keys; + value_t* vals; int K; int _size; key_t max_key; int max_idx; }; - // This is a version of MinK that only touches the arrays using static indexing // via RegisterIndexUtils. If the keys and values are stored in thread-local // arrays, then this may allow the compiler to place them in registers for @@ -120,13 +117,13 @@ class MinK { // We found that sorting via RegisterIndexUtils gave very poor performance, // and suspect it may have prevented the compiler from placing the arrays // into registers. -template +template class RegisterMinK { public: - __device__ RegisterMinK(key_t *keys, value_t *vals) : - keys(keys), vals(vals), _size(0) {} + __device__ RegisterMinK(key_t* keys, value_t* vals) + : keys(keys), vals(vals), _size(0) {} - __device__ __forceinline__ void add(const key_t &key, const value_t &val) { + __device__ __forceinline__ void add(const key_t& key, const value_t& val) { if (_size < K) { RegisterIndexUtils::set(keys, _size, key); RegisterIndexUtils::set(vals, _size, val); @@ -154,8 +151,8 @@ class RegisterMinK { } private: - key_t *keys; - value_t *vals; + key_t* keys; + value_t* vals; int _size; key_t max_key; int max_idx; diff --git a/pytorch3d/ops/knn.py b/pytorch3d/ops/knn.py index 2ec35992..3986b9bf 100644 --- a/pytorch3d/ops/knn.py +++ b/pytorch3d/ops/knn.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import torch + from pytorch3d import _C diff --git a/tests/bm_knn.py b/tests/bm_knn.py index fd391532..1bf935d2 100644 --- a/tests/bm_knn.py +++ b/tests/bm_knn.py @@ -1,7 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from itertools import product - import torch from fvcore.common.benchmark import benchmark @@ -30,21 +29,13 @@ def benchmark_knn_cuda_versions() -> None: continue if version == 3 and K > 4: continue - knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K, 'v': version}) + knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version}) for N, P, D in product(Ns, Ps, Ds): - nn_kwargs.append({'N': N, 'D': D, 'P': P}) + nn_kwargs.append({"N": N, "D": D, "P": P}) benchmark( - knn_cuda_with_init, - 'KNN_CUDA_VERSIONS', - knn_kwargs, - warmup_iters=1, - ) - benchmark( - nn_cuda_with_init, - 'NN_CUDA', - nn_kwargs, - warmup_iters=1, + knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1 ) + benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1) def benchmark_knn_cuda_vs_naive() -> None: @@ -55,21 +46,16 @@ def benchmark_knn_cuda_vs_naive() -> None: Ks = [1, 2, 4, 8, 16] knn_kwargs, naive_kwargs = [], [] for N, P, D, K in product(Ns, Ps, Ds, Ks): - knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) + knn_kwargs.append({"N": N, "D": D, "P": P, "K": K}) if P <= 4096: - naive_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) + naive_kwargs.append({"N": N, "D": D, "P": P, "K": K}) benchmark( knn_python_cuda_with_init, - 'KNN_CUDA_PYTHON', + "KNN_CUDA_PYTHON", naive_kwargs, warmup_iters=1, ) - benchmark( - knn_cuda_with_init, - 'KNN_CUDA', - knn_kwargs, - warmup_iters=1, - ) + benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1) def benchmark_knn_cpu() -> None: @@ -79,31 +65,18 @@ def benchmark_knn_cpu() -> None: Ks = [1, 2, 4] knn_kwargs, nn_kwargs = [], [] for N, P, D, K in product(Ns, Ps, Ds, Ks): - knn_kwargs.append({'N': N, 'D': D, 'P': P, 'K': K}) + knn_kwargs.append({"N": N, "D": D, "P": P, "K": K}) for N, P, D in product(Ns, Ps, Ds): - nn_kwargs.append({'N': N, 'D': D, 'P': P}) + nn_kwargs.append({"N": N, "D": D, "P": P}) benchmark( - knn_python_cpu_with_init, - 'KNN_CPU_PYTHON', - knn_kwargs, - warmup_iters=1, - ) - benchmark( - knn_cpu_with_init, - 'KNN_CPU_CPP', - knn_kwargs, - warmup_iters=1, - ) - benchmark( - nn_cpu_with_init, - 'NN_CPU_CPP', - nn_kwargs, - warmup_iters=1, + knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1 ) + benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1) + benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1) def knn_cuda_with_init(N, D, P, K, v=-1): - device = torch.device('cuda:0') + device = torch.device("cuda:0") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) torch.cuda.synchronize() @@ -116,7 +89,7 @@ def knn_cuda_with_init(N, D, P, K, v=-1): def knn_cpu_with_init(N, D, P, K): - device = torch.device('cpu') + device = torch.device("cpu") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) @@ -127,7 +100,7 @@ def knn_cpu_with_init(N, D, P, K): def knn_python_cuda_with_init(N, D, P, K): - device = torch.device('cuda') + device = torch.device("cuda") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) torch.cuda.synchronize() @@ -140,7 +113,7 @@ def knn_python_cuda_with_init(N, D, P, K): def knn_python_cpu_with_init(N, D, P, K): - device = torch.device('cpu') + device = torch.device("cpu") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) @@ -151,7 +124,7 @@ def knn_python_cpu_with_init(N, D, P, K): def nn_cuda_with_init(N, D, P): - device = torch.device('cuda') + device = torch.device("cuda") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) torch.cuda.synchronize() @@ -164,7 +137,7 @@ def nn_cuda_with_init(N, D, P): def nn_cpu_with_init(N, D, P): - device = torch.device('cpu') + device = torch.device("cpu") x = torch.randn(N, P, D, device=device) y = torch.randn(N, P, D, device=device) diff --git a/tests/test_build.py b/tests/test_build.py index 53e98701..865a7fd5 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -22,6 +22,27 @@ class TestBuild(unittest.TestCase): for k, v in counter.items(): self.assertEqual(v, 1, f"Too many files with stem {k}.") + def test_deprecated_usage(self): + # Check certain expressions do not occur in the csrc code + test_dir = Path(__file__).resolve().parent + source_dir = test_dir.parent / "pytorch3d" / "csrc" + + files = sorted(source_dir.glob("**/*.*")) + self.assertGreater(len(files), 4) + + patterns = [".type()", ".data()"] + + for file in files: + with open(file) as f: + text = f.read() + for pattern in patterns: + found = pattern in text + msg = ( + f"{pattern} found in {file.name}" + + ", this has been deprecated." + ) + self.assertFalse(found, msg) + def test_copyright(self): test_dir = Path(__file__).resolve().parent root_dir = test_dir.parent diff --git a/tests/test_knn.py b/tests/test_knn.py index 9c9483d6..1a090e3a 100644 --- a/tests/test_knn.py +++ b/tests/test_knn.py @@ -28,7 +28,7 @@ class TestKNN(unittest.TestCase): def test_knn_vs_python_cpu(self): """ Test CPU output vs PyTorch implementation """ - device = torch.device('cpu') + device = torch.device("cpu") Ns = [1, 4] Ds = [2, 3] P1s = [1, 10, 101] @@ -45,7 +45,7 @@ class TestKNN(unittest.TestCase): def test_knn_vs_python_cuda(self): """ Test CUDA output vs PyTorch implementation """ - device = torch.device('cuda') + device = torch.device("cuda") Ns = [1, 4] Ds = [2, 3, 8] P1s = [1, 8, 64, 128, 1001]