arbitrary shape input to voxel_grids

Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of  only `[n_grids, n_points, 3]`.

Reviewed By: bottler

Differential Revision: D39574373

fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4
This commit is contained in:
Darijan Gudelj
2022-09-22 03:35:11 -07:00
committed by Facebook GitHub Bot
parent 6ae6ff9cf7
commit db3c12abfb
2 changed files with 93 additions and 55 deletions

View File

@@ -38,10 +38,17 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
return
def get_random_normalized_points(
self, n_grids, n_points, dimension=3
self, n_grids, n_points=None, dimension=3
) -> torch.Tensor:
middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,))))
# create random query points
return torch.rand(n_grids, n_points, dimension) * 2 - 1
return (
torch.rand(
n_grids, *(middle_shape if n_points is None else [n_points]), dimension
)
* 2
- 1
)
def _test_query_with_constant_init_cp(
self,
@@ -50,7 +57,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_components: int,
resolution: Tuple[int],
value: float = 1,
n_points: int = 1,
) -> None:
# set everything to 'value' and do query for elementsthe result should
# be of shape (n_grids, n_points, n_features) and be filled with n_components
@@ -65,12 +71,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
points = self.get_random_normalized_points(n_grids)
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * n_components * value,
grid.evaluate_local(points, params),
torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value,
rtol=0.0001,
)
def _test_query_with_constant_init_vm(
@@ -98,11 +103,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
expected_element = (
n_components * value if distribution is None else sum(distribution) * value
)
points = self.get_random_normalized_points(n_grids)
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * expected_element,
grid.evaluate_local(points, params),
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
)
def _test_query_with_constant_init_full(
@@ -121,21 +125,20 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
)
expected_element = value
points = self.get_random_normalized_points(n_grids)
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * expected_element,
grid.evaluate_local(points, params),
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
)
def test_query_with_constant_init(self):
with self.subTest("Full"):
self._test_query_with_constant_init_full(
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
n_grids=5, n_features=6, resolution=(3, 4, 5)
)
with self.subTest("Full with 1 in dimensions"):
self._test_query_with_constant_init_full(
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
n_grids=5, n_features=1, resolution=(33, 41, 1)
)
with self.subTest("CP"):
self._test_query_with_constant_init_cp(
@@ -143,7 +146,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=6,
n_components=7,
resolution=(3, 4, 5),
n_points=2,
)
with self.subTest("CP with 1 in dimensions"):
self._test_query_with_constant_init_cp(
@@ -151,7 +153,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=1,
n_components=3,
resolution=(3, 1, 1),
n_points=4,
)
with self.subTest("VM with symetric distribution"):
self._test_query_with_constant_init_vm(
@@ -159,7 +160,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=9,
resolution=(2, 12, 2),
n_components=12,
n_points=3,
)
with self.subTest("VM with distribution"):
self._test_query_with_constant_init_vm(
@@ -167,7 +167,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=1,
resolution=(5, 9, 7),
distribution=(33, 41, 1),
n_points=7,
)
def test_query_with_zero_init(self):
@@ -177,7 +176,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=6,
n_components=7,
resolution=(3, 2, 5),
n_points=3,
value=0,
)
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
@@ -186,12 +184,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=9,
resolution=(2, 11, 3),
n_components=3,
n_points=3,
value=0,
)
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
self._test_query_with_constant_init_full(
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
n_grids=4, n_features=2, resolution=(3, 3, 5), value=0
)
def setUp(self):
@@ -324,6 +321,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
padding_mode="zeros",
mode="bilinear",
),
rtol=0.0001,
)
def test_floating_point_query(self):