VoxelGridModule

Summary: Simple wrapper around voxel grids to make them a module

Reviewed By: bottler

Differential Revision: D38829762

fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
This commit is contained in:
Darijan Gudelj
2022-08-25 09:40:43 -07:00
committed by Facebook GitHub Bot
parent 6653f4400b
commit 24f5f4a3e7
2 changed files with 112 additions and 2 deletions

View File

@@ -8,14 +8,23 @@
This file contains classes that implement Voxel grids, both in their full resolution
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
https://arxiv.org/abs/2203.09517.
TensoRF (https://arxiv.org/abs/2203.09517) paper.
In addition, the module VoxelGridModule implements a trainable instance of one of
these classes.
"""
from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple, Type
import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.implicitron.tools.config import (
Configurable,
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.structures.volumes import VolumeLocator
from .utils import interpolate_line, interpolate_plane, interpolate_volume
@@ -426,3 +435,78 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
)
return shape_dict
class VoxelGridModule(Configurable, torch.nn.Module):
"""
A wrapper torch.nn.Module for the VoxelGrid classes, which
contains parameters that are needed to train the VoxelGrid classes.
Members:
voxel_grid_class_type: The name of the class to use for voxel_grid,
which must be available in the registry. Default FullResolutionVoxelGrid.
voxel_grid: An instance of `VoxelGridBase`. This is the object which
this class wraps.
extents: 3-tuple of a form (width, height, depth), denotes the size of the grid
in world units.
translation: 3-tuple of float. The center of the volume in world units as (x, y, z).
init_std: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.1
init_mean: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.
"""
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
voxel_grid: VoxelGridBase
extents: Tuple[float, float, float] = 1.0
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
init_std: float = 0.1
init_mean: float = 0
def __post_init__(self):
super().__init__()
run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes()
params = {
name: torch.normal(
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
std=self.init_std,
)
for name, shape in shapes.items()
}
self.params = torch.nn.ParameterDict(params)
def forward(self, points: torch.Tensor) -> torch.Tensor:
"""
Evaluates points in the world coordinate frame on the voxel_grid.
Args:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3)
Returns:
torch.Tensor of shape (n_points, n_features)
"""
locator = VolumeLocator(
batch_size=1,
# The resolution of the voxel grid does not need to be known
# to the locator object. It is easiest to fix the resolution of the locator.
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
# desired size. The locator object uses (z, y, x) convention for the grid_size,
# and this module uses (x, y, z) convention so the order has to be reversed
# (irrelevant in this case since they are all equal).
# It is (2, 2, 2) because the VolumeLocator object behaves like
# align_corners=True, which means that the points are in the corners of
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
grid_sizes=(2, 2, 2),
# The locator object uses (x, y, z) convention for the
# voxel size and translation.
voxel_size=self.extents,
volume_translation=self.translation,
device=next(self.params.values()).device,
)
grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]