mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Expose knn_check_version in python
Summary: We have multiple KNN CUDA implementations. From python, users can currently request a particular implementation via the `version` flag, but they have no way of knowing which implementations can be used for a given problem. This diff exposes a function `pytorch3d._C.knn_check_version(version, D, K)` that returns whether a particular version can be used. Reviewed By: nikhilaravi Differential Revision: D21162573 fbshipit-source-id: 6061960bdcecba454fd920b00036f4e9ff3fdbc0
This commit is contained in:
parent
e38cbfe5e3
commit
9f31a4fd46
@ -18,6 +18,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
|
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
|
||||||
m.def("packed_to_padded", &PackedToPadded);
|
m.def("packed_to_padded", &PackedToPadded);
|
||||||
m.def("padded_to_packed", &PaddedToPacked);
|
m.def("padded_to_packed", &PaddedToPacked);
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
m.def("knn_check_version", &KnnCheckVersion);
|
||||||
|
#endif
|
||||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||||
m.def("knn_points_backward", &KNearestNeighborBackward);
|
m.def("knn_points_backward", &KNearestNeighborBackward);
|
||||||
m.def("gather_scatter", &gather_scatter);
|
m.def("gather_scatter", &gather_scatter);
|
||||||
|
@ -267,7 +267,7 @@ bool InBounds(const int64_t min, const int64_t x, const int64_t max) {
|
|||||||
return min <= x && x <= max;
|
return min <= x && x <= max;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CheckVersion(int version, const int64_t D, const int64_t K) {
|
bool KnnCheckVersion(int version, const int64_t D, const int64_t K) {
|
||||||
if (version == 0) {
|
if (version == 0) {
|
||||||
return true;
|
return true;
|
||||||
} else if (version == 1) {
|
} else if (version == 1) {
|
||||||
@ -282,7 +282,7 @@ bool CheckVersion(int version, const int64_t D, const int64_t K) {
|
|||||||
|
|
||||||
int ChooseVersion(const int64_t D, const int64_t K) {
|
int ChooseVersion(const int64_t D, const int64_t K) {
|
||||||
for (int version = 3; version >= 1; version--) {
|
for (int version = 3; version >= 1; version--) {
|
||||||
if (CheckVersion(version, D, K)) {
|
if (KnnCheckVersion(version, D, K)) {
|
||||||
return version;
|
return version;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -309,7 +309,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
|
|
||||||
if (version < 0) {
|
if (version < 0) {
|
||||||
version = ChooseVersion(D, K);
|
version = ChooseVersion(D, K);
|
||||||
} else if (!CheckVersion(version, D, K)) {
|
} else if (!KnnCheckVersion(version, D, K)) {
|
||||||
int new_version = ChooseVersion(D, K);
|
int new_version = ChooseVersion(D, K);
|
||||||
std::cout << "WARNING: Requested KNN version " << version
|
std::cout << "WARNING: Requested KNN version " << version
|
||||||
<< " is not compatible with D = " << D << "; K = " << K
|
<< " is not compatible with D = " << D << "; K = " << K
|
||||||
@ -321,7 +321,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
// gave us. But we can check once more to be sure; however this time
|
// gave us. But we can check once more to be sure; however this time
|
||||||
// assert fail since failing at this point means we have a bug in our version
|
// assert fail since failing at this point means we have a bug in our version
|
||||||
// selection or checking code.
|
// selection or checking code.
|
||||||
AT_ASSERTM(CheckVersion(version, D, K), "Invalid version");
|
AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version");
|
||||||
|
|
||||||
const size_t threads = 256;
|
const size_t threads = 256;
|
||||||
const size_t blocks = 256;
|
const size_t blocks = 256;
|
||||||
|
@ -128,3 +128,15 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
|||||||
return KNearestNeighborBackwardCpu(
|
return KNearestNeighborBackwardCpu(
|
||||||
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Utility to check whether a KNN version can be used.
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// version: Integer in the range 0 <= version <= 3 indicating one of our
|
||||||
|
// KNN implementations.
|
||||||
|
// D: Number of dimensions for the input and query point clouds
|
||||||
|
// K: Number of neighbors to be found
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// Whether the indicated KNN version can be used.
|
||||||
|
bool KnnCheckVersion(int version, const int64_t D, const int64_t K);
|
||||||
|
@ -155,6 +155,26 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
|
self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
|
||||||
|
|
||||||
|
def test_knn_check_version(self):
|
||||||
|
try:
|
||||||
|
from pytorch3d._C import knn_check_version
|
||||||
|
except ImportError:
|
||||||
|
# knn_check_version will only be defined if we compiled with CUDA support
|
||||||
|
return
|
||||||
|
for D in range(-10, 10):
|
||||||
|
for K in range(-10, 20):
|
||||||
|
v0 = True
|
||||||
|
v1 = 1 <= D <= 32
|
||||||
|
v2 = 1 <= D <= 8 and 1 <= K <= 32
|
||||||
|
v3 = 1 <= D <= 8 and 1 <= K <= 4
|
||||||
|
all_expected = [v0, v1, v2, v3]
|
||||||
|
for version in range(-10, 10):
|
||||||
|
actual = knn_check_version(version, D, K)
|
||||||
|
expected = False
|
||||||
|
if 0 <= version < len(all_expected):
|
||||||
|
expected = all_expected[version]
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user