VoxelGridModule

Summary: Simple wrapper around voxel grids to make them a module

Reviewed By: bottler

Differential Revision: D38829762

fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
This commit is contained in:
Darijan Gudelj 2022-08-25 09:40:43 -07:00 committed by Facebook GitHub Bot
parent 6653f4400b
commit 24f5f4a3e7
2 changed files with 112 additions and 2 deletions

View File

@ -8,14 +8,23 @@
This file contains classes that implement Voxel grids, both in their full resolution
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
https://arxiv.org/abs/2203.09517.
TensoRF (https://arxiv.org/abs/2203.09517) paper.
In addition, the module VoxelGridModule implements a trainable instance of one of
these classes.
"""
from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple, Type
import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.implicitron.tools.config import (
Configurable,
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.structures.volumes import VolumeLocator
from .utils import interpolate_line, interpolate_plane, interpolate_volume
@ -426,3 +435,78 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
)
return shape_dict
class VoxelGridModule(Configurable, torch.nn.Module):
"""
A wrapper torch.nn.Module for the VoxelGrid classes, which
contains parameters that are needed to train the VoxelGrid classes.
Members:
voxel_grid_class_type: The name of the class to use for voxel_grid,
which must be available in the registry. Default FullResolutionVoxelGrid.
voxel_grid: An instance of `VoxelGridBase`. This is the object which
this class wraps.
extents: 3-tuple of a form (width, height, depth), denotes the size of the grid
in world units.
translation: 3-tuple of float. The center of the volume in world units as (x, y, z).
init_std: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.1
init_mean: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.
"""
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
voxel_grid: VoxelGridBase
extents: Tuple[float, float, float] = 1.0
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
init_std: float = 0.1
init_mean: float = 0
def __post_init__(self):
super().__init__()
run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes()
params = {
name: torch.normal(
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
std=self.init_std,
)
for name, shape in shapes.items()
}
self.params = torch.nn.ParameterDict(params)
def forward(self, points: torch.Tensor) -> torch.Tensor:
"""
Evaluates points in the world coordinate frame on the voxel_grid.
Args:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3)
Returns:
torch.Tensor of shape (n_points, n_features)
"""
locator = VolumeLocator(
batch_size=1,
# The resolution of the voxel grid does not need to be known
# to the locator object. It is easiest to fix the resolution of the locator.
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
# desired size. The locator object uses (z, y, x) convention for the grid_size,
# and this module uses (x, y, z) convention so the order has to be reversed
# (irrelevant in this case since they are all equal).
# It is (2, 2, 2) because the VolumeLocator object behaves like
# align_corners=True, which means that the points are in the corners of
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
grid_sizes=(2, 2, 2),
# The locator object uses (x, y, z) convention for the
# voxel size and translation.
voxel_size=self.extents,
volume_translation=self.translation,
device=next(self.params.values()).device,
)
grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]

View File

@ -19,6 +19,7 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
CPFactorizedVoxelGrid,
FullResolutionVoxelGrid,
VMFactorizedVoxelGrid,
VoxelGridModule,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
@ -198,6 +199,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
expand_args_fields(FullResolutionVoxelGrid)
expand_args_fields(CPFactorizedVoxelGrid)
expand_args_fields(VMFactorizedVoxelGrid)
expand_args_fields(VoxelGridModule)
def _interpolate_1D(
self, points: torch.Tensor, vectors: torch.Tensor
@ -585,3 +587,27 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=10,
n_components=3,
)
def test_voxel_grid_module_location(self, n_times=10):
"""
This checks the module uses locator correctly etc..
If we know that voxel grids work for (x, y, z) in local coordinates
to test if the VoxelGridModule does not have permuted dimensions we
create local coordinates, pass them through verified voxelgrids and
compare the result with the result that we get when we convert
coordinates to world and pass them through the VoxelGridModule
"""
for _ in range(n_times):
extents = tuple(torch.randint(1, 50, size=(3,)).tolist())
grid = VoxelGridModule(extents=extents)
local_point = torch.rand(1, 3) * 2 - 1
world_point = local_point * torch.tensor(extents) / 2
grid_values = grid.voxel_grid.values_type(**grid.params)
assert torch.allclose(
grid(world_point)[0, 0],
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
rtol=0.0001,
)