mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Fix: Pointclouds.inside_box reducing over spatial dimensions.
Summary: As subj. Tests corrected accordingly. Also changed the test to provide a bit better diagnostics. Reviewed By: bottler Differential Revision: D32879498 fbshipit-source-id: 0a852e4a13dcb4ca3e54d71c6b263c5d2eeaf4eb
This commit is contained in:
		
							parent
							
								
									d9f709599b
								
							
						
					
					
						commit
						a6508ac3df
					
				@ -1176,5 +1176,5 @@ class Pointclouds:
 | 
			
		||||
            ]
 | 
			
		||||
            box = torch.cat(box, 0)
 | 
			
		||||
 | 
			
		||||
        idx = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
 | 
			
		||||
        return idx
 | 
			
		||||
        coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
 | 
			
		||||
        return coord_inside.all(dim=-1)
 | 
			
		||||
 | 
			
		||||
@ -163,8 +163,17 @@ class TestCaseMixin(unittest.TestCase):
 | 
			
		||||
        if close:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        diff = backend.abs(input + 0.0 - other)
 | 
			
		||||
        ratio = diff / backend.abs(other)
 | 
			
		||||
        # handle bool case
 | 
			
		||||
        if backend == torch and input.dtype == torch.bool:
 | 
			
		||||
            diff = (input != other).float()
 | 
			
		||||
            ratio = diff
 | 
			
		||||
        if backend == np and input.dtype == bool:
 | 
			
		||||
            diff = (input != other).astype(float)
 | 
			
		||||
            ratio = diff
 | 
			
		||||
        else:
 | 
			
		||||
            diff = backend.abs(input + 0.0 - other)
 | 
			
		||||
            ratio = diff / backend.abs(other)
 | 
			
		||||
 | 
			
		||||
        try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0))
 | 
			
		||||
        if try_relative.all():
 | 
			
		||||
            if backend == np:
 | 
			
		||||
 | 
			
		||||
@ -976,7 +976,9 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_inside_box(self):
 | 
			
		||||
        def inside_box_naive(cloud, box_min, box_max):
 | 
			
		||||
            return (cloud >= box_min.view(1, 3)) * (cloud <= box_max.view(1, 3))
 | 
			
		||||
            return ((cloud >= box_min.view(1, 3)) * (cloud <= box_max.view(1, 3))).all(
 | 
			
		||||
                dim=-1
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        N, P, C = 5, 100, 4
 | 
			
		||||
 | 
			
		||||
@ -994,7 +996,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        for i, cloud in enumerate(clouds.points_list()):
 | 
			
		||||
            within_box_naive.append(inside_box_naive(cloud, box[i, 0], box[i, 1]))
 | 
			
		||||
        within_box_naive = torch.cat(within_box_naive, 0)
 | 
			
		||||
        self.assertTrue(within_box.eq(within_box_naive).all())
 | 
			
		||||
        self.assertClose(within_box, within_box_naive)
 | 
			
		||||
 | 
			
		||||
        # box of shape 2x3
 | 
			
		||||
        box2 = box[0, :]
 | 
			
		||||
@ -1005,13 +1007,13 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        for cloud in clouds.points_list():
 | 
			
		||||
            within_box_naive2.append(inside_box_naive(cloud, box2[0], box2[1]))
 | 
			
		||||
        within_box_naive2 = torch.cat(within_box_naive2, 0)
 | 
			
		||||
        self.assertTrue(within_box2.eq(within_box_naive2).all())
 | 
			
		||||
        self.assertClose(within_box2, within_box_naive2)
 | 
			
		||||
 | 
			
		||||
        # box of shape 1x2x3
 | 
			
		||||
        box3 = box2.expand(1, 2, 3)
 | 
			
		||||
 | 
			
		||||
        within_box3 = clouds.inside_box(box3)
 | 
			
		||||
        self.assertTrue(within_box2.eq(within_box3).all())
 | 
			
		||||
        self.assertClose(within_box2, within_box3)
 | 
			
		||||
 | 
			
		||||
        # invalid box
 | 
			
		||||
        invalid_box = torch.cat(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user