mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
voxel grids with interpolation
Summary: Added voxel grid classes from TensoRF, both in their factorized (CP and VM) and full form. Reviewed By: bottler Differential Revision: D38465419 fbshipit-source-id: 8b306338af58dc50ef47a682616022a0512c0047
This commit is contained in:
parent
af799facdd
commit
edee25a1e5
@ -7,6 +7,8 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.common.compat import prod
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
@ -88,3 +90,98 @@ def create_embeddings_for_implicit_function(
|
||||
embeds = broadcast_global_code(embeds, global_code)
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
def interpolate_line(
|
||||
points: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Linearly interpolates values of source grids. The first dimension of points represents
|
||||
number of points and the second coordinate, for example ([[x0], [x1], ...]). The first
|
||||
dimension of argument source represents feature and ones after that the spatial
|
||||
dimension.
|
||||
|
||||
Arguments:
|
||||
points: shape (n_grids, n_points, 1),
|
||||
source: tensor of shape (n_grids, features, width),
|
||||
Returns:
|
||||
interpolated tensor of shape (n_grids, n_points, features)
|
||||
"""
|
||||
# To enable sampling of the source using the torch.functional.grid_sample
|
||||
# points need to have 2 coordinates.
|
||||
expansion = points.new_zeros(points.shape)
|
||||
points = torch.cat((points, expansion), dim=-1)
|
||||
|
||||
source = source[:, :, None, :]
|
||||
points = points[:, :, None, :]
|
||||
|
||||
out = F.grid_sample(
|
||||
grid=points,
|
||||
input=source,
|
||||
**kwargs,
|
||||
)
|
||||
return out[:, :, :, 0].permute(0, 2, 1)
|
||||
|
||||
|
||||
def interpolate_plane(
|
||||
points: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Bilinearly interpolates values of source grids. The first dimension of points represents
|
||||
number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]).
|
||||
The first dimension of argument source represents feature and ones after that the
|
||||
spatial dimension.
|
||||
|
||||
Arguments:
|
||||
points: shape (n_grids, n_points, 2),
|
||||
source: tensor of shape (n_grids, features, width, height),
|
||||
Returns:
|
||||
interpolated tensor of shape (n_grids, n_points, features)
|
||||
"""
|
||||
# permuting because torch.nn.functional.grid_sample works with
|
||||
# (features, height, width) and not
|
||||
# (features, width, height)
|
||||
source = source.permute(0, 1, 3, 2)
|
||||
points = points[:, :, None, :]
|
||||
|
||||
out = F.grid_sample(
|
||||
grid=points,
|
||||
input=source,
|
||||
**kwargs,
|
||||
)
|
||||
return out[:, :, :, 0].permute(0, 2, 1)
|
||||
|
||||
|
||||
def interpolate_volume(
|
||||
points: torch.Tensor, source: torch.Tensor, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Interpolates values of source grids. The first dimension of points represents
|
||||
number of points and the second coordinates, for example
|
||||
[[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features
|
||||
and ones after that the spatial dimension.
|
||||
|
||||
Arguments:
|
||||
points: shape (n_grids, n_points, 3),
|
||||
source: tensor of shape (n_grids, features, width, height, depth),
|
||||
Returns:
|
||||
interpolated tensor of shape (n_grids, n_points, features)
|
||||
"""
|
||||
if "mode" in kwargs and kwargs["mode"] == "trilinear":
|
||||
kwargs = kwargs.copy()
|
||||
kwargs["mode"] = "bilinear"
|
||||
# permuting because torch.nn.functional.grid_sample works with
|
||||
# (features, depth, height, width) and not (features, width, height, depth)
|
||||
source = source.permute(0, 1, 4, 3, 2)
|
||||
grid = points[:, :, None, None, :]
|
||||
|
||||
out = F.grid_sample(
|
||||
grid=grid,
|
||||
input=source,
|
||||
**kwargs,
|
||||
)
|
||||
return out[:, :, :, 0, 0].permute(0, 2, 1)
|
||||
|
428
pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Normal file
428
pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Normal file
@ -0,0 +1,428 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This file contains classes that implement Voxel grids, both in their full resolution
|
||||
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
|
||||
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
|
||||
https://arxiv.org/abs/2203.09517.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.structures.volumes import VolumeLocator
|
||||
|
||||
from .utils import interpolate_line, interpolate_plane, interpolate_volume
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoxelGridValuesBase:
|
||||
pass
|
||||
|
||||
|
||||
class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Base class for all the voxel grid variants whith added trilinear interpolation between
|
||||
voxels (for example if voxel (0.333, 1, 3) is queried that would return the result
|
||||
2/3*voxel[0, 1, 3] + 1/3*voxel[1, 1, 3])
|
||||
|
||||
Internally voxel grids are indexed by (features, x, y, z). If queried the point is not
|
||||
inside the voxel grid the vector that will be returned is determined by padding.
|
||||
|
||||
Members:
|
||||
align_corners: parameter used in torch.functional.grid_sample. For details go to
|
||||
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html by
|
||||
default is True
|
||||
padding: padding mode for outside grid values 'zeros' | 'border' | 'reflection'.
|
||||
Default is 'zeros'
|
||||
mode: interpolation mode to calculate output values :
|
||||
'bilinear' | 'nearest' | 'bicubic' | 'trilinear'.
|
||||
Default: 'bilinear' Note: mode='bicubic' supports only FullResolutionVoxelGrid.
|
||||
When mode='bilinear' and the input is 5-D, the interpolation mode used internally
|
||||
will actually be trilinear.
|
||||
n_features: number of dimensions of base feature vector. Determines how many features
|
||||
the grid returns.
|
||||
resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis.
|
||||
"""
|
||||
|
||||
align_corners: bool = True
|
||||
padding: str = "zeros"
|
||||
mode: str = "bilinear"
|
||||
n_features: int = 1
|
||||
resolution: Tuple[int, int, int] = (64, 64, 64)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def evaluate_world(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
grid_values: VoxelGridValuesBase,
|
||||
locator: VolumeLocator,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Evaluates the voxel grid at points in the world coordinate frame.
|
||||
The interpolation type is determined by the `mode` member.
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_grids, n_points, 3)
|
||||
grid_values: an object of type Class.values_type which has tensors as
|
||||
members which have shapes derived from the get_shapes() method
|
||||
locator: a VolumeLocator object
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
"""
|
||||
points_local = locator.world_to_local_coords(points)
|
||||
# pyre-ignore[29]
|
||||
return self.evaluate_local(points_local, grid_values)
|
||||
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: VoxelGridValuesBase
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Evaluates the voxel grid at points in the local coordinate frame,
|
||||
The interpolation type is determined by the `mode` member.
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
||||
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
|
||||
as members which have shapes derived from the get_shapes() method
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
"""
|
||||
Using parameters from the __init__ method, this method returns the
|
||||
shapes of individual tensors needed to run the evaluate method.
|
||||
|
||||
Returns:
|
||||
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
|
||||
replace the shapes in the dictionary with tensors of those shapes and add the
|
||||
first 'batch' dimension. If the required shape is (a, b) and you want to
|
||||
have g grids than the tensor that replaces the shape should have the
|
||||
shape (g, a, b).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
||||
voxel_grid: torch.Tensor
|
||||
|
||||
|
||||
@registry.register
|
||||
class FullResolutionVoxelGrid(VoxelGridBase):
|
||||
"""
|
||||
Full resolution voxel grid equivalent to 4D tensor where shape is
|
||||
(features, width, height, depth) with linear interpolation between voxels.
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
values_type: ClassVar[Type[VoxelGridValuesBase]] = FullResolutionVoxelGridValues
|
||||
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: FullResolutionVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Evaluates the voxel grid at points in the local coordinate frame,
|
||||
The interpolation type is determined by the `mode` member.
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
||||
grid_values: an object of type values_type which has tensors as
|
||||
members which have shapes derived from the get_shapes() method
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
"""
|
||||
return interpolate_volume(
|
||||
points,
|
||||
grid_values.voxel_grid,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
return {"voxel_grid": (self.n_features, *self.resolution)}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPFactorizedVoxelGridValues(VoxelGridValuesBase):
|
||||
vector_components_x: torch.Tensor
|
||||
vector_components_y: torch.Tensor
|
||||
vector_components_z: torch.Tensor
|
||||
basis_matrix: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@registry.register
|
||||
class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
"""
|
||||
Canonical Polyadic (CP/CANDECOMP/PARAFAC) Factorization factorizes the 3d grid into three
|
||||
vectors (x, y, z). For n_components=n, the 3d grid is a sum of the two outer products
|
||||
(call it ⊗) of each vector type (x, y, z):
|
||||
|
||||
3d_grid = x0 ⊗ y0 ⊗ z0 + x1 ⊗ y1 ⊗ z1 + ... + xn ⊗ yn ⊗ zn
|
||||
|
||||
These tensors are passed in a object of CPFactorizedVoxelGridValues (here obj) as
|
||||
obj.vector_components_x, obj.vector_components_y, obj.vector_components_z. Their shapes are
|
||||
`(n_components, r)` where `r` is the relevant resolution.
|
||||
|
||||
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
|
||||
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
|
||||
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
|
||||
of different components are summed together to create (n_grids, n_components, 1) tensor.
|
||||
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
|
||||
n_features as F, n_components as C and n_grids as G:
|
||||
|
||||
3d_grid = (x ⊗ y ⊗ z) @ basis # GWHDC x GCF -> GWHDF
|
||||
|
||||
The basis feature vectors are passed as obj.basis_matrix.
|
||||
|
||||
Members:
|
||||
n_components: number of vector triplets, higher number gives better approximation.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
|
||||
basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
||||
is summed along the rows to get (n_grids, n_points, 1).
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues
|
||||
|
||||
n_components: int = 24
|
||||
matrix_reduction: bool = True
|
||||
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
def factor(i):
|
||||
axis = ["x", "y", "z"][i]
|
||||
index = points[..., i, None]
|
||||
vector = getattr(grid_values, "vector_components_" + axis)
|
||||
return interpolate_line(
|
||||
index,
|
||||
vector,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
# collect points from all the vectors and multipy them out
|
||||
mult = factor(0) * factor(1) * factor(2)
|
||||
|
||||
# reduce the result from
|
||||
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
||||
if grid_values.basis_matrix is not None:
|
||||
# (n_grids, n_points, n_features) =
|
||||
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
|
||||
return torch.bmm(mult, grid_values.basis_matrix)
|
||||
|
||||
return mult.sum(axis=-1, keepdim=True)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
|
||||
|
||||
shape_dict = {
|
||||
"vector_components_x": (self.n_components, self.resolution[0]),
|
||||
"vector_components_y": (self.n_components, self.resolution[1]),
|
||||
"vector_components_z": (self.n_components, self.resolution[2]),
|
||||
}
|
||||
if self.matrix_reduction:
|
||||
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
|
||||
return shape_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class VMFactorizedVoxelGridValues(VoxelGridValuesBase):
|
||||
vector_components_x: torch.Tensor
|
||||
vector_components_y: torch.Tensor
|
||||
vector_components_z: torch.Tensor
|
||||
matrix_components_xy: torch.Tensor
|
||||
matrix_components_yz: torch.Tensor
|
||||
matrix_components_xz: torch.Tensor
|
||||
basis_matrix: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@registry.register
|
||||
class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
"""
|
||||
Implementation of Vector-Matrix Factorization of a tensor from
|
||||
https://arxiv.org/abs/2203.09517.
|
||||
|
||||
Vector-Matrix Factorization factorizes the 3d grid into three matrices
|
||||
(xy, xz, yz) and three vectors (x, y, z). For n_components=1, the 3d grid
|
||||
is a sum of the outer products (call it ⊗) of each matrix with its
|
||||
complementary vector:
|
||||
|
||||
3d_grid = xy ⊗ z + xz ⊗ y + yz ⊗ x.
|
||||
|
||||
These tensors are passed in a VMFactorizedVoxelGridValues object (here obj)
|
||||
as obj.matrix_components_xy, obj.matrix_components_xy, obj.vector_components_y, etc.
|
||||
|
||||
Their shapes are `(n_grids, n_components, r0, r1)` for matrix_components and
|
||||
(n_grids, n_components, r2)` for vector_componenets. Each of `r0, r1 and r2` coresponds
|
||||
to one resolution in (width, height and depth).
|
||||
|
||||
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
|
||||
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
|
||||
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
|
||||
of different components are summed together to create (n_grids, n_components, 1) tensor.
|
||||
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
|
||||
n_features as F, n_components as C (which can differ for each dimension) and n_grids as G:
|
||||
|
||||
3d_grid = concat((xy ⊗ z), (xz ⊗ y).permute(0, 2, 1),
|
||||
(yz ⊗ x).permute(2, 0, 1)) @ basis_matrix # GWHDC x GCF -> GWHDF
|
||||
|
||||
Members:
|
||||
n_components: total number of matrix vector pairs, this must be divisible by 3. Set
|
||||
this if you want to have equal representational power in all 3 directions. You
|
||||
must specify either n_components or distribution_of_components, you cannot
|
||||
specify both.
|
||||
distribution_of_components: if you do not want equal representational power in
|
||||
all 3 directions specify a tuple of numbers of matrix_vector pairs for each
|
||||
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
|
||||
either n_components or distribution_of_components, you cannot specify both.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
|
||||
the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
||||
is summed along the rows to get (n_grids, n_points, 1).
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
values_type: ClassVar[Type[VoxelGridValuesBase]] = VMFactorizedVoxelGridValues
|
||||
|
||||
n_components: Optional[int] = None
|
||||
distribution_of_components: Optional[Tuple[int, int, int]] = None
|
||||
matrix_reduction: bool = True
|
||||
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
# collect points from matrices and vectors and multiply them
|
||||
a = interpolate_plane(
|
||||
points[..., :2],
|
||||
grid_values.matrix_components_xy,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
) * interpolate_line(
|
||||
points[..., 2:],
|
||||
grid_values.vector_components_z,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
b = interpolate_plane(
|
||||
points[..., [0, 2]],
|
||||
grid_values.matrix_components_xz,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
) * interpolate_line(
|
||||
points[..., 1:2],
|
||||
grid_values.vector_components_y,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
c = interpolate_plane(
|
||||
points[..., 1:],
|
||||
grid_values.matrix_components_yz,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
) * interpolate_line(
|
||||
points[..., :1],
|
||||
grid_values.vector_components_x,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
# pyre-ignore[28]
|
||||
feats = torch.cat((a, b, c), axis=-1)
|
||||
|
||||
# reduce the result from
|
||||
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
||||
if grid_values.basis_matrix is not None:
|
||||
# (n_grids, n_points, n_features) =
|
||||
# (n_grids, n_points, total_n_components) x
|
||||
# (n_grids, total_n_components, n_features)
|
||||
return torch.bmm(feats, grid_values.basis_matrix)
|
||||
# pyre-ignore[28]
|
||||
return feats.sum(axis=-1, keepdim=True)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
|
||||
if self.distribution_of_components is None and self.n_components is None:
|
||||
raise ValueError(
|
||||
"You need to provide n_components or distribution_of_components"
|
||||
)
|
||||
if (
|
||||
self.distribution_of_components is not None
|
||||
and self.n_components is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"You cannot define n_components and distribution_of_components"
|
||||
)
|
||||
# pyre-ignore[58]
|
||||
if self.distribution_of_components is None and self.n_components % 3 != 0:
|
||||
raise ValueError("n_components must be divisible by 3")
|
||||
if self.distribution_of_components is None:
|
||||
# pyre-ignore[58]
|
||||
calculated_distribution_of_components = [
|
||||
self.n_components // 3 for _ in range(3)
|
||||
]
|
||||
else:
|
||||
calculated_distribution_of_components = self.distribution_of_components
|
||||
|
||||
shape_dict = {
|
||||
"vector_components_x": (
|
||||
calculated_distribution_of_components[1],
|
||||
self.resolution[0],
|
||||
),
|
||||
"vector_components_y": (
|
||||
calculated_distribution_of_components[2],
|
||||
self.resolution[1],
|
||||
),
|
||||
"vector_components_z": (
|
||||
calculated_distribution_of_components[0],
|
||||
self.resolution[2],
|
||||
),
|
||||
"matrix_components_xy": (
|
||||
calculated_distribution_of_components[0],
|
||||
self.resolution[0],
|
||||
self.resolution[1],
|
||||
),
|
||||
"matrix_components_yz": (
|
||||
calculated_distribution_of_components[1],
|
||||
self.resolution[1],
|
||||
self.resolution[2],
|
||||
),
|
||||
"matrix_components_xz": (
|
||||
calculated_distribution_of_components[2],
|
||||
self.resolution[0],
|
||||
self.resolution[2],
|
||||
),
|
||||
}
|
||||
if self.matrix_reduction:
|
||||
shape_dict["basis_matrix"] = (
|
||||
sum(calculated_distribution_of_components),
|
||||
self.n_features,
|
||||
)
|
||||
|
||||
return shape_dict
|
587
tests/implicitron/test_voxel_grids.py
Normal file
587
tests/implicitron/test_voxel_grids.py
Normal file
@ -0,0 +1,587 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import unittest
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.models.implicit_function.utils import (
|
||||
interpolate_line,
|
||||
interpolate_plane,
|
||||
interpolate_volume,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
|
||||
CPFactorizedVoxelGrid,
|
||||
FullResolutionVoxelGrid,
|
||||
VMFactorizedVoxelGrid,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
Tests Voxel grids, tests them by setting all elements to zero (after retrieving
|
||||
they should also return zero) and by setting all of the elements to one and
|
||||
getting the result. Also tests the interpolation by 'manually' interpolating
|
||||
one by one sample and comparing with the batched implementation.
|
||||
"""
|
||||
|
||||
def test_my_code(self):
|
||||
return
|
||||
|
||||
def get_random_normalized_points(
|
||||
self, n_grids, n_points, dimension=3
|
||||
) -> torch.Tensor:
|
||||
# create random query points
|
||||
return torch.rand(n_grids, n_points, dimension) * 2 - 1
|
||||
|
||||
def _test_query_with_constant_init_cp(
|
||||
self,
|
||||
n_grids: int,
|
||||
n_features: int,
|
||||
n_components: int,
|
||||
resolution: Tuple[int],
|
||||
value: float = 1,
|
||||
n_points: int = 1,
|
||||
) -> None:
|
||||
# set everything to 'value' and do query for elementsthe result should
|
||||
# be of shape (n_grids, n_points, n_features) and be filled with n_components
|
||||
# * value
|
||||
grid = CPFactorizedVoxelGrid(
|
||||
resolution=resolution,
|
||||
n_components=n_components,
|
||||
n_features=n_features,
|
||||
)
|
||||
shapes = grid.get_shapes()
|
||||
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * n_components * value,
|
||||
)
|
||||
|
||||
def _test_query_with_constant_init_vm(
|
||||
self,
|
||||
n_grids: int,
|
||||
n_features: int,
|
||||
resolution: Tuple[int],
|
||||
n_components: Optional[int] = None,
|
||||
distribution: Optional[Tuple[int]] = None,
|
||||
value: float = 1,
|
||||
n_points: int = 1,
|
||||
) -> None:
|
||||
# set everything to 'value' and do query for elements
|
||||
grid = VMFactorizedVoxelGrid(
|
||||
n_features=n_features,
|
||||
resolution=resolution,
|
||||
n_components=n_components,
|
||||
distribution_of_components=distribution,
|
||||
)
|
||||
shapes = grid.get_shapes()
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
|
||||
expected_element = (
|
||||
n_components * value if distribution is None else sum(distribution) * value
|
||||
)
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * expected_element,
|
||||
)
|
||||
|
||||
def _test_query_with_constant_init_full(
|
||||
self,
|
||||
n_grids: int,
|
||||
n_features: int,
|
||||
resolution: Tuple[int],
|
||||
value: int = 1,
|
||||
n_points: int = 1,
|
||||
) -> None:
|
||||
# set everything to 'value' and do query for elements
|
||||
grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution)
|
||||
shapes = grid.get_shapes()
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
|
||||
expected_element = value
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * expected_element,
|
||||
)
|
||||
|
||||
def test_query_with_constant_init(self):
|
||||
with self.subTest("Full"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
|
||||
)
|
||||
with self.subTest("Full with 1 in dimensions"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
|
||||
)
|
||||
with self.subTest("CP"):
|
||||
self._test_query_with_constant_init_cp(
|
||||
n_grids=5,
|
||||
n_features=6,
|
||||
n_components=7,
|
||||
resolution=(3, 4, 5),
|
||||
n_points=2,
|
||||
)
|
||||
with self.subTest("CP with 1 in dimensions"):
|
||||
self._test_query_with_constant_init_cp(
|
||||
n_grids=2,
|
||||
n_features=1,
|
||||
n_components=3,
|
||||
resolution=(3, 1, 1),
|
||||
n_points=4,
|
||||
)
|
||||
with self.subTest("VM with symetric distribution"):
|
||||
self._test_query_with_constant_init_vm(
|
||||
n_grids=6,
|
||||
n_features=9,
|
||||
resolution=(2, 12, 2),
|
||||
n_components=12,
|
||||
n_points=3,
|
||||
)
|
||||
with self.subTest("VM with distribution"):
|
||||
self._test_query_with_constant_init_vm(
|
||||
n_grids=5,
|
||||
n_features=1,
|
||||
resolution=(5, 9, 7),
|
||||
distribution=(33, 41, 1),
|
||||
n_points=7,
|
||||
)
|
||||
|
||||
def test_query_with_zero_init(self):
|
||||
with self.subTest("Query testing with zero init CPFactorizedVoxelGrid"):
|
||||
self._test_query_with_constant_init_cp(
|
||||
n_grids=5,
|
||||
n_features=6,
|
||||
n_components=7,
|
||||
resolution=(3, 2, 5),
|
||||
n_points=3,
|
||||
value=0,
|
||||
)
|
||||
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
|
||||
self._test_query_with_constant_init_vm(
|
||||
n_grids=2,
|
||||
n_features=9,
|
||||
resolution=(2, 11, 3),
|
||||
n_components=3,
|
||||
n_points=3,
|
||||
value=0,
|
||||
)
|
||||
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
expand_args_fields(FullResolutionVoxelGrid)
|
||||
expand_args_fields(CPFactorizedVoxelGrid)
|
||||
expand_args_fields(VMFactorizedVoxelGrid)
|
||||
|
||||
def _interpolate_1D(
|
||||
self, points: torch.Tensor, vectors: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
interpolate vector from points, which are (batch, 1) and individual point is in [-1, 1]
|
||||
"""
|
||||
result = []
|
||||
_, _, width = vectors.shape
|
||||
# transform from [-1, 1] to [0, width-1]
|
||||
points = (points + 1) / 2 * (width - 1)
|
||||
for vector, row in zip(vectors, points):
|
||||
newrow = []
|
||||
for x in row:
|
||||
xf, xc = int(torch.floor(x)), int(torch.ceil(x))
|
||||
itemf, itemc = vector[:, xf], vector[:, xc]
|
||||
tmp = itemf * (xc - x) + itemc * (x - xf)
|
||||
newrow.append(tmp[None, None, :])
|
||||
result.append(torch.cat(newrow, dim=1))
|
||||
return torch.cat(result)
|
||||
|
||||
def _interpolate_2D(
|
||||
self, points: torch.Tensor, matrices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
interpolate matrix from points, which are (batch, 2) and individual point is in [-1, 1]
|
||||
"""
|
||||
result = []
|
||||
n_grids, _, width, height = matrices.shape
|
||||
points = (points + 1) / 2 * (torch.tensor([[[width, height]]]) - 1)
|
||||
for matrix, row in zip(matrices, points):
|
||||
newrow = []
|
||||
for x, y in row:
|
||||
xf, xc = int(torch.floor(x)), int(torch.ceil(x))
|
||||
yf, yc = int(torch.floor(y)), int(torch.ceil(y))
|
||||
itemff, itemfc = matrix[:, xf, yf], matrix[:, xf, yc]
|
||||
itemcf, itemcc = matrix[:, xc, yf], matrix[:, xc, yc]
|
||||
itemf = itemff * (xc - x) + itemcf * (x - xf)
|
||||
itemc = itemfc * (xc - x) + itemcc * (x - xf)
|
||||
tmp = itemf * (yc - y) + itemc * (y - yf)
|
||||
newrow.append(tmp[None, None, :])
|
||||
result.append(torch.cat(newrow, dim=1))
|
||||
return torch.cat(result)
|
||||
|
||||
def _interpolate_3D(
|
||||
self, points: torch.Tensor, tensors: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
interpolate tensors from points, which are (batch, 3) and individual point is in [-1, 1]
|
||||
"""
|
||||
result = []
|
||||
_, _, width, height, depth = tensors.shape
|
||||
batch_normalized_points = (
|
||||
(points + 1) / 2 * (torch.tensor([[[width, height, depth]]]) - 1)
|
||||
)
|
||||
batch_points = points
|
||||
|
||||
for tensor, points, normalized_points in zip(
|
||||
tensors, batch_points, batch_normalized_points
|
||||
):
|
||||
newrow = []
|
||||
for (x, y, z), (_, _, nz) in zip(points, normalized_points):
|
||||
zf, zc = int(torch.floor(nz)), int(torch.ceil(nz))
|
||||
itemf = self._interpolate_2D(
|
||||
points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zf]
|
||||
)
|
||||
itemc = self._interpolate_2D(
|
||||
points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zc]
|
||||
)
|
||||
tmp = self._interpolate_1D(
|
||||
points=torch.tensor([[[z]]]),
|
||||
vectors=torch.cat((itemf, itemc), dim=1).permute(0, 2, 1),
|
||||
)
|
||||
newrow.append(tmp)
|
||||
result.append(torch.cat(newrow, dim=1))
|
||||
return torch.cat(result)
|
||||
|
||||
def test_interpolation(self):
|
||||
|
||||
with self.subTest("1D interpolation"):
|
||||
points = self.get_random_normalized_points(
|
||||
n_grids=4, n_points=5, dimension=1
|
||||
)
|
||||
vector = torch.randn(size=(4, 3, 2))
|
||||
assert torch.allclose(
|
||||
self._interpolate_1D(points, vector),
|
||||
interpolate_line(
|
||||
points,
|
||||
vector,
|
||||
align_corners=True,
|
||||
padding_mode="zeros",
|
||||
mode="bilinear",
|
||||
),
|
||||
)
|
||||
with self.subTest("2D interpolation"):
|
||||
points = self.get_random_normalized_points(
|
||||
n_grids=4, n_points=5, dimension=2
|
||||
)
|
||||
matrix = torch.randn(size=(4, 2, 3, 5))
|
||||
assert torch.allclose(
|
||||
self._interpolate_2D(points, matrix),
|
||||
interpolate_plane(
|
||||
points,
|
||||
matrix,
|
||||
align_corners=True,
|
||||
padding_mode="zeros",
|
||||
mode="bilinear",
|
||||
),
|
||||
)
|
||||
|
||||
with self.subTest("3D interpolation"):
|
||||
points = self.get_random_normalized_points(
|
||||
n_grids=4, n_points=5, dimension=3
|
||||
)
|
||||
tensor = torch.randn(size=(4, 5, 2, 7, 2))
|
||||
assert torch.allclose(
|
||||
self._interpolate_3D(points, tensor),
|
||||
interpolate_volume(
|
||||
points,
|
||||
tensor,
|
||||
align_corners=True,
|
||||
padding_mode="zeros",
|
||||
mode="bilinear",
|
||||
),
|
||||
)
|
||||
|
||||
def test_floating_point_query(self):
|
||||
"""
|
||||
test querying the voxel grids on some float positions
|
||||
"""
|
||||
with self.subTest("FullResolution"):
|
||||
grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1))
|
||||
params = grid.values_type(**grid.get_shapes())
|
||||
params.voxel_grid = torch.tensor(
|
||||
[
|
||||
[
|
||||
[[[1, 3], [5, 7]], [[9, 11], [13, 15]]],
|
||||
[[[2, 4], [6, 8]], [[10, 12], [14, 16]]],
|
||||
],
|
||||
[
|
||||
[[[17, 18], [19, 20]], [[21, 22], [23, 24]]],
|
||||
[[[25, 26], [27, 28]], [[29, 30], [31, 32]]],
|
||||
],
|
||||
],
|
||||
dtype=torch.float,
|
||||
)
|
||||
points = (
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[1, 0, 1],
|
||||
[0.5, 1, 1],
|
||||
[1 / 3, 1 / 3, 2 / 3],
|
||||
],
|
||||
[
|
||||
[0, 1, 1],
|
||||
[0, 0.5, 1],
|
||||
[1 / 4, 1 / 4, 3 / 4],
|
||||
],
|
||||
]
|
||||
)
|
||||
/ torch.tensor([[1.0, 1, 1]])
|
||||
* 2
|
||||
- 1
|
||||
)
|
||||
expected_result = torch.tensor(
|
||||
[
|
||||
[[11, 12], [11, 12], [6.333333, 7.3333333]],
|
||||
[[20, 28], [19, 27], [19.25, 27.25]],
|
||||
]
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result,
|
||||
rtol=0.00001,
|
||||
), grid.evaluate_local(points, params)
|
||||
with self.subTest("CP"):
|
||||
grid = CPFactorizedVoxelGrid(
|
||||
n_features=1, resolution=(1, 1, 1), n_components=3
|
||||
)
|
||||
params = grid.values_type(**grid.get_shapes())
|
||||
params.vector_components_x = torch.tensor(
|
||||
[
|
||||
[[1, 2], [10.5, 20.5]],
|
||||
[[10, 20], [2, 4]],
|
||||
]
|
||||
)
|
||||
params.vector_components_y = torch.tensor(
|
||||
[
|
||||
[[3, 4, 5], [30.5, 40.5, 50.5]],
|
||||
[[30, 40, 50], [1, 3, 5]],
|
||||
]
|
||||
)
|
||||
params.vector_components_z = torch.tensor(
|
||||
[
|
||||
[[6, 7, 8, 9], [60.5, 70.5, 80.5, 90.5]],
|
||||
[[60, 70, 80, 90], [6, 7, 8, 9]],
|
||||
]
|
||||
)
|
||||
params.basis_matrix = torch.tensor(
|
||||
[
|
||||
[[2.0], [2.0]],
|
||||
[[1.0], [2.0]],
|
||||
]
|
||||
)
|
||||
points = (
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[0, 2, 2],
|
||||
[1, 2, 0.25],
|
||||
[0.5, 0.5, 1],
|
||||
[1 / 3, 2 / 3, 2 + 1 / 3],
|
||||
],
|
||||
[
|
||||
[1, 0, 1],
|
||||
[0.5, 2, 2],
|
||||
[1, 0.5, 0.5],
|
||||
[1 / 4, 3 / 4, 2 + 1 / 4],
|
||||
],
|
||||
]
|
||||
)
|
||||
/ torch.tensor([[[1.0, 2, 3]]])
|
||||
* 2
|
||||
- 1
|
||||
)
|
||||
expected_result_matrix = torch.tensor(
|
||||
[
|
||||
[[85450.25], [130566.5], [77658.75], [86285.422]],
|
||||
[[42056], [60240], [45604], [38775]],
|
||||
]
|
||||
)
|
||||
expected_result_sum = torch.tensor(
|
||||
[
|
||||
[[42725.125], [65283.25], [38829.375], [43142.711]],
|
||||
[[42028], [60120], [45552], [38723.4375]],
|
||||
]
|
||||
)
|
||||
with self.subTest("CP with basis_matrix reduction"):
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result_matrix,
|
||||
rtol=0.00001,
|
||||
)
|
||||
del params.basis_matrix
|
||||
with self.subTest("CP with sum reduction"):
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result_sum,
|
||||
rtol=0.00001,
|
||||
)
|
||||
|
||||
with self.subTest("VM"):
|
||||
grid = VMFactorizedVoxelGrid(
|
||||
n_features=1, resolution=(1, 1, 1), n_components=3
|
||||
)
|
||||
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes())
|
||||
params.matrix_components_xy = torch.tensor(
|
||||
[
|
||||
[[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]],
|
||||
[[[35, 36], [37, 38]], [[39, 40], [41, 42]]],
|
||||
]
|
||||
)
|
||||
params.matrix_components_xz = torch.tensor(
|
||||
[
|
||||
[[[7, 8], [9, 10]], [[25, 26], [27, 28.0]]],
|
||||
[[[43, 44], [45, 46]], [[47, 48], [49, 50]]],
|
||||
]
|
||||
)
|
||||
params.matrix_components_yz = torch.tensor(
|
||||
[
|
||||
[[[13, 14], [15, 16]], [[31, 32], [33, 34.0]]],
|
||||
[[[51, 52], [53, 54]], [[55, 56], [57, 58.0]]],
|
||||
]
|
||||
)
|
||||
|
||||
params.vector_components_z = torch.tensor(
|
||||
[
|
||||
[[5, 6], [23, 24.0]],
|
||||
[[59, 60], [61, 62]],
|
||||
]
|
||||
)
|
||||
params.vector_components_y = torch.tensor(
|
||||
[
|
||||
[[11, 12], [29, 30.0]],
|
||||
[[63, 64], [65, 66]],
|
||||
]
|
||||
)
|
||||
params.vector_components_x = torch.tensor(
|
||||
[
|
||||
[[17, 18], [35, 36.0]],
|
||||
[[67, 68], [69, 70.0]],
|
||||
]
|
||||
)
|
||||
|
||||
params.basis_matrix = torch.tensor(
|
||||
[
|
||||
[2, 2, 2, 2, 2, 2.0],
|
||||
[1, 2, 1, 2, 1, 2.0],
|
||||
]
|
||||
)[:, :, None]
|
||||
points = (
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[1, 0, 1],
|
||||
[0.5, 1, 1],
|
||||
[1 / 3, 1 / 3, 2 / 3],
|
||||
],
|
||||
[
|
||||
[0, 1, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0],
|
||||
],
|
||||
]
|
||||
)
|
||||
/ torch.tensor([[[1.0, 1, 1]]])
|
||||
* 2
|
||||
- 1
|
||||
)
|
||||
expected_result_matrix = torch.tensor(
|
||||
[
|
||||
[[5696], [5854], [5484.888]],
|
||||
[[27377], [26649], [27377]],
|
||||
]
|
||||
)
|
||||
expected_result_sum = torch.tensor(
|
||||
[
|
||||
[[2848], [2927], [2742.444]],
|
||||
[[17902], [17420], [17902]],
|
||||
]
|
||||
)
|
||||
with self.subTest("VM with basis_matrix reduction"):
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result_matrix,
|
||||
rtol=0.00001,
|
||||
)
|
||||
del params.basis_matrix
|
||||
with self.subTest("VM with sum reduction"):
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result_sum,
|
||||
rtol=0.0001,
|
||||
), grid.evaluate_local(points, params)
|
||||
|
||||
def test_forward_with_small_init_std(self):
|
||||
"""
|
||||
Test does the grid return small values if it is initialized with small
|
||||
mean and small standard deviation.
|
||||
"""
|
||||
|
||||
def test(cls, **kwargs):
|
||||
with self.subTest(cls.__name__):
|
||||
n_grids = 3
|
||||
grid = cls(**kwargs)
|
||||
shapes = grid.get_shapes()
|
||||
params = cls.values_type(
|
||||
**{
|
||||
k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001)
|
||||
for k, shape in shapes.items()
|
||||
}
|
||||
)
|
||||
points = self.get_random_normalized_points(n_grids=n_grids, n_points=3)
|
||||
max_expected_result = torch.zeros((len(points), 10)) + 1e-2
|
||||
assert torch.all(
|
||||
grid.evaluate_local(points, params) < max_expected_result
|
||||
)
|
||||
|
||||
test(
|
||||
FullResolutionVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
n_features=10,
|
||||
)
|
||||
test(
|
||||
CPFactorizedVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
n_features=10,
|
||||
n_components=3,
|
||||
)
|
||||
test(
|
||||
VMFactorizedVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
n_features=10,
|
||||
n_components=3,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user