From a6508ac3dfaaf59d8bdce176bfbafad94c1d0604 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Mon, 6 Dec 2021 07:44:14 -0800 Subject: [PATCH] Fix: Pointclouds.inside_box reducing over spatial dimensions. Summary: As subj. Tests corrected accordingly. Also changed the test to provide a bit better diagnostics. Reviewed By: bottler Differential Revision: D32879498 fbshipit-source-id: 0a852e4a13dcb4ca3e54d71c6b263c5d2eeaf4eb --- pytorch3d/structures/pointclouds.py | 4 ++-- tests/common_testing.py | 13 +++++++++++-- tests/test_pointclouds.py | 10 ++++++---- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 9b65f496..6d17fb3e 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -1176,5 +1176,5 @@ class Pointclouds: ] box = torch.cat(box, 0) - idx = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1]) - return idx + coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1]) + return coord_inside.all(dim=-1) diff --git a/tests/common_testing.py b/tests/common_testing.py index e92bbb52..1056fcf6 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -163,8 +163,17 @@ class TestCaseMixin(unittest.TestCase): if close: return - diff = backend.abs(input + 0.0 - other) - ratio = diff / backend.abs(other) + # handle bool case + if backend == torch and input.dtype == torch.bool: + diff = (input != other).float() + ratio = diff + if backend == np and input.dtype == bool: + diff = (input != other).astype(float) + ratio = diff + else: + diff = backend.abs(input + 0.0 - other) + ratio = diff / backend.abs(other) + try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0)) if try_relative.all(): if backend == np: diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index f83de817..d92cf895 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -976,7 +976,9 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): def test_inside_box(self): def inside_box_naive(cloud, box_min, box_max): - return (cloud >= box_min.view(1, 3)) * (cloud <= box_max.view(1, 3)) + return ((cloud >= box_min.view(1, 3)) * (cloud <= box_max.view(1, 3))).all( + dim=-1 + ) N, P, C = 5, 100, 4 @@ -994,7 +996,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): for i, cloud in enumerate(clouds.points_list()): within_box_naive.append(inside_box_naive(cloud, box[i, 0], box[i, 1])) within_box_naive = torch.cat(within_box_naive, 0) - self.assertTrue(within_box.eq(within_box_naive).all()) + self.assertClose(within_box, within_box_naive) # box of shape 2x3 box2 = box[0, :] @@ -1005,13 +1007,13 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): for cloud in clouds.points_list(): within_box_naive2.append(inside_box_naive(cloud, box2[0], box2[1])) within_box_naive2 = torch.cat(within_box_naive2, 0) - self.assertTrue(within_box2.eq(within_box_naive2).all()) + self.assertClose(within_box2, within_box_naive2) # box of shape 1x2x3 box3 = box2.expand(1, 2, 3) within_box3 = clouds.inside_box(box3) - self.assertTrue(within_box2.eq(within_box3).all()) + self.assertClose(within_box2, within_box3) # invalid box invalid_box = torch.cat(