From 717493cb79f16e67a0d64653bbfd36558683f78b Mon Sep 17 00:00:00 2001 From: Kyle Vedder Date: Mon, 17 Jun 2024 06:00:13 -0700 Subject: [PATCH] Fixed last dimension size check so that it doesn't trivially pass. (#1815) Summary: Currently, it checks that the `2`th dimension of `p2` is the same size as the `2`th dimension of `p2` instead of `p1`. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1815 Reviewed By: MichaelRamamonjisoa Differential Revision: D58586966 Pulled By: bottler fbshipit-source-id: d4f723fa264f90fe368c10825c1acdfdc4c406dc --- pytorch3d/csrc/knn/knn.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index 93a3060b..ad9dce24 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -338,7 +338,7 @@ std::tuple KNearestNeighborIdxCuda( TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2."); - TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension"); + TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); auto long_dtype = lengths1.options().dtype(at::kLong); auto idxs = at::zeros({N, P1, K}, long_dtype); auto dists = at::zeros({N, P1, K}, p1.options());