mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
VoxelGridModule
Summary: Simple wrapper around voxel grids to make them a module Reviewed By: bottler Differential Revision: D38829762 fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
This commit is contained in:
parent
6653f4400b
commit
24f5f4a3e7
@ -8,14 +8,23 @@
|
||||
This file contains classes that implement Voxel grids, both in their full resolution
|
||||
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
|
||||
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
|
||||
https://arxiv.org/abs/2203.09517.
|
||||
TensoRF (https://arxiv.org/abs/2203.09517) paper.
|
||||
|
||||
In addition, the module VoxelGridModule implements a trainable instance of one of
|
||||
these classes.
|
||||
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.structures.volumes import VolumeLocator
|
||||
|
||||
from .utils import interpolate_line, interpolate_plane, interpolate_volume
|
||||
@ -426,3 +435,78 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
)
|
||||
|
||||
return shape_dict
|
||||
|
||||
|
||||
class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
"""
|
||||
A wrapper torch.nn.Module for the VoxelGrid classes, which
|
||||
contains parameters that are needed to train the VoxelGrid classes.
|
||||
|
||||
Members:
|
||||
voxel_grid_class_type: The name of the class to use for voxel_grid,
|
||||
which must be available in the registry. Default FullResolutionVoxelGrid.
|
||||
voxel_grid: An instance of `VoxelGridBase`. This is the object which
|
||||
this class wraps.
|
||||
extents: 3-tuple of a form (width, height, depth), denotes the size of the grid
|
||||
in world units.
|
||||
translation: 3-tuple of float. The center of the volume in world units as (x, y, z).
|
||||
init_std: Parameters are initialized using the gaussian distribution
|
||||
with mean=init_mean and std=init_std. Default 0.1
|
||||
init_mean: Parameters are initialized using the gaussian distribution
|
||||
with mean=init_mean and std=init_std. Default 0.
|
||||
"""
|
||||
|
||||
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
||||
voxel_grid: VoxelGridBase
|
||||
|
||||
extents: Tuple[float, float, float] = 1.0
|
||||
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
|
||||
init_std: float = 0.1
|
||||
init_mean: float = 0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
|
||||
shapes = self.voxel_grid.get_shapes()
|
||||
params = {
|
||||
name: torch.normal(
|
||||
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
|
||||
std=self.init_std,
|
||||
)
|
||||
for name, shape in shapes.items()
|
||||
}
|
||||
self.params = torch.nn.ParameterDict(params)
|
||||
|
||||
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Evaluates points in the world coordinate frame on the voxel_grid.
|
||||
|
||||
Args:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3)
|
||||
Returns:
|
||||
torch.Tensor of shape (n_points, n_features)
|
||||
"""
|
||||
locator = VolumeLocator(
|
||||
batch_size=1,
|
||||
# The resolution of the voxel grid does not need to be known
|
||||
# to the locator object. It is easiest to fix the resolution of the locator.
|
||||
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
|
||||
# desired size. The locator object uses (z, y, x) convention for the grid_size,
|
||||
# and this module uses (x, y, z) convention so the order has to be reversed
|
||||
# (irrelevant in this case since they are all equal).
|
||||
# It is (2, 2, 2) because the VolumeLocator object behaves like
|
||||
# align_corners=True, which means that the points are in the corners of
|
||||
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
|
||||
grid_sizes=(2, 2, 2),
|
||||
# The locator object uses (x, y, z) convention for the
|
||||
# voxel size and translation.
|
||||
voxel_size=self.extents,
|
||||
volume_translation=self.translation,
|
||||
device=next(self.params.values()).device,
|
||||
)
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||
|
@ -19,6 +19,7 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
|
||||
CPFactorizedVoxelGrid,
|
||||
FullResolutionVoxelGrid,
|
||||
VMFactorizedVoxelGrid,
|
||||
VoxelGridModule,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
@ -198,6 +199,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
expand_args_fields(FullResolutionVoxelGrid)
|
||||
expand_args_fields(CPFactorizedVoxelGrid)
|
||||
expand_args_fields(VMFactorizedVoxelGrid)
|
||||
expand_args_fields(VoxelGridModule)
|
||||
|
||||
def _interpolate_1D(
|
||||
self, points: torch.Tensor, vectors: torch.Tensor
|
||||
@ -585,3 +587,27 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=10,
|
||||
n_components=3,
|
||||
)
|
||||
|
||||
def test_voxel_grid_module_location(self, n_times=10):
|
||||
"""
|
||||
This checks the module uses locator correctly etc..
|
||||
|
||||
If we know that voxel grids work for (x, y, z) in local coordinates
|
||||
to test if the VoxelGridModule does not have permuted dimensions we
|
||||
create local coordinates, pass them through verified voxelgrids and
|
||||
compare the result with the result that we get when we convert
|
||||
coordinates to world and pass them through the VoxelGridModule
|
||||
"""
|
||||
for _ in range(n_times):
|
||||
extents = tuple(torch.randint(1, 50, size=(3,)).tolist())
|
||||
|
||||
grid = VoxelGridModule(extents=extents)
|
||||
local_point = torch.rand(1, 3) * 2 - 1
|
||||
world_point = local_point * torch.tensor(extents) / 2
|
||||
grid_values = grid.voxel_grid.values_type(**grid.params)
|
||||
|
||||
assert torch.allclose(
|
||||
grid(world_point)[0, 0],
|
||||
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
|
||||
rtol=0.0001,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user