mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
arbitrary shape input to voxel_grids
Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of only `[n_grids, n_points, 3]`. Reviewed By: bottler Differential Revision: D39574373 fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6ae6ff9cf7
commit
db3c12abfb
@@ -38,10 +38,17 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
return
|
||||
|
||||
def get_random_normalized_points(
|
||||
self, n_grids, n_points, dimension=3
|
||||
self, n_grids, n_points=None, dimension=3
|
||||
) -> torch.Tensor:
|
||||
middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,))))
|
||||
# create random query points
|
||||
return torch.rand(n_grids, n_points, dimension) * 2 - 1
|
||||
return (
|
||||
torch.rand(
|
||||
n_grids, *(middle_shape if n_points is None else [n_points]), dimension
|
||||
)
|
||||
* 2
|
||||
- 1
|
||||
)
|
||||
|
||||
def _test_query_with_constant_init_cp(
|
||||
self,
|
||||
@@ -50,7 +57,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_components: int,
|
||||
resolution: Tuple[int],
|
||||
value: float = 1,
|
||||
n_points: int = 1,
|
||||
) -> None:
|
||||
# set everything to 'value' and do query for elementsthe result should
|
||||
# be of shape (n_grids, n_points, n_features) and be filled with n_components
|
||||
@@ -65,12 +71,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
|
||||
points = self.get_random_normalized_points(n_grids)
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * n_components * value,
|
||||
grid.evaluate_local(points, params),
|
||||
torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value,
|
||||
rtol=0.0001,
|
||||
)
|
||||
|
||||
def _test_query_with_constant_init_vm(
|
||||
@@ -98,11 +103,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
expected_element = (
|
||||
n_components * value if distribution is None else sum(distribution) * value
|
||||
)
|
||||
points = self.get_random_normalized_points(n_grids)
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * expected_element,
|
||||
grid.evaluate_local(points, params),
|
||||
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
|
||||
)
|
||||
|
||||
def _test_query_with_constant_init_full(
|
||||
@@ -121,21 +125,20 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
expected_element = value
|
||||
points = self.get_random_normalized_points(n_grids)
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * expected_element,
|
||||
grid.evaluate_local(points, params),
|
||||
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
|
||||
)
|
||||
|
||||
def test_query_with_constant_init(self):
|
||||
with self.subTest("Full"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
|
||||
n_grids=5, n_features=6, resolution=(3, 4, 5)
|
||||
)
|
||||
with self.subTest("Full with 1 in dimensions"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
|
||||
n_grids=5, n_features=1, resolution=(33, 41, 1)
|
||||
)
|
||||
with self.subTest("CP"):
|
||||
self._test_query_with_constant_init_cp(
|
||||
@@ -143,7 +146,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=6,
|
||||
n_components=7,
|
||||
resolution=(3, 4, 5),
|
||||
n_points=2,
|
||||
)
|
||||
with self.subTest("CP with 1 in dimensions"):
|
||||
self._test_query_with_constant_init_cp(
|
||||
@@ -151,7 +153,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=1,
|
||||
n_components=3,
|
||||
resolution=(3, 1, 1),
|
||||
n_points=4,
|
||||
)
|
||||
with self.subTest("VM with symetric distribution"):
|
||||
self._test_query_with_constant_init_vm(
|
||||
@@ -159,7 +160,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=9,
|
||||
resolution=(2, 12, 2),
|
||||
n_components=12,
|
||||
n_points=3,
|
||||
)
|
||||
with self.subTest("VM with distribution"):
|
||||
self._test_query_with_constant_init_vm(
|
||||
@@ -167,7 +167,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=1,
|
||||
resolution=(5, 9, 7),
|
||||
distribution=(33, 41, 1),
|
||||
n_points=7,
|
||||
)
|
||||
|
||||
def test_query_with_zero_init(self):
|
||||
@@ -177,7 +176,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=6,
|
||||
n_components=7,
|
||||
resolution=(3, 2, 5),
|
||||
n_points=3,
|
||||
value=0,
|
||||
)
|
||||
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
|
||||
@@ -186,12 +184,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=9,
|
||||
resolution=(2, 11, 3),
|
||||
n_components=3,
|
||||
n_points=3,
|
||||
value=0,
|
||||
)
|
||||
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
|
||||
n_grids=4, n_features=2, resolution=(3, 3, 5), value=0
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
@@ -324,6 +321,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
padding_mode="zeros",
|
||||
mode="bilinear",
|
||||
),
|
||||
rtol=0.0001,
|
||||
)
|
||||
|
||||
def test_floating_point_query(self):
|
||||
|
||||
Reference in New Issue
Block a user