mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Summary: Simple wrapper around voxel grids to make them a module Reviewed By: bottler Differential Revision: D38829762 fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
513 lines
21 KiB
Python
513 lines
21 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
This file contains classes that implement Voxel grids, both in their full resolution
|
|
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
|
|
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
|
|
TensoRF (https://arxiv.org/abs/2203.09517) paper.
|
|
|
|
In addition, the module VoxelGridModule implements a trainable instance of one of
|
|
these classes.
|
|
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import ClassVar, Dict, Optional, Tuple, Type
|
|
|
|
import torch
|
|
from pytorch3d.implicitron.tools.config import (
|
|
Configurable,
|
|
registry,
|
|
ReplaceableBase,
|
|
run_auto_creation,
|
|
)
|
|
from pytorch3d.structures.volumes import VolumeLocator
|
|
|
|
from .utils import interpolate_line, interpolate_plane, interpolate_volume
|
|
|
|
|
|
@dataclass
|
|
class VoxelGridValuesBase:
|
|
pass
|
|
|
|
|
|
class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|
"""
|
|
Base class for all the voxel grid variants whith added trilinear interpolation between
|
|
voxels (for example if voxel (0.333, 1, 3) is queried that would return the result
|
|
2/3*voxel[0, 1, 3] + 1/3*voxel[1, 1, 3])
|
|
|
|
Internally voxel grids are indexed by (features, x, y, z). If queried the point is not
|
|
inside the voxel grid the vector that will be returned is determined by padding.
|
|
|
|
Members:
|
|
align_corners: parameter used in torch.functional.grid_sample. For details go to
|
|
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html by
|
|
default is True
|
|
padding: padding mode for outside grid values 'zeros' | 'border' | 'reflection'.
|
|
Default is 'zeros'
|
|
mode: interpolation mode to calculate output values :
|
|
'bilinear' | 'nearest' | 'bicubic' | 'trilinear'.
|
|
Default: 'bilinear' Note: mode='bicubic' supports only FullResolutionVoxelGrid.
|
|
When mode='bilinear' and the input is 5-D, the interpolation mode used internally
|
|
will actually be trilinear.
|
|
n_features: number of dimensions of base feature vector. Determines how many features
|
|
the grid returns.
|
|
resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis.
|
|
"""
|
|
|
|
align_corners: bool = True
|
|
padding: str = "zeros"
|
|
mode: str = "bilinear"
|
|
n_features: int = 1
|
|
resolution: Tuple[int, int, int] = (64, 64, 64)
|
|
|
|
def __post_init__(self):
|
|
super().__init__()
|
|
|
|
def evaluate_world(
|
|
self,
|
|
points: torch.Tensor,
|
|
grid_values: VoxelGridValuesBase,
|
|
locator: VolumeLocator,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Evaluates the voxel grid at points in the world coordinate frame.
|
|
The interpolation type is determined by the `mode` member.
|
|
|
|
Arguments:
|
|
points (torch.Tensor): tensor of points that you want to query
|
|
of a form (n_grids, n_points, 3)
|
|
grid_values: an object of type Class.values_type which has tensors as
|
|
members which have shapes derived from the get_shapes() method
|
|
locator: a VolumeLocator object
|
|
Returns:
|
|
torch.Tensor: shape (n_grids, n_points, n_features)
|
|
"""
|
|
points_local = locator.world_to_local_coords(points)
|
|
# pyre-ignore[29]
|
|
return self.evaluate_local(points_local, grid_values)
|
|
|
|
def evaluate_local(
|
|
self, points: torch.Tensor, grid_values: VoxelGridValuesBase
|
|
) -> torch.Tensor:
|
|
"""
|
|
Evaluates the voxel grid at points in the local coordinate frame,
|
|
The interpolation type is determined by the `mode` member.
|
|
|
|
Arguments:
|
|
points (torch.Tensor): tensor of points that you want to query
|
|
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
|
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
|
|
as members which have shapes derived from the get_shapes() method
|
|
Returns:
|
|
torch.Tensor: shape (n_grids, n_points, n_features)
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def get_shapes(self) -> Dict[str, Tuple]:
|
|
"""
|
|
Using parameters from the __init__ method, this method returns the
|
|
shapes of individual tensors needed to run the evaluate method.
|
|
|
|
Returns:
|
|
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
|
|
replace the shapes in the dictionary with tensors of those shapes and add the
|
|
first 'batch' dimension. If the required shape is (a, b) and you want to
|
|
have g grids than the tensor that replaces the shape should have the
|
|
shape (g, a, b).
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
@dataclass
|
|
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
|
voxel_grid: torch.Tensor
|
|
|
|
|
|
@registry.register
|
|
class FullResolutionVoxelGrid(VoxelGridBase):
|
|
"""
|
|
Full resolution voxel grid equivalent to 4D tensor where shape is
|
|
(features, width, height, depth) with linear interpolation between voxels.
|
|
"""
|
|
|
|
# the type of grid_values argument needed to run evaluate_local()
|
|
values_type: ClassVar[Type[VoxelGridValuesBase]] = FullResolutionVoxelGridValues
|
|
|
|
def evaluate_local(
|
|
self, points: torch.Tensor, grid_values: FullResolutionVoxelGridValues
|
|
) -> torch.Tensor:
|
|
"""
|
|
Evaluates the voxel grid at points in the local coordinate frame,
|
|
The interpolation type is determined by the `mode` member.
|
|
|
|
Arguments:
|
|
points (torch.Tensor): tensor of points that you want to query
|
|
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
|
grid_values: an object of type values_type which has tensors as
|
|
members which have shapes derived from the get_shapes() method
|
|
Returns:
|
|
torch.Tensor: shape (n_grids, n_points, n_features)
|
|
"""
|
|
return interpolate_volume(
|
|
points,
|
|
grid_values.voxel_grid,
|
|
align_corners=self.align_corners,
|
|
padding_mode=self.padding,
|
|
mode=self.mode,
|
|
)
|
|
|
|
def get_shapes(self) -> Dict[str, Tuple]:
|
|
return {"voxel_grid": (self.n_features, *self.resolution)}
|
|
|
|
|
|
@dataclass
|
|
class CPFactorizedVoxelGridValues(VoxelGridValuesBase):
|
|
vector_components_x: torch.Tensor
|
|
vector_components_y: torch.Tensor
|
|
vector_components_z: torch.Tensor
|
|
basis_matrix: Optional[torch.Tensor] = None
|
|
|
|
|
|
@registry.register
|
|
class CPFactorizedVoxelGrid(VoxelGridBase):
|
|
"""
|
|
Canonical Polyadic (CP/CANDECOMP/PARAFAC) Factorization factorizes the 3d grid into three
|
|
vectors (x, y, z). For n_components=n, the 3d grid is a sum of the two outer products
|
|
(call it ⊗) of each vector type (x, y, z):
|
|
|
|
3d_grid = x0 ⊗ y0 ⊗ z0 + x1 ⊗ y1 ⊗ z1 + ... + xn ⊗ yn ⊗ zn
|
|
|
|
These tensors are passed in a object of CPFactorizedVoxelGridValues (here obj) as
|
|
obj.vector_components_x, obj.vector_components_y, obj.vector_components_z. Their shapes are
|
|
`(n_components, r)` where `r` is the relevant resolution.
|
|
|
|
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
|
|
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
|
|
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
|
|
of different components are summed together to create (n_grids, n_components, 1) tensor.
|
|
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
|
|
n_features as F, n_components as C and n_grids as G:
|
|
|
|
3d_grid = (x ⊗ y ⊗ z) @ basis # GWHDC x GCF -> GWHDF
|
|
|
|
The basis feature vectors are passed as obj.basis_matrix.
|
|
|
|
Members:
|
|
n_components: number of vector triplets, higher number gives better approximation.
|
|
matrix_reduction: how to transform components. If matrix_reduction=True result
|
|
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
|
|
basis_matrix of shape (n_grids, n_components, n_features). If
|
|
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
|
is summed along the rows to get (n_grids, n_points, 1).
|
|
"""
|
|
|
|
# the type of grid_values argument needed to run evaluate_local()
|
|
values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues
|
|
|
|
n_components: int = 24
|
|
matrix_reduction: bool = True
|
|
|
|
def evaluate_local(
|
|
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
|
|
) -> torch.Tensor:
|
|
def factor(i):
|
|
axis = ["x", "y", "z"][i]
|
|
index = points[..., i, None]
|
|
vector = getattr(grid_values, "vector_components_" + axis)
|
|
return interpolate_line(
|
|
index,
|
|
vector,
|
|
align_corners=self.align_corners,
|
|
padding_mode=self.padding,
|
|
mode=self.mode,
|
|
)
|
|
|
|
# collect points from all the vectors and multipy them out
|
|
mult = factor(0) * factor(1) * factor(2)
|
|
|
|
# reduce the result from
|
|
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
|
if grid_values.basis_matrix is not None:
|
|
# (n_grids, n_points, n_features) =
|
|
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
|
|
return torch.bmm(mult, grid_values.basis_matrix)
|
|
|
|
return mult.sum(axis=-1, keepdim=True)
|
|
|
|
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
|
|
if self.matrix_reduction is False and self.n_features != 1:
|
|
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
|
|
|
|
shape_dict = {
|
|
"vector_components_x": (self.n_components, self.resolution[0]),
|
|
"vector_components_y": (self.n_components, self.resolution[1]),
|
|
"vector_components_z": (self.n_components, self.resolution[2]),
|
|
}
|
|
if self.matrix_reduction:
|
|
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
|
|
return shape_dict
|
|
|
|
|
|
@dataclass
|
|
class VMFactorizedVoxelGridValues(VoxelGridValuesBase):
|
|
vector_components_x: torch.Tensor
|
|
vector_components_y: torch.Tensor
|
|
vector_components_z: torch.Tensor
|
|
matrix_components_xy: torch.Tensor
|
|
matrix_components_yz: torch.Tensor
|
|
matrix_components_xz: torch.Tensor
|
|
basis_matrix: Optional[torch.Tensor] = None
|
|
|
|
|
|
@registry.register
|
|
class VMFactorizedVoxelGrid(VoxelGridBase):
|
|
"""
|
|
Implementation of Vector-Matrix Factorization of a tensor from
|
|
https://arxiv.org/abs/2203.09517.
|
|
|
|
Vector-Matrix Factorization factorizes the 3d grid into three matrices
|
|
(xy, xz, yz) and three vectors (x, y, z). For n_components=1, the 3d grid
|
|
is a sum of the outer products (call it ⊗) of each matrix with its
|
|
complementary vector:
|
|
|
|
3d_grid = xy ⊗ z + xz ⊗ y + yz ⊗ x.
|
|
|
|
These tensors are passed in a VMFactorizedVoxelGridValues object (here obj)
|
|
as obj.matrix_components_xy, obj.matrix_components_xy, obj.vector_components_y, etc.
|
|
|
|
Their shapes are `(n_grids, n_components, r0, r1)` for matrix_components and
|
|
(n_grids, n_components, r2)` for vector_componenets. Each of `r0, r1 and r2` coresponds
|
|
to one resolution in (width, height and depth).
|
|
|
|
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
|
|
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
|
|
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
|
|
of different components are summed together to create (n_grids, n_components, 1) tensor.
|
|
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
|
|
n_features as F, n_components as C (which can differ for each dimension) and n_grids as G:
|
|
|
|
3d_grid = concat((xy ⊗ z), (xz ⊗ y).permute(0, 2, 1),
|
|
(yz ⊗ x).permute(2, 0, 1)) @ basis_matrix # GWHDC x GCF -> GWHDF
|
|
|
|
Members:
|
|
n_components: total number of matrix vector pairs, this must be divisible by 3. Set
|
|
this if you want to have equal representational power in all 3 directions. You
|
|
must specify either n_components or distribution_of_components, you cannot
|
|
specify both.
|
|
distribution_of_components: if you do not want equal representational power in
|
|
all 3 directions specify a tuple of numbers of matrix_vector pairs for each
|
|
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
|
|
either n_components or distribution_of_components, you cannot specify both.
|
|
matrix_reduction: how to transform components. If matrix_reduction=True result
|
|
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
|
|
the basis_matrix of shape (n_grids, n_components, n_features). If
|
|
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
|
is summed along the rows to get (n_grids, n_points, 1).
|
|
"""
|
|
|
|
# the type of grid_values argument needed to run evaluate_local()
|
|
values_type: ClassVar[Type[VoxelGridValuesBase]] = VMFactorizedVoxelGridValues
|
|
|
|
n_components: Optional[int] = None
|
|
distribution_of_components: Optional[Tuple[int, int, int]] = None
|
|
matrix_reduction: bool = True
|
|
|
|
def evaluate_local(
|
|
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
|
|
) -> torch.Tensor:
|
|
# collect points from matrices and vectors and multiply them
|
|
a = interpolate_plane(
|
|
points[..., :2],
|
|
grid_values.matrix_components_xy,
|
|
align_corners=self.align_corners,
|
|
padding_mode=self.padding,
|
|
mode=self.mode,
|
|
) * interpolate_line(
|
|
points[..., 2:],
|
|
grid_values.vector_components_z,
|
|
align_corners=self.align_corners,
|
|
padding_mode=self.padding,
|
|
mode=self.mode,
|
|
)
|
|
b = interpolate_plane(
|
|
points[..., [0, 2]],
|
|
grid_values.matrix_components_xz,
|
|
align_corners=self.align_corners,
|
|
padding_mode=self.padding,
|
|
mode=self.mode,
|
|
) * interpolate_line(
|
|
points[..., 1:2],
|
|
grid_values.vector_components_y,
|
|
align_corners=self.align_corners,
|
|
padding_mode=self.padding,
|
|
mode=self.mode,
|
|
)
|
|
c = interpolate_plane(
|
|
points[..., 1:],
|
|
grid_values.matrix_components_yz,
|
|
align_corners=self.align_corners,
|
|
padding_mode=self.padding,
|
|
mode=self.mode,
|
|
) * interpolate_line(
|
|
points[..., :1],
|
|
grid_values.vector_components_x,
|
|
align_corners=self.align_corners,
|
|
padding_mode=self.padding,
|
|
mode=self.mode,
|
|
)
|
|
# pyre-ignore[28]
|
|
feats = torch.cat((a, b, c), axis=-1)
|
|
|
|
# reduce the result from
|
|
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
|
if grid_values.basis_matrix is not None:
|
|
# (n_grids, n_points, n_features) =
|
|
# (n_grids, n_points, total_n_components) x
|
|
# (n_grids, total_n_components, n_features)
|
|
return torch.bmm(feats, grid_values.basis_matrix)
|
|
# pyre-ignore[28]
|
|
return feats.sum(axis=-1, keepdim=True)
|
|
|
|
def get_shapes(self) -> Dict[str, Tuple]:
|
|
if self.matrix_reduction is False and self.n_features != 1:
|
|
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
|
|
if self.distribution_of_components is None and self.n_components is None:
|
|
raise ValueError(
|
|
"You need to provide n_components or distribution_of_components"
|
|
)
|
|
if (
|
|
self.distribution_of_components is not None
|
|
and self.n_components is not None
|
|
):
|
|
raise ValueError(
|
|
"You cannot define n_components and distribution_of_components"
|
|
)
|
|
# pyre-ignore[58]
|
|
if self.distribution_of_components is None and self.n_components % 3 != 0:
|
|
raise ValueError("n_components must be divisible by 3")
|
|
if self.distribution_of_components is None:
|
|
# pyre-ignore[58]
|
|
calculated_distribution_of_components = [
|
|
self.n_components // 3 for _ in range(3)
|
|
]
|
|
else:
|
|
calculated_distribution_of_components = self.distribution_of_components
|
|
|
|
shape_dict = {
|
|
"vector_components_x": (
|
|
calculated_distribution_of_components[1],
|
|
self.resolution[0],
|
|
),
|
|
"vector_components_y": (
|
|
calculated_distribution_of_components[2],
|
|
self.resolution[1],
|
|
),
|
|
"vector_components_z": (
|
|
calculated_distribution_of_components[0],
|
|
self.resolution[2],
|
|
),
|
|
"matrix_components_xy": (
|
|
calculated_distribution_of_components[0],
|
|
self.resolution[0],
|
|
self.resolution[1],
|
|
),
|
|
"matrix_components_yz": (
|
|
calculated_distribution_of_components[1],
|
|
self.resolution[1],
|
|
self.resolution[2],
|
|
),
|
|
"matrix_components_xz": (
|
|
calculated_distribution_of_components[2],
|
|
self.resolution[0],
|
|
self.resolution[2],
|
|
),
|
|
}
|
|
if self.matrix_reduction:
|
|
shape_dict["basis_matrix"] = (
|
|
sum(calculated_distribution_of_components),
|
|
self.n_features,
|
|
)
|
|
|
|
return shape_dict
|
|
|
|
|
|
class VoxelGridModule(Configurable, torch.nn.Module):
|
|
"""
|
|
A wrapper torch.nn.Module for the VoxelGrid classes, which
|
|
contains parameters that are needed to train the VoxelGrid classes.
|
|
|
|
Members:
|
|
voxel_grid_class_type: The name of the class to use for voxel_grid,
|
|
which must be available in the registry. Default FullResolutionVoxelGrid.
|
|
voxel_grid: An instance of `VoxelGridBase`. This is the object which
|
|
this class wraps.
|
|
extents: 3-tuple of a form (width, height, depth), denotes the size of the grid
|
|
in world units.
|
|
translation: 3-tuple of float. The center of the volume in world units as (x, y, z).
|
|
init_std: Parameters are initialized using the gaussian distribution
|
|
with mean=init_mean and std=init_std. Default 0.1
|
|
init_mean: Parameters are initialized using the gaussian distribution
|
|
with mean=init_mean and std=init_std. Default 0.
|
|
"""
|
|
|
|
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
|
voxel_grid: VoxelGridBase
|
|
|
|
extents: Tuple[float, float, float] = 1.0
|
|
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
|
|
|
init_std: float = 0.1
|
|
init_mean: float = 0
|
|
|
|
def __post_init__(self):
|
|
super().__init__()
|
|
run_auto_creation(self)
|
|
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
|
|
shapes = self.voxel_grid.get_shapes()
|
|
params = {
|
|
name: torch.normal(
|
|
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
|
|
std=self.init_std,
|
|
)
|
|
for name, shape in shapes.items()
|
|
}
|
|
self.params = torch.nn.ParameterDict(params)
|
|
|
|
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Evaluates points in the world coordinate frame on the voxel_grid.
|
|
|
|
Args:
|
|
points (torch.Tensor): tensor of points that you want to query
|
|
of a form (n_points, 3)
|
|
Returns:
|
|
torch.Tensor of shape (n_points, n_features)
|
|
"""
|
|
locator = VolumeLocator(
|
|
batch_size=1,
|
|
# The resolution of the voxel grid does not need to be known
|
|
# to the locator object. It is easiest to fix the resolution of the locator.
|
|
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
|
|
# desired size. The locator object uses (z, y, x) convention for the grid_size,
|
|
# and this module uses (x, y, z) convention so the order has to be reversed
|
|
# (irrelevant in this case since they are all equal).
|
|
# It is (2, 2, 2) because the VolumeLocator object behaves like
|
|
# align_corners=True, which means that the points are in the corners of
|
|
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
|
|
grid_sizes=(2, 2, 2),
|
|
# The locator object uses (x, y, z) convention for the
|
|
# voxel size and translation.
|
|
voxel_size=self.extents,
|
|
volume_translation=self.translation,
|
|
device=next(self.params.values()).device,
|
|
)
|
|
grid_values = self.voxel_grid.values_type(**self.params)
|
|
# voxel grids operate with extra n_grids dimension, which we fix to one
|
|
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|