voxel grids with interpolation

Summary: Added voxel grid classes from TensoRF, both in their factorized (CP and VM) and full form.

Reviewed By: bottler

Differential Revision: D38465419

fbshipit-source-id: 8b306338af58dc50ef47a682616022a0512c0047
This commit is contained in:
Darijan Gudelj 2022-08-23 07:22:41 -07:00 committed by Facebook GitHub Bot
parent af799facdd
commit edee25a1e5
3 changed files with 1112 additions and 0 deletions

View File

@ -7,6 +7,8 @@
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from pytorch3d.common.compat import prod
from pytorch3d.renderer.cameras import CamerasBase
@ -88,3 +90,98 @@ def create_embeddings_for_implicit_function(
embeds = broadcast_global_code(embeds, global_code)
return embeds
def interpolate_line(
points: torch.Tensor,
source: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Linearly interpolates values of source grids. The first dimension of points represents
number of points and the second coordinate, for example ([[x0], [x1], ...]). The first
dimension of argument source represents feature and ones after that the spatial
dimension.
Arguments:
points: shape (n_grids, n_points, 1),
source: tensor of shape (n_grids, features, width),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
# To enable sampling of the source using the torch.functional.grid_sample
# points need to have 2 coordinates.
expansion = points.new_zeros(points.shape)
points = torch.cat((points, expansion), dim=-1)
source = source[:, :, None, :]
points = points[:, :, None, :]
out = F.grid_sample(
grid=points,
input=source,
**kwargs,
)
return out[:, :, :, 0].permute(0, 2, 1)
def interpolate_plane(
points: torch.Tensor,
source: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Bilinearly interpolates values of source grids. The first dimension of points represents
number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]).
The first dimension of argument source represents feature and ones after that the
spatial dimension.
Arguments:
points: shape (n_grids, n_points, 2),
source: tensor of shape (n_grids, features, width, height),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
# permuting because torch.nn.functional.grid_sample works with
# (features, height, width) and not
# (features, width, height)
source = source.permute(0, 1, 3, 2)
points = points[:, :, None, :]
out = F.grid_sample(
grid=points,
input=source,
**kwargs,
)
return out[:, :, :, 0].permute(0, 2, 1)
def interpolate_volume(
points: torch.Tensor, source: torch.Tensor, **kwargs
) -> torch.Tensor:
"""
Interpolates values of source grids. The first dimension of points represents
number of points and the second coordinates, for example
[[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features
and ones after that the spatial dimension.
Arguments:
points: shape (n_grids, n_points, 3),
source: tensor of shape (n_grids, features, width, height, depth),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
if "mode" in kwargs and kwargs["mode"] == "trilinear":
kwargs = kwargs.copy()
kwargs["mode"] = "bilinear"
# permuting because torch.nn.functional.grid_sample works with
# (features, depth, height, width) and not (features, width, height, depth)
source = source.permute(0, 1, 4, 3, 2)
grid = points[:, :, None, None, :]
out = F.grid_sample(
grid=grid,
input=source,
**kwargs,
)
return out[:, :, :, 0, 0].permute(0, 2, 1)

View File

@ -0,0 +1,428 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
This file contains classes that implement Voxel grids, both in their full resolution
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
https://arxiv.org/abs/2203.09517.
"""
from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple, Type
import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.structures.volumes import VolumeLocator
from .utils import interpolate_line, interpolate_plane, interpolate_volume
@dataclass
class VoxelGridValuesBase:
pass
class VoxelGridBase(ReplaceableBase, torch.nn.Module):
"""
Base class for all the voxel grid variants whith added trilinear interpolation between
voxels (for example if voxel (0.333, 1, 3) is queried that would return the result
2/3*voxel[0, 1, 3] + 1/3*voxel[1, 1, 3])
Internally voxel grids are indexed by (features, x, y, z). If queried the point is not
inside the voxel grid the vector that will be returned is determined by padding.
Members:
align_corners: parameter used in torch.functional.grid_sample. For details go to
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html by
default is True
padding: padding mode for outside grid values 'zeros' | 'border' | 'reflection'.
Default is 'zeros'
mode: interpolation mode to calculate output values :
'bilinear' | 'nearest' | 'bicubic' | 'trilinear'.
Default: 'bilinear' Note: mode='bicubic' supports only FullResolutionVoxelGrid.
When mode='bilinear' and the input is 5-D, the interpolation mode used internally
will actually be trilinear.
n_features: number of dimensions of base feature vector. Determines how many features
the grid returns.
resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis.
"""
align_corners: bool = True
padding: str = "zeros"
mode: str = "bilinear"
n_features: int = 1
resolution: Tuple[int, int, int] = (64, 64, 64)
def __post_init__(self):
super().__init__()
def evaluate_world(
self,
points: torch.Tensor,
grid_values: VoxelGridValuesBase,
locator: VolumeLocator,
) -> torch.Tensor:
"""
Evaluates the voxel grid at points in the world coordinate frame.
The interpolation type is determined by the `mode` member.
Arguments:
points (torch.Tensor): tensor of points that you want to query
of a form (n_grids, n_points, 3)
grid_values: an object of type Class.values_type which has tensors as
members which have shapes derived from the get_shapes() method
locator: a VolumeLocator object
Returns:
torch.Tensor: shape (n_grids, n_points, n_features)
"""
points_local = locator.world_to_local_coords(points)
# pyre-ignore[29]
return self.evaluate_local(points_local, grid_values)
def evaluate_local(
self, points: torch.Tensor, grid_values: VoxelGridValuesBase
) -> torch.Tensor:
"""
Evaluates the voxel grid at points in the local coordinate frame,
The interpolation type is determined by the `mode` member.
Arguments:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
as members which have shapes derived from the get_shapes() method
Returns:
torch.Tensor: shape (n_grids, n_points, n_features)
"""
raise NotImplementedError()
def get_shapes(self) -> Dict[str, Tuple]:
"""
Using parameters from the __init__ method, this method returns the
shapes of individual tensors needed to run the evaluate method.
Returns:
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
replace the shapes in the dictionary with tensors of those shapes and add the
first 'batch' dimension. If the required shape is (a, b) and you want to
have g grids than the tensor that replaces the shape should have the
shape (g, a, b).
"""
raise NotImplementedError()
@dataclass
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
voxel_grid: torch.Tensor
@registry.register
class FullResolutionVoxelGrid(VoxelGridBase):
"""
Full resolution voxel grid equivalent to 4D tensor where shape is
(features, width, height, depth) with linear interpolation between voxels.
"""
# the type of grid_values argument needed to run evaluate_local()
values_type: ClassVar[Type[VoxelGridValuesBase]] = FullResolutionVoxelGridValues
def evaluate_local(
self, points: torch.Tensor, grid_values: FullResolutionVoxelGridValues
) -> torch.Tensor:
"""
Evaluates the voxel grid at points in the local coordinate frame,
The interpolation type is determined by the `mode` member.
Arguments:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
grid_values: an object of type values_type which has tensors as
members which have shapes derived from the get_shapes() method
Returns:
torch.Tensor: shape (n_grids, n_points, n_features)
"""
return interpolate_volume(
points,
grid_values.voxel_grid,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
)
def get_shapes(self) -> Dict[str, Tuple]:
return {"voxel_grid": (self.n_features, *self.resolution)}
@dataclass
class CPFactorizedVoxelGridValues(VoxelGridValuesBase):
vector_components_x: torch.Tensor
vector_components_y: torch.Tensor
vector_components_z: torch.Tensor
basis_matrix: Optional[torch.Tensor] = None
@registry.register
class CPFactorizedVoxelGrid(VoxelGridBase):
"""
Canonical Polyadic (CP/CANDECOMP/PARAFAC) Factorization factorizes the 3d grid into three
vectors (x, y, z). For n_components=n, the 3d grid is a sum of the two outer products
(call it ) of each vector type (x, y, z):
3d_grid = x0 y0 z0 + x1 y1 z1 + ... + xn yn zn
These tensors are passed in a object of CPFactorizedVoxelGridValues (here obj) as
obj.vector_components_x, obj.vector_components_y, obj.vector_components_z. Their shapes are
`(n_components, r)` where `r` is the relevant resolution.
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
of different components are summed together to create (n_grids, n_components, 1) tensor.
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
n_features as F, n_components as C and n_grids as G:
3d_grid = (x y z) @ basis # GWHDC x GCF -> GWHDF
The basis feature vectors are passed as obj.basis_matrix.
Members:
n_components: number of vector triplets, higher number gives better approximation.
matrix_reduction: how to transform components. If matrix_reduction=True result
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
is summed along the rows to get (n_grids, n_points, 1).
"""
# the type of grid_values argument needed to run evaluate_local()
values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues
n_components: int = 24
matrix_reduction: bool = True
def evaluate_local(
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
) -> torch.Tensor:
def factor(i):
axis = ["x", "y", "z"][i]
index = points[..., i, None]
vector = getattr(grid_values, "vector_components_" + axis)
return interpolate_line(
index,
vector,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
)
# collect points from all the vectors and multipy them out
mult = factor(0) * factor(1) * factor(2)
# reduce the result from
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
if grid_values.basis_matrix is not None:
# (n_grids, n_points, n_features) =
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
return torch.bmm(mult, grid_values.basis_matrix)
return mult.sum(axis=-1, keepdim=True)
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
if self.matrix_reduction is False and self.n_features != 1:
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
shape_dict = {
"vector_components_x": (self.n_components, self.resolution[0]),
"vector_components_y": (self.n_components, self.resolution[1]),
"vector_components_z": (self.n_components, self.resolution[2]),
}
if self.matrix_reduction:
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
return shape_dict
@dataclass
class VMFactorizedVoxelGridValues(VoxelGridValuesBase):
vector_components_x: torch.Tensor
vector_components_y: torch.Tensor
vector_components_z: torch.Tensor
matrix_components_xy: torch.Tensor
matrix_components_yz: torch.Tensor
matrix_components_xz: torch.Tensor
basis_matrix: Optional[torch.Tensor] = None
@registry.register
class VMFactorizedVoxelGrid(VoxelGridBase):
"""
Implementation of Vector-Matrix Factorization of a tensor from
https://arxiv.org/abs/2203.09517.
Vector-Matrix Factorization factorizes the 3d grid into three matrices
(xy, xz, yz) and three vectors (x, y, z). For n_components=1, the 3d grid
is a sum of the outer products (call it ) of each matrix with its
complementary vector:
3d_grid = xy z + xz y + yz x.
These tensors are passed in a VMFactorizedVoxelGridValues object (here obj)
as obj.matrix_components_xy, obj.matrix_components_xy, obj.vector_components_y, etc.
Their shapes are `(n_grids, n_components, r0, r1)` for matrix_components and
(n_grids, n_components, r2)` for vector_componenets. Each of `r0, r1 and r2` coresponds
to one resolution in (width, height and depth).
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
of different components are summed together to create (n_grids, n_components, 1) tensor.
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
n_features as F, n_components as C (which can differ for each dimension) and n_grids as G:
3d_grid = concat((xy z), (xz y).permute(0, 2, 1),
(yz x).permute(2, 0, 1)) @ basis_matrix # GWHDC x GCF -> GWHDF
Members:
n_components: total number of matrix vector pairs, this must be divisible by 3. Set
this if you want to have equal representational power in all 3 directions. You
must specify either n_components or distribution_of_components, you cannot
specify both.
distribution_of_components: if you do not want equal representational power in
all 3 directions specify a tuple of numbers of matrix_vector pairs for each
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
either n_components or distribution_of_components, you cannot specify both.
matrix_reduction: how to transform components. If matrix_reduction=True result
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
the basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
is summed along the rows to get (n_grids, n_points, 1).
"""
# the type of grid_values argument needed to run evaluate_local()
values_type: ClassVar[Type[VoxelGridValuesBase]] = VMFactorizedVoxelGridValues
n_components: Optional[int] = None
distribution_of_components: Optional[Tuple[int, int, int]] = None
matrix_reduction: bool = True
def evaluate_local(
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
) -> torch.Tensor:
# collect points from matrices and vectors and multiply them
a = interpolate_plane(
points[..., :2],
grid_values.matrix_components_xy,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
) * interpolate_line(
points[..., 2:],
grid_values.vector_components_z,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
)
b = interpolate_plane(
points[..., [0, 2]],
grid_values.matrix_components_xz,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
) * interpolate_line(
points[..., 1:2],
grid_values.vector_components_y,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
)
c = interpolate_plane(
points[..., 1:],
grid_values.matrix_components_yz,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
) * interpolate_line(
points[..., :1],
grid_values.vector_components_x,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
)
# pyre-ignore[28]
feats = torch.cat((a, b, c), axis=-1)
# reduce the result from
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
if grid_values.basis_matrix is not None:
# (n_grids, n_points, n_features) =
# (n_grids, n_points, total_n_components) x
# (n_grids, total_n_components, n_features)
return torch.bmm(feats, grid_values.basis_matrix)
# pyre-ignore[28]
return feats.sum(axis=-1, keepdim=True)
def get_shapes(self) -> Dict[str, Tuple]:
if self.matrix_reduction is False and self.n_features != 1:
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
if self.distribution_of_components is None and self.n_components is None:
raise ValueError(
"You need to provide n_components or distribution_of_components"
)
if (
self.distribution_of_components is not None
and self.n_components is not None
):
raise ValueError(
"You cannot define n_components and distribution_of_components"
)
# pyre-ignore[58]
if self.distribution_of_components is None and self.n_components % 3 != 0:
raise ValueError("n_components must be divisible by 3")
if self.distribution_of_components is None:
# pyre-ignore[58]
calculated_distribution_of_components = [
self.n_components // 3 for _ in range(3)
]
else:
calculated_distribution_of_components = self.distribution_of_components
shape_dict = {
"vector_components_x": (
calculated_distribution_of_components[1],
self.resolution[0],
),
"vector_components_y": (
calculated_distribution_of_components[2],
self.resolution[1],
),
"vector_components_z": (
calculated_distribution_of_components[0],
self.resolution[2],
),
"matrix_components_xy": (
calculated_distribution_of_components[0],
self.resolution[0],
self.resolution[1],
),
"matrix_components_yz": (
calculated_distribution_of_components[1],
self.resolution[1],
self.resolution[2],
),
"matrix_components_xz": (
calculated_distribution_of_components[2],
self.resolution[0],
self.resolution[2],
),
}
if self.matrix_reduction:
shape_dict["basis_matrix"] = (
sum(calculated_distribution_of_components),
self.n_features,
)
return shape_dict

View File

@ -0,0 +1,587 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.models.implicit_function.utils import (
interpolate_line,
interpolate_plane,
interpolate_volume,
)
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
CPFactorizedVoxelGrid,
FullResolutionVoxelGrid,
VMFactorizedVoxelGrid,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from tests.common_testing import TestCaseMixin
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
"""
Tests Voxel grids, tests them by setting all elements to zero (after retrieving
they should also return zero) and by setting all of the elements to one and
getting the result. Also tests the interpolation by 'manually' interpolating
one by one sample and comparing with the batched implementation.
"""
def test_my_code(self):
return
def get_random_normalized_points(
self, n_grids, n_points, dimension=3
) -> torch.Tensor:
# create random query points
return torch.rand(n_grids, n_points, dimension) * 2 - 1
def _test_query_with_constant_init_cp(
self,
n_grids: int,
n_features: int,
n_components: int,
resolution: Tuple[int],
value: float = 1,
n_points: int = 1,
) -> None:
# set everything to 'value' and do query for elementsthe result should
# be of shape (n_grids, n_points, n_features) and be filled with n_components
# * value
grid = CPFactorizedVoxelGrid(
resolution=resolution,
n_components=n_components,
n_features=n_features,
)
shapes = grid.get_shapes()
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * n_components * value,
)
def _test_query_with_constant_init_vm(
self,
n_grids: int,
n_features: int,
resolution: Tuple[int],
n_components: Optional[int] = None,
distribution: Optional[Tuple[int]] = None,
value: float = 1,
n_points: int = 1,
) -> None:
# set everything to 'value' and do query for elements
grid = VMFactorizedVoxelGrid(
n_features=n_features,
resolution=resolution,
n_components=n_components,
distribution_of_components=distribution,
)
shapes = grid.get_shapes()
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
expected_element = (
n_components * value if distribution is None else sum(distribution) * value
)
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * expected_element,
)
def _test_query_with_constant_init_full(
self,
n_grids: int,
n_features: int,
resolution: Tuple[int],
value: int = 1,
n_points: int = 1,
) -> None:
# set everything to 'value' and do query for elements
grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution)
shapes = grid.get_shapes()
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
expected_element = value
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * expected_element,
)
def test_query_with_constant_init(self):
with self.subTest("Full"):
self._test_query_with_constant_init_full(
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
)
with self.subTest("Full with 1 in dimensions"):
self._test_query_with_constant_init_full(
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
)
with self.subTest("CP"):
self._test_query_with_constant_init_cp(
n_grids=5,
n_features=6,
n_components=7,
resolution=(3, 4, 5),
n_points=2,
)
with self.subTest("CP with 1 in dimensions"):
self._test_query_with_constant_init_cp(
n_grids=2,
n_features=1,
n_components=3,
resolution=(3, 1, 1),
n_points=4,
)
with self.subTest("VM with symetric distribution"):
self._test_query_with_constant_init_vm(
n_grids=6,
n_features=9,
resolution=(2, 12, 2),
n_components=12,
n_points=3,
)
with self.subTest("VM with distribution"):
self._test_query_with_constant_init_vm(
n_grids=5,
n_features=1,
resolution=(5, 9, 7),
distribution=(33, 41, 1),
n_points=7,
)
def test_query_with_zero_init(self):
with self.subTest("Query testing with zero init CPFactorizedVoxelGrid"):
self._test_query_with_constant_init_cp(
n_grids=5,
n_features=6,
n_components=7,
resolution=(3, 2, 5),
n_points=3,
value=0,
)
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
self._test_query_with_constant_init_vm(
n_grids=2,
n_features=9,
resolution=(2, 11, 3),
n_components=3,
n_points=3,
value=0,
)
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
self._test_query_with_constant_init_full(
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
)
def setUp(self):
torch.manual_seed(42)
expand_args_fields(FullResolutionVoxelGrid)
expand_args_fields(CPFactorizedVoxelGrid)
expand_args_fields(VMFactorizedVoxelGrid)
def _interpolate_1D(
self, points: torch.Tensor, vectors: torch.Tensor
) -> torch.Tensor:
"""
interpolate vector from points, which are (batch, 1) and individual point is in [-1, 1]
"""
result = []
_, _, width = vectors.shape
# transform from [-1, 1] to [0, width-1]
points = (points + 1) / 2 * (width - 1)
for vector, row in zip(vectors, points):
newrow = []
for x in row:
xf, xc = int(torch.floor(x)), int(torch.ceil(x))
itemf, itemc = vector[:, xf], vector[:, xc]
tmp = itemf * (xc - x) + itemc * (x - xf)
newrow.append(tmp[None, None, :])
result.append(torch.cat(newrow, dim=1))
return torch.cat(result)
def _interpolate_2D(
self, points: torch.Tensor, matrices: torch.Tensor
) -> torch.Tensor:
"""
interpolate matrix from points, which are (batch, 2) and individual point is in [-1, 1]
"""
result = []
n_grids, _, width, height = matrices.shape
points = (points + 1) / 2 * (torch.tensor([[[width, height]]]) - 1)
for matrix, row in zip(matrices, points):
newrow = []
for x, y in row:
xf, xc = int(torch.floor(x)), int(torch.ceil(x))
yf, yc = int(torch.floor(y)), int(torch.ceil(y))
itemff, itemfc = matrix[:, xf, yf], matrix[:, xf, yc]
itemcf, itemcc = matrix[:, xc, yf], matrix[:, xc, yc]
itemf = itemff * (xc - x) + itemcf * (x - xf)
itemc = itemfc * (xc - x) + itemcc * (x - xf)
tmp = itemf * (yc - y) + itemc * (y - yf)
newrow.append(tmp[None, None, :])
result.append(torch.cat(newrow, dim=1))
return torch.cat(result)
def _interpolate_3D(
self, points: torch.Tensor, tensors: torch.Tensor
) -> torch.Tensor:
"""
interpolate tensors from points, which are (batch, 3) and individual point is in [-1, 1]
"""
result = []
_, _, width, height, depth = tensors.shape
batch_normalized_points = (
(points + 1) / 2 * (torch.tensor([[[width, height, depth]]]) - 1)
)
batch_points = points
for tensor, points, normalized_points in zip(
tensors, batch_points, batch_normalized_points
):
newrow = []
for (x, y, z), (_, _, nz) in zip(points, normalized_points):
zf, zc = int(torch.floor(nz)), int(torch.ceil(nz))
itemf = self._interpolate_2D(
points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zf]
)
itemc = self._interpolate_2D(
points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zc]
)
tmp = self._interpolate_1D(
points=torch.tensor([[[z]]]),
vectors=torch.cat((itemf, itemc), dim=1).permute(0, 2, 1),
)
newrow.append(tmp)
result.append(torch.cat(newrow, dim=1))
return torch.cat(result)
def test_interpolation(self):
with self.subTest("1D interpolation"):
points = self.get_random_normalized_points(
n_grids=4, n_points=5, dimension=1
)
vector = torch.randn(size=(4, 3, 2))
assert torch.allclose(
self._interpolate_1D(points, vector),
interpolate_line(
points,
vector,
align_corners=True,
padding_mode="zeros",
mode="bilinear",
),
)
with self.subTest("2D interpolation"):
points = self.get_random_normalized_points(
n_grids=4, n_points=5, dimension=2
)
matrix = torch.randn(size=(4, 2, 3, 5))
assert torch.allclose(
self._interpolate_2D(points, matrix),
interpolate_plane(
points,
matrix,
align_corners=True,
padding_mode="zeros",
mode="bilinear",
),
)
with self.subTest("3D interpolation"):
points = self.get_random_normalized_points(
n_grids=4, n_points=5, dimension=3
)
tensor = torch.randn(size=(4, 5, 2, 7, 2))
assert torch.allclose(
self._interpolate_3D(points, tensor),
interpolate_volume(
points,
tensor,
align_corners=True,
padding_mode="zeros",
mode="bilinear",
),
)
def test_floating_point_query(self):
"""
test querying the voxel grids on some float positions
"""
with self.subTest("FullResolution"):
grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1))
params = grid.values_type(**grid.get_shapes())
params.voxel_grid = torch.tensor(
[
[
[[[1, 3], [5, 7]], [[9, 11], [13, 15]]],
[[[2, 4], [6, 8]], [[10, 12], [14, 16]]],
],
[
[[[17, 18], [19, 20]], [[21, 22], [23, 24]]],
[[[25, 26], [27, 28]], [[29, 30], [31, 32]]],
],
],
dtype=torch.float,
)
points = (
torch.tensor(
[
[
[1, 0, 1],
[0.5, 1, 1],
[1 / 3, 1 / 3, 2 / 3],
],
[
[0, 1, 1],
[0, 0.5, 1],
[1 / 4, 1 / 4, 3 / 4],
],
]
)
/ torch.tensor([[1.0, 1, 1]])
* 2
- 1
)
expected_result = torch.tensor(
[
[[11, 12], [11, 12], [6.333333, 7.3333333]],
[[20, 28], [19, 27], [19.25, 27.25]],
]
)
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result,
rtol=0.00001,
), grid.evaluate_local(points, params)
with self.subTest("CP"):
grid = CPFactorizedVoxelGrid(
n_features=1, resolution=(1, 1, 1), n_components=3
)
params = grid.values_type(**grid.get_shapes())
params.vector_components_x = torch.tensor(
[
[[1, 2], [10.5, 20.5]],
[[10, 20], [2, 4]],
]
)
params.vector_components_y = torch.tensor(
[
[[3, 4, 5], [30.5, 40.5, 50.5]],
[[30, 40, 50], [1, 3, 5]],
]
)
params.vector_components_z = torch.tensor(
[
[[6, 7, 8, 9], [60.5, 70.5, 80.5, 90.5]],
[[60, 70, 80, 90], [6, 7, 8, 9]],
]
)
params.basis_matrix = torch.tensor(
[
[[2.0], [2.0]],
[[1.0], [2.0]],
]
)
points = (
torch.tensor(
[
[
[0, 2, 2],
[1, 2, 0.25],
[0.5, 0.5, 1],
[1 / 3, 2 / 3, 2 + 1 / 3],
],
[
[1, 0, 1],
[0.5, 2, 2],
[1, 0.5, 0.5],
[1 / 4, 3 / 4, 2 + 1 / 4],
],
]
)
/ torch.tensor([[[1.0, 2, 3]]])
* 2
- 1
)
expected_result_matrix = torch.tensor(
[
[[85450.25], [130566.5], [77658.75], [86285.422]],
[[42056], [60240], [45604], [38775]],
]
)
expected_result_sum = torch.tensor(
[
[[42725.125], [65283.25], [38829.375], [43142.711]],
[[42028], [60120], [45552], [38723.4375]],
]
)
with self.subTest("CP with basis_matrix reduction"):
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result_matrix,
rtol=0.00001,
)
del params.basis_matrix
with self.subTest("CP with sum reduction"):
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result_sum,
rtol=0.00001,
)
with self.subTest("VM"):
grid = VMFactorizedVoxelGrid(
n_features=1, resolution=(1, 1, 1), n_components=3
)
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes())
params.matrix_components_xy = torch.tensor(
[
[[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]],
[[[35, 36], [37, 38]], [[39, 40], [41, 42]]],
]
)
params.matrix_components_xz = torch.tensor(
[
[[[7, 8], [9, 10]], [[25, 26], [27, 28.0]]],
[[[43, 44], [45, 46]], [[47, 48], [49, 50]]],
]
)
params.matrix_components_yz = torch.tensor(
[
[[[13, 14], [15, 16]], [[31, 32], [33, 34.0]]],
[[[51, 52], [53, 54]], [[55, 56], [57, 58.0]]],
]
)
params.vector_components_z = torch.tensor(
[
[[5, 6], [23, 24.0]],
[[59, 60], [61, 62]],
]
)
params.vector_components_y = torch.tensor(
[
[[11, 12], [29, 30.0]],
[[63, 64], [65, 66]],
]
)
params.vector_components_x = torch.tensor(
[
[[17, 18], [35, 36.0]],
[[67, 68], [69, 70.0]],
]
)
params.basis_matrix = torch.tensor(
[
[2, 2, 2, 2, 2, 2.0],
[1, 2, 1, 2, 1, 2.0],
]
)[:, :, None]
points = (
torch.tensor(
[
[
[1, 0, 1],
[0.5, 1, 1],
[1 / 3, 1 / 3, 2 / 3],
],
[
[0, 1, 0],
[0, 0, 0],
[0, 1, 0],
],
]
)
/ torch.tensor([[[1.0, 1, 1]]])
* 2
- 1
)
expected_result_matrix = torch.tensor(
[
[[5696], [5854], [5484.888]],
[[27377], [26649], [27377]],
]
)
expected_result_sum = torch.tensor(
[
[[2848], [2927], [2742.444]],
[[17902], [17420], [17902]],
]
)
with self.subTest("VM with basis_matrix reduction"):
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result_matrix,
rtol=0.00001,
)
del params.basis_matrix
with self.subTest("VM with sum reduction"):
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result_sum,
rtol=0.0001,
), grid.evaluate_local(points, params)
def test_forward_with_small_init_std(self):
"""
Test does the grid return small values if it is initialized with small
mean and small standard deviation.
"""
def test(cls, **kwargs):
with self.subTest(cls.__name__):
n_grids = 3
grid = cls(**kwargs)
shapes = grid.get_shapes()
params = cls.values_type(
**{
k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001)
for k, shape in shapes.items()
}
)
points = self.get_random_normalized_points(n_grids=n_grids, n_points=3)
max_expected_result = torch.zeros((len(points), 10)) + 1e-2
assert torch.all(
grid.evaluate_local(points, params) < max_expected_result
)
test(
FullResolutionVoxelGrid,
resolution=(4, 6, 9),
n_features=10,
)
test(
CPFactorizedVoxelGrid,
resolution=(4, 6, 9),
n_features=10,
n_components=3,
)
test(
VMFactorizedVoxelGrid,
resolution=(4, 6, 9),
n_features=10,
n_components=3,
)