VoxelGridModule

Summary: Simple wrapper around voxel grids to make them a module

Reviewed By: bottler

Differential Revision: D38829762

fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
This commit is contained in:
Darijan Gudelj
2022-08-25 09:40:43 -07:00
committed by Facebook GitHub Bot
parent 6653f4400b
commit 24f5f4a3e7
2 changed files with 112 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
CPFactorizedVoxelGrid,
FullResolutionVoxelGrid,
VMFactorizedVoxelGrid,
VoxelGridModule,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
@@ -198,6 +199,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
expand_args_fields(FullResolutionVoxelGrid)
expand_args_fields(CPFactorizedVoxelGrid)
expand_args_fields(VMFactorizedVoxelGrid)
expand_args_fields(VoxelGridModule)
def _interpolate_1D(
self, points: torch.Tensor, vectors: torch.Tensor
@@ -585,3 +587,27 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=10,
n_components=3,
)
def test_voxel_grid_module_location(self, n_times=10):
"""
This checks the module uses locator correctly etc..
If we know that voxel grids work for (x, y, z) in local coordinates
to test if the VoxelGridModule does not have permuted dimensions we
create local coordinates, pass them through verified voxelgrids and
compare the result with the result that we get when we convert
coordinates to world and pass them through the VoxelGridModule
"""
for _ in range(n_times):
extents = tuple(torch.randint(1, 50, size=(3,)).tolist())
grid = VoxelGridModule(extents=extents)
local_point = torch.rand(1, 3) * 2 - 1
world_point = local_point * torch.tensor(extents) / 2
grid_values = grid.voxel_grid.values_type(**grid.params)
assert torch.allclose(
grid(world_point)[0, 0],
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
rtol=0.0001,
)