mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
VoxelGridModule
Summary: Simple wrapper around voxel grids to make them a module Reviewed By: bottler Differential Revision: D38829762 fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
This commit is contained in:
parent
6653f4400b
commit
24f5f4a3e7
@ -8,14 +8,23 @@
|
|||||||
This file contains classes that implement Voxel grids, both in their full resolution
|
This file contains classes that implement Voxel grids, both in their full resolution
|
||||||
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
|
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
|
||||||
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
|
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
|
||||||
https://arxiv.org/abs/2203.09517.
|
TensoRF (https://arxiv.org/abs/2203.09517) paper.
|
||||||
|
|
||||||
|
In addition, the module VoxelGridModule implements a trainable instance of one of
|
||||||
|
these classes.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Dict, Optional, Tuple, Type
|
from typing import ClassVar, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
Configurable,
|
||||||
|
registry,
|
||||||
|
ReplaceableBase,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
from pytorch3d.structures.volumes import VolumeLocator
|
from pytorch3d.structures.volumes import VolumeLocator
|
||||||
|
|
||||||
from .utils import interpolate_line, interpolate_plane, interpolate_volume
|
from .utils import interpolate_line, interpolate_plane, interpolate_volume
|
||||||
@ -426,3 +435,78 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return shape_dict
|
return shape_dict
|
||||||
|
|
||||||
|
|
||||||
|
class VoxelGridModule(Configurable, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A wrapper torch.nn.Module for the VoxelGrid classes, which
|
||||||
|
contains parameters that are needed to train the VoxelGrid classes.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
voxel_grid_class_type: The name of the class to use for voxel_grid,
|
||||||
|
which must be available in the registry. Default FullResolutionVoxelGrid.
|
||||||
|
voxel_grid: An instance of `VoxelGridBase`. This is the object which
|
||||||
|
this class wraps.
|
||||||
|
extents: 3-tuple of a form (width, height, depth), denotes the size of the grid
|
||||||
|
in world units.
|
||||||
|
translation: 3-tuple of float. The center of the volume in world units as (x, y, z).
|
||||||
|
init_std: Parameters are initialized using the gaussian distribution
|
||||||
|
with mean=init_mean and std=init_std. Default 0.1
|
||||||
|
init_mean: Parameters are initialized using the gaussian distribution
|
||||||
|
with mean=init_mean and std=init_std. Default 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
||||||
|
voxel_grid: VoxelGridBase
|
||||||
|
|
||||||
|
extents: Tuple[float, float, float] = 1.0
|
||||||
|
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
|
|
||||||
|
init_std: float = 0.1
|
||||||
|
init_mean: float = 0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
|
||||||
|
shapes = self.voxel_grid.get_shapes()
|
||||||
|
params = {
|
||||||
|
name: torch.normal(
|
||||||
|
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
|
||||||
|
std=self.init_std,
|
||||||
|
)
|
||||||
|
for name, shape in shapes.items()
|
||||||
|
}
|
||||||
|
self.params = torch.nn.ParameterDict(params)
|
||||||
|
|
||||||
|
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Evaluates points in the world coordinate frame on the voxel_grid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points (torch.Tensor): tensor of points that you want to query
|
||||||
|
of a form (n_points, 3)
|
||||||
|
Returns:
|
||||||
|
torch.Tensor of shape (n_points, n_features)
|
||||||
|
"""
|
||||||
|
locator = VolumeLocator(
|
||||||
|
batch_size=1,
|
||||||
|
# The resolution of the voxel grid does not need to be known
|
||||||
|
# to the locator object. It is easiest to fix the resolution of the locator.
|
||||||
|
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
|
||||||
|
# desired size. The locator object uses (z, y, x) convention for the grid_size,
|
||||||
|
# and this module uses (x, y, z) convention so the order has to be reversed
|
||||||
|
# (irrelevant in this case since they are all equal).
|
||||||
|
# It is (2, 2, 2) because the VolumeLocator object behaves like
|
||||||
|
# align_corners=True, which means that the points are in the corners of
|
||||||
|
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
|
||||||
|
grid_sizes=(2, 2, 2),
|
||||||
|
# The locator object uses (x, y, z) convention for the
|
||||||
|
# voxel size and translation.
|
||||||
|
voxel_size=self.extents,
|
||||||
|
volume_translation=self.translation,
|
||||||
|
device=next(self.params.values()).device,
|
||||||
|
)
|
||||||
|
grid_values = self.voxel_grid.values_type(**self.params)
|
||||||
|
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||||
|
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||||
|
@ -19,6 +19,7 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
|
|||||||
CPFactorizedVoxelGrid,
|
CPFactorizedVoxelGrid,
|
||||||
FullResolutionVoxelGrid,
|
FullResolutionVoxelGrid,
|
||||||
VMFactorizedVoxelGrid,
|
VMFactorizedVoxelGrid,
|
||||||
|
VoxelGridModule,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||||
@ -198,6 +199,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
expand_args_fields(FullResolutionVoxelGrid)
|
expand_args_fields(FullResolutionVoxelGrid)
|
||||||
expand_args_fields(CPFactorizedVoxelGrid)
|
expand_args_fields(CPFactorizedVoxelGrid)
|
||||||
expand_args_fields(VMFactorizedVoxelGrid)
|
expand_args_fields(VMFactorizedVoxelGrid)
|
||||||
|
expand_args_fields(VoxelGridModule)
|
||||||
|
|
||||||
def _interpolate_1D(
|
def _interpolate_1D(
|
||||||
self, points: torch.Tensor, vectors: torch.Tensor
|
self, points: torch.Tensor, vectors: torch.Tensor
|
||||||
@ -585,3 +587,27 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
n_features=10,
|
n_features=10,
|
||||||
n_components=3,
|
n_components=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_voxel_grid_module_location(self, n_times=10):
|
||||||
|
"""
|
||||||
|
This checks the module uses locator correctly etc..
|
||||||
|
|
||||||
|
If we know that voxel grids work for (x, y, z) in local coordinates
|
||||||
|
to test if the VoxelGridModule does not have permuted dimensions we
|
||||||
|
create local coordinates, pass them through verified voxelgrids and
|
||||||
|
compare the result with the result that we get when we convert
|
||||||
|
coordinates to world and pass them through the VoxelGridModule
|
||||||
|
"""
|
||||||
|
for _ in range(n_times):
|
||||||
|
extents = tuple(torch.randint(1, 50, size=(3,)).tolist())
|
||||||
|
|
||||||
|
grid = VoxelGridModule(extents=extents)
|
||||||
|
local_point = torch.rand(1, 3) * 2 - 1
|
||||||
|
world_point = local_point * torch.tensor(extents) / 2
|
||||||
|
grid_values = grid.voxel_grid.values_type(**grid.params)
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
grid(world_point)[0, 0],
|
||||||
|
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
|
||||||
|
rtol=0.0001,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user