From edee25a1e5385ae62f37995dee1b5960b2a5ea66 Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Tue, 23 Aug 2022 07:22:41 -0700 Subject: [PATCH] voxel grids with interpolation Summary: Added voxel grid classes from TensoRF, both in their factorized (CP and VM) and full form. Reviewed By: bottler Differential Revision: D38465419 fbshipit-source-id: 8b306338af58dc50ef47a682616022a0512c0047 --- .../models/implicit_function/utils.py | 97 +++ .../models/implicit_function/voxel_grid.py | 428 +++++++++++++ tests/implicitron/test_voxel_grids.py | 587 ++++++++++++++++++ 3 files changed, 1112 insertions(+) create mode 100644 pytorch3d/implicitron/models/implicit_function/voxel_grid.py create mode 100644 tests/implicitron/test_voxel_grids.py diff --git a/pytorch3d/implicitron/models/implicit_function/utils.py b/pytorch3d/implicitron/models/implicit_function/utils.py index 9681818b..9a26aff3 100644 --- a/pytorch3d/implicitron/models/implicit_function/utils.py +++ b/pytorch3d/implicitron/models/implicit_function/utils.py @@ -7,6 +7,8 @@ from typing import Callable, Optional import torch + +import torch.nn.functional as F from pytorch3d.common.compat import prod from pytorch3d.renderer.cameras import CamerasBase @@ -88,3 +90,98 @@ def create_embeddings_for_implicit_function( embeds = broadcast_global_code(embeds, global_code) return embeds + + +def interpolate_line( + points: torch.Tensor, + source: torch.Tensor, + **kwargs, +) -> torch.Tensor: + """ + Linearly interpolates values of source grids. The first dimension of points represents + number of points and the second coordinate, for example ([[x0], [x1], ...]). The first + dimension of argument source represents feature and ones after that the spatial + dimension. + + Arguments: + points: shape (n_grids, n_points, 1), + source: tensor of shape (n_grids, features, width), + Returns: + interpolated tensor of shape (n_grids, n_points, features) + """ + # To enable sampling of the source using the torch.functional.grid_sample + # points need to have 2 coordinates. + expansion = points.new_zeros(points.shape) + points = torch.cat((points, expansion), dim=-1) + + source = source[:, :, None, :] + points = points[:, :, None, :] + + out = F.grid_sample( + grid=points, + input=source, + **kwargs, + ) + return out[:, :, :, 0].permute(0, 2, 1) + + +def interpolate_plane( + points: torch.Tensor, + source: torch.Tensor, + **kwargs, +) -> torch.Tensor: + """ + Bilinearly interpolates values of source grids. The first dimension of points represents + number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]). + The first dimension of argument source represents feature and ones after that the + spatial dimension. + + Arguments: + points: shape (n_grids, n_points, 2), + source: tensor of shape (n_grids, features, width, height), + Returns: + interpolated tensor of shape (n_grids, n_points, features) + """ + # permuting because torch.nn.functional.grid_sample works with + # (features, height, width) and not + # (features, width, height) + source = source.permute(0, 1, 3, 2) + points = points[:, :, None, :] + + out = F.grid_sample( + grid=points, + input=source, + **kwargs, + ) + return out[:, :, :, 0].permute(0, 2, 1) + + +def interpolate_volume( + points: torch.Tensor, source: torch.Tensor, **kwargs +) -> torch.Tensor: + """ + Interpolates values of source grids. The first dimension of points represents + number of points and the second coordinates, for example + [[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features + and ones after that the spatial dimension. + + Arguments: + points: shape (n_grids, n_points, 3), + source: tensor of shape (n_grids, features, width, height, depth), + Returns: + interpolated tensor of shape (n_grids, n_points, features) + """ + if "mode" in kwargs and kwargs["mode"] == "trilinear": + kwargs = kwargs.copy() + kwargs["mode"] = "bilinear" + # permuting because torch.nn.functional.grid_sample works with + # (features, depth, height, width) and not (features, width, height, depth) + source = source.permute(0, 1, 4, 3, 2) + grid = points[:, :, None, None, :] + + out = F.grid_sample( + grid=grid, + input=source, + **kwargs, + ) + return out[:, :, :, 0, 0].permute(0, 2, 1) diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py new file mode 100644 index 00000000..834dcb2c --- /dev/null +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -0,0 +1,428 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This file contains classes that implement Voxel grids, both in their full resolution +as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition +or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the +https://arxiv.org/abs/2203.09517. +""" + +from dataclasses import dataclass +from typing import ClassVar, Dict, Optional, Tuple, Type + +import torch +from pytorch3d.implicitron.tools.config import registry, ReplaceableBase +from pytorch3d.structures.volumes import VolumeLocator + +from .utils import interpolate_line, interpolate_plane, interpolate_volume + + +@dataclass +class VoxelGridValuesBase: + pass + + +class VoxelGridBase(ReplaceableBase, torch.nn.Module): + """ + Base class for all the voxel grid variants whith added trilinear interpolation between + voxels (for example if voxel (0.333, 1, 3) is queried that would return the result + 2/3*voxel[0, 1, 3] + 1/3*voxel[1, 1, 3]) + + Internally voxel grids are indexed by (features, x, y, z). If queried the point is not + inside the voxel grid the vector that will be returned is determined by padding. + + Members: + align_corners: parameter used in torch.functional.grid_sample. For details go to + https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html by + default is True + padding: padding mode for outside grid values 'zeros' | 'border' | 'reflection'. + Default is 'zeros' + mode: interpolation mode to calculate output values : + 'bilinear' | 'nearest' | 'bicubic' | 'trilinear'. + Default: 'bilinear' Note: mode='bicubic' supports only FullResolutionVoxelGrid. + When mode='bilinear' and the input is 5-D, the interpolation mode used internally + will actually be trilinear. + n_features: number of dimensions of base feature vector. Determines how many features + the grid returns. + resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis. + """ + + align_corners: bool = True + padding: str = "zeros" + mode: str = "bilinear" + n_features: int = 1 + resolution: Tuple[int, int, int] = (64, 64, 64) + + def __post_init__(self): + super().__init__() + + def evaluate_world( + self, + points: torch.Tensor, + grid_values: VoxelGridValuesBase, + locator: VolumeLocator, + ) -> torch.Tensor: + """ + Evaluates the voxel grid at points in the world coordinate frame. + The interpolation type is determined by the `mode` member. + + Arguments: + points (torch.Tensor): tensor of points that you want to query + of a form (n_grids, n_points, 3) + grid_values: an object of type Class.values_type which has tensors as + members which have shapes derived from the get_shapes() method + locator: a VolumeLocator object + Returns: + torch.Tensor: shape (n_grids, n_points, n_features) + """ + points_local = locator.world_to_local_coords(points) + # pyre-ignore[29] + return self.evaluate_local(points_local, grid_values) + + def evaluate_local( + self, points: torch.Tensor, grid_values: VoxelGridValuesBase + ) -> torch.Tensor: + """ + Evaluates the voxel grid at points in the local coordinate frame, + The interpolation type is determined by the `mode` member. + + Arguments: + points (torch.Tensor): tensor of points that you want to query + of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1]) + grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors + as members which have shapes derived from the get_shapes() method + Returns: + torch.Tensor: shape (n_grids, n_points, n_features) + """ + raise NotImplementedError() + + def get_shapes(self) -> Dict[str, Tuple]: + """ + Using parameters from the __init__ method, this method returns the + shapes of individual tensors needed to run the evaluate method. + + Returns: + a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods + replace the shapes in the dictionary with tensors of those shapes and add the + first 'batch' dimension. If the required shape is (a, b) and you want to + have g grids than the tensor that replaces the shape should have the + shape (g, a, b). + """ + raise NotImplementedError() + + +@dataclass +class FullResolutionVoxelGridValues(VoxelGridValuesBase): + voxel_grid: torch.Tensor + + +@registry.register +class FullResolutionVoxelGrid(VoxelGridBase): + """ + Full resolution voxel grid equivalent to 4D tensor where shape is + (features, width, height, depth) with linear interpolation between voxels. + """ + + # the type of grid_values argument needed to run evaluate_local() + values_type: ClassVar[Type[VoxelGridValuesBase]] = FullResolutionVoxelGridValues + + def evaluate_local( + self, points: torch.Tensor, grid_values: FullResolutionVoxelGridValues + ) -> torch.Tensor: + """ + Evaluates the voxel grid at points in the local coordinate frame, + The interpolation type is determined by the `mode` member. + + Arguments: + points (torch.Tensor): tensor of points that you want to query + of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1]) + grid_values: an object of type values_type which has tensors as + members which have shapes derived from the get_shapes() method + Returns: + torch.Tensor: shape (n_grids, n_points, n_features) + """ + return interpolate_volume( + points, + grid_values.voxel_grid, + align_corners=self.align_corners, + padding_mode=self.padding, + mode=self.mode, + ) + + def get_shapes(self) -> Dict[str, Tuple]: + return {"voxel_grid": (self.n_features, *self.resolution)} + + +@dataclass +class CPFactorizedVoxelGridValues(VoxelGridValuesBase): + vector_components_x: torch.Tensor + vector_components_y: torch.Tensor + vector_components_z: torch.Tensor + basis_matrix: Optional[torch.Tensor] = None + + +@registry.register +class CPFactorizedVoxelGrid(VoxelGridBase): + """ + Canonical Polyadic (CP/CANDECOMP/PARAFAC) Factorization factorizes the 3d grid into three + vectors (x, y, z). For n_components=n, the 3d grid is a sum of the two outer products + (call it ⊗) of each vector type (x, y, z): + + 3d_grid = x0 ⊗ y0 ⊗ z0 + x1 ⊗ y1 ⊗ z1 + ... + xn ⊗ yn ⊗ zn + + These tensors are passed in a object of CPFactorizedVoxelGridValues (here obj) as + obj.vector_components_x, obj.vector_components_y, obj.vector_components_z. Their shapes are + `(n_components, r)` where `r` is the relevant resolution. + + Each element of this sum has an extra dimension, which gets matrix-multiplied by an + appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication + brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements + of different components are summed together to create (n_grids, n_components, 1) tensor. + With some notation abuse, ignoring the interpolation operation, simplifying and denoting + n_features as F, n_components as C and n_grids as G: + + 3d_grid = (x ⊗ y ⊗ z) @ basis # GWHDC x GCF -> GWHDF + + The basis feature vectors are passed as obj.basis_matrix. + + Members: + n_components: number of vector triplets, higher number gives better approximation. + matrix_reduction: how to transform components. If matrix_reduction=True result + matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the + basis_matrix of shape (n_grids, n_components, n_features). If + matrix_reduction=False, the result tensor of (n_grids, n_points, n_components) + is summed along the rows to get (n_grids, n_points, 1). + """ + + # the type of grid_values argument needed to run evaluate_local() + values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues + + n_components: int = 24 + matrix_reduction: bool = True + + def evaluate_local( + self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues + ) -> torch.Tensor: + def factor(i): + axis = ["x", "y", "z"][i] + index = points[..., i, None] + vector = getattr(grid_values, "vector_components_" + axis) + return interpolate_line( + index, + vector, + align_corners=self.align_corners, + padding_mode=self.padding, + mode=self.mode, + ) + + # collect points from all the vectors and multipy them out + mult = factor(0) * factor(1) * factor(2) + + # reduce the result from + # (n_grids, n_points, n_components) to (n_grids, n_points, n_features) + if grid_values.basis_matrix is not None: + # (n_grids, n_points, n_features) = + # (n_grids, n_points, total_n_components) x (total_n_components, n_features) + return torch.bmm(mult, grid_values.basis_matrix) + + return mult.sum(axis=-1, keepdim=True) + + def get_shapes(self) -> Dict[str, Tuple[int, int]]: + if self.matrix_reduction is False and self.n_features != 1: + raise ValueError("Cannot set matrix_reduction=False and n_features to != 1") + + shape_dict = { + "vector_components_x": (self.n_components, self.resolution[0]), + "vector_components_y": (self.n_components, self.resolution[1]), + "vector_components_z": (self.n_components, self.resolution[2]), + } + if self.matrix_reduction: + shape_dict["basis_matrix"] = (self.n_components, self.n_features) + return shape_dict + + +@dataclass +class VMFactorizedVoxelGridValues(VoxelGridValuesBase): + vector_components_x: torch.Tensor + vector_components_y: torch.Tensor + vector_components_z: torch.Tensor + matrix_components_xy: torch.Tensor + matrix_components_yz: torch.Tensor + matrix_components_xz: torch.Tensor + basis_matrix: Optional[torch.Tensor] = None + + +@registry.register +class VMFactorizedVoxelGrid(VoxelGridBase): + """ + Implementation of Vector-Matrix Factorization of a tensor from + https://arxiv.org/abs/2203.09517. + + Vector-Matrix Factorization factorizes the 3d grid into three matrices + (xy, xz, yz) and three vectors (x, y, z). For n_components=1, the 3d grid + is a sum of the outer products (call it ⊗) of each matrix with its + complementary vector: + + 3d_grid = xy ⊗ z + xz ⊗ y + yz ⊗ x. + + These tensors are passed in a VMFactorizedVoxelGridValues object (here obj) + as obj.matrix_components_xy, obj.matrix_components_xy, obj.vector_components_y, etc. + + Their shapes are `(n_grids, n_components, r0, r1)` for matrix_components and + (n_grids, n_components, r2)` for vector_componenets. Each of `r0, r1 and r2` coresponds + to one resolution in (width, height and depth). + + Each element of this sum has an extra dimension, which gets matrix-multiplied by an + appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication + brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements + of different components are summed together to create (n_grids, n_components, 1) tensor. + With some notation abuse, ignoring the interpolation operation, simplifying and denoting + n_features as F, n_components as C (which can differ for each dimension) and n_grids as G: + + 3d_grid = concat((xy ⊗ z), (xz ⊗ y).permute(0, 2, 1), + (yz ⊗ x).permute(2, 0, 1)) @ basis_matrix # GWHDC x GCF -> GWHDF + + Members: + n_components: total number of matrix vector pairs, this must be divisible by 3. Set + this if you want to have equal representational power in all 3 directions. You + must specify either n_components or distribution_of_components, you cannot + specify both. + distribution_of_components: if you do not want equal representational power in + all 3 directions specify a tuple of numbers of matrix_vector pairs for each + coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify + either n_components or distribution_of_components, you cannot specify both. + matrix_reduction: how to transform components. If matrix_reduction=True result + matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by + the basis_matrix of shape (n_grids, n_components, n_features). If + matrix_reduction=False, the result tensor of (n_grids, n_points, n_components) + is summed along the rows to get (n_grids, n_points, 1). + """ + + # the type of grid_values argument needed to run evaluate_local() + values_type: ClassVar[Type[VoxelGridValuesBase]] = VMFactorizedVoxelGridValues + + n_components: Optional[int] = None + distribution_of_components: Optional[Tuple[int, int, int]] = None + matrix_reduction: bool = True + + def evaluate_local( + self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues + ) -> torch.Tensor: + # collect points from matrices and vectors and multiply them + a = interpolate_plane( + points[..., :2], + grid_values.matrix_components_xy, + align_corners=self.align_corners, + padding_mode=self.padding, + mode=self.mode, + ) * interpolate_line( + points[..., 2:], + grid_values.vector_components_z, + align_corners=self.align_corners, + padding_mode=self.padding, + mode=self.mode, + ) + b = interpolate_plane( + points[..., [0, 2]], + grid_values.matrix_components_xz, + align_corners=self.align_corners, + padding_mode=self.padding, + mode=self.mode, + ) * interpolate_line( + points[..., 1:2], + grid_values.vector_components_y, + align_corners=self.align_corners, + padding_mode=self.padding, + mode=self.mode, + ) + c = interpolate_plane( + points[..., 1:], + grid_values.matrix_components_yz, + align_corners=self.align_corners, + padding_mode=self.padding, + mode=self.mode, + ) * interpolate_line( + points[..., :1], + grid_values.vector_components_x, + align_corners=self.align_corners, + padding_mode=self.padding, + mode=self.mode, + ) + # pyre-ignore[28] + feats = torch.cat((a, b, c), axis=-1) + + # reduce the result from + # (n_grids, n_points, n_components) to (n_grids, n_points, n_features) + if grid_values.basis_matrix is not None: + # (n_grids, n_points, n_features) = + # (n_grids, n_points, total_n_components) x + # (n_grids, total_n_components, n_features) + return torch.bmm(feats, grid_values.basis_matrix) + # pyre-ignore[28] + return feats.sum(axis=-1, keepdim=True) + + def get_shapes(self) -> Dict[str, Tuple]: + if self.matrix_reduction is False and self.n_features != 1: + raise ValueError("Cannot set matrix_reduction=False and n_features to != 1") + if self.distribution_of_components is None and self.n_components is None: + raise ValueError( + "You need to provide n_components or distribution_of_components" + ) + if ( + self.distribution_of_components is not None + and self.n_components is not None + ): + raise ValueError( + "You cannot define n_components and distribution_of_components" + ) + # pyre-ignore[58] + if self.distribution_of_components is None and self.n_components % 3 != 0: + raise ValueError("n_components must be divisible by 3") + if self.distribution_of_components is None: + # pyre-ignore[58] + calculated_distribution_of_components = [ + self.n_components // 3 for _ in range(3) + ] + else: + calculated_distribution_of_components = self.distribution_of_components + + shape_dict = { + "vector_components_x": ( + calculated_distribution_of_components[1], + self.resolution[0], + ), + "vector_components_y": ( + calculated_distribution_of_components[2], + self.resolution[1], + ), + "vector_components_z": ( + calculated_distribution_of_components[0], + self.resolution[2], + ), + "matrix_components_xy": ( + calculated_distribution_of_components[0], + self.resolution[0], + self.resolution[1], + ), + "matrix_components_yz": ( + calculated_distribution_of_components[1], + self.resolution[1], + self.resolution[2], + ), + "matrix_components_xz": ( + calculated_distribution_of_components[2], + self.resolution[0], + self.resolution[2], + ), + } + if self.matrix_reduction: + shape_dict["basis_matrix"] = ( + sum(calculated_distribution_of_components), + self.n_features, + ) + + return shape_dict diff --git a/tests/implicitron/test_voxel_grids.py b/tests/implicitron/test_voxel_grids.py new file mode 100644 index 00000000..4e5ce3be --- /dev/null +++ b/tests/implicitron/test_voxel_grids.py @@ -0,0 +1,587 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest +from typing import Optional, Tuple + +import torch + +from pytorch3d.implicitron.models.implicit_function.utils import ( + interpolate_line, + interpolate_plane, + interpolate_volume, +) +from pytorch3d.implicitron.models.implicit_function.voxel_grid import ( + CPFactorizedVoxelGrid, + FullResolutionVoxelGrid, + VMFactorizedVoxelGrid, +) + +from pytorch3d.implicitron.tools.config import expand_args_fields +from tests.common_testing import TestCaseMixin + + +class TestVoxelGrids(TestCaseMixin, unittest.TestCase): + """ + Tests Voxel grids, tests them by setting all elements to zero (after retrieving + they should also return zero) and by setting all of the elements to one and + getting the result. Also tests the interpolation by 'manually' interpolating + one by one sample and comparing with the batched implementation. + """ + + def test_my_code(self): + return + + def get_random_normalized_points( + self, n_grids, n_points, dimension=3 + ) -> torch.Tensor: + # create random query points + return torch.rand(n_grids, n_points, dimension) * 2 - 1 + + def _test_query_with_constant_init_cp( + self, + n_grids: int, + n_features: int, + n_components: int, + resolution: Tuple[int], + value: float = 1, + n_points: int = 1, + ) -> None: + # set everything to 'value' and do query for elementsthe result should + # be of shape (n_grids, n_points, n_features) and be filled with n_components + # * value + grid = CPFactorizedVoxelGrid( + resolution=resolution, + n_components=n_components, + n_features=n_features, + ) + shapes = grid.get_shapes() + + params = grid.values_type( + **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes} + ) + + assert torch.allclose( + grid.evaluate_local( + self.get_random_normalized_points(n_grids, n_points), params + ), + torch.ones(n_grids, n_points, n_features) * n_components * value, + ) + + def _test_query_with_constant_init_vm( + self, + n_grids: int, + n_features: int, + resolution: Tuple[int], + n_components: Optional[int] = None, + distribution: Optional[Tuple[int]] = None, + value: float = 1, + n_points: int = 1, + ) -> None: + # set everything to 'value' and do query for elements + grid = VMFactorizedVoxelGrid( + n_features=n_features, + resolution=resolution, + n_components=n_components, + distribution_of_components=distribution, + ) + shapes = grid.get_shapes() + params = grid.values_type( + **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes} + ) + + expected_element = ( + n_components * value if distribution is None else sum(distribution) * value + ) + assert torch.allclose( + grid.evaluate_local( + self.get_random_normalized_points(n_grids, n_points), params + ), + torch.ones(n_grids, n_points, n_features) * expected_element, + ) + + def _test_query_with_constant_init_full( + self, + n_grids: int, + n_features: int, + resolution: Tuple[int], + value: int = 1, + n_points: int = 1, + ) -> None: + # set everything to 'value' and do query for elements + grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution) + shapes = grid.get_shapes() + params = grid.values_type( + **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes} + ) + + expected_element = value + assert torch.allclose( + grid.evaluate_local( + self.get_random_normalized_points(n_grids, n_points), params + ), + torch.ones(n_grids, n_points, n_features) * expected_element, + ) + + def test_query_with_constant_init(self): + with self.subTest("Full"): + self._test_query_with_constant_init_full( + n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3 + ) + with self.subTest("Full with 1 in dimensions"): + self._test_query_with_constant_init_full( + n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4 + ) + with self.subTest("CP"): + self._test_query_with_constant_init_cp( + n_grids=5, + n_features=6, + n_components=7, + resolution=(3, 4, 5), + n_points=2, + ) + with self.subTest("CP with 1 in dimensions"): + self._test_query_with_constant_init_cp( + n_grids=2, + n_features=1, + n_components=3, + resolution=(3, 1, 1), + n_points=4, + ) + with self.subTest("VM with symetric distribution"): + self._test_query_with_constant_init_vm( + n_grids=6, + n_features=9, + resolution=(2, 12, 2), + n_components=12, + n_points=3, + ) + with self.subTest("VM with distribution"): + self._test_query_with_constant_init_vm( + n_grids=5, + n_features=1, + resolution=(5, 9, 7), + distribution=(33, 41, 1), + n_points=7, + ) + + def test_query_with_zero_init(self): + with self.subTest("Query testing with zero init CPFactorizedVoxelGrid"): + self._test_query_with_constant_init_cp( + n_grids=5, + n_features=6, + n_components=7, + resolution=(3, 2, 5), + n_points=3, + value=0, + ) + with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"): + self._test_query_with_constant_init_vm( + n_grids=2, + n_features=9, + resolution=(2, 11, 3), + n_components=3, + n_points=3, + value=0, + ) + with self.subTest("Query testing with zero init FullResolutionVoxelGrid"): + self._test_query_with_constant_init_full( + n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0 + ) + + def setUp(self): + torch.manual_seed(42) + expand_args_fields(FullResolutionVoxelGrid) + expand_args_fields(CPFactorizedVoxelGrid) + expand_args_fields(VMFactorizedVoxelGrid) + + def _interpolate_1D( + self, points: torch.Tensor, vectors: torch.Tensor + ) -> torch.Tensor: + """ + interpolate vector from points, which are (batch, 1) and individual point is in [-1, 1] + """ + result = [] + _, _, width = vectors.shape + # transform from [-1, 1] to [0, width-1] + points = (points + 1) / 2 * (width - 1) + for vector, row in zip(vectors, points): + newrow = [] + for x in row: + xf, xc = int(torch.floor(x)), int(torch.ceil(x)) + itemf, itemc = vector[:, xf], vector[:, xc] + tmp = itemf * (xc - x) + itemc * (x - xf) + newrow.append(tmp[None, None, :]) + result.append(torch.cat(newrow, dim=1)) + return torch.cat(result) + + def _interpolate_2D( + self, points: torch.Tensor, matrices: torch.Tensor + ) -> torch.Tensor: + """ + interpolate matrix from points, which are (batch, 2) and individual point is in [-1, 1] + """ + result = [] + n_grids, _, width, height = matrices.shape + points = (points + 1) / 2 * (torch.tensor([[[width, height]]]) - 1) + for matrix, row in zip(matrices, points): + newrow = [] + for x, y in row: + xf, xc = int(torch.floor(x)), int(torch.ceil(x)) + yf, yc = int(torch.floor(y)), int(torch.ceil(y)) + itemff, itemfc = matrix[:, xf, yf], matrix[:, xf, yc] + itemcf, itemcc = matrix[:, xc, yf], matrix[:, xc, yc] + itemf = itemff * (xc - x) + itemcf * (x - xf) + itemc = itemfc * (xc - x) + itemcc * (x - xf) + tmp = itemf * (yc - y) + itemc * (y - yf) + newrow.append(tmp[None, None, :]) + result.append(torch.cat(newrow, dim=1)) + return torch.cat(result) + + def _interpolate_3D( + self, points: torch.Tensor, tensors: torch.Tensor + ) -> torch.Tensor: + """ + interpolate tensors from points, which are (batch, 3) and individual point is in [-1, 1] + """ + result = [] + _, _, width, height, depth = tensors.shape + batch_normalized_points = ( + (points + 1) / 2 * (torch.tensor([[[width, height, depth]]]) - 1) + ) + batch_points = points + + for tensor, points, normalized_points in zip( + tensors, batch_points, batch_normalized_points + ): + newrow = [] + for (x, y, z), (_, _, nz) in zip(points, normalized_points): + zf, zc = int(torch.floor(nz)), int(torch.ceil(nz)) + itemf = self._interpolate_2D( + points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zf] + ) + itemc = self._interpolate_2D( + points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zc] + ) + tmp = self._interpolate_1D( + points=torch.tensor([[[z]]]), + vectors=torch.cat((itemf, itemc), dim=1).permute(0, 2, 1), + ) + newrow.append(tmp) + result.append(torch.cat(newrow, dim=1)) + return torch.cat(result) + + def test_interpolation(self): + + with self.subTest("1D interpolation"): + points = self.get_random_normalized_points( + n_grids=4, n_points=5, dimension=1 + ) + vector = torch.randn(size=(4, 3, 2)) + assert torch.allclose( + self._interpolate_1D(points, vector), + interpolate_line( + points, + vector, + align_corners=True, + padding_mode="zeros", + mode="bilinear", + ), + ) + with self.subTest("2D interpolation"): + points = self.get_random_normalized_points( + n_grids=4, n_points=5, dimension=2 + ) + matrix = torch.randn(size=(4, 2, 3, 5)) + assert torch.allclose( + self._interpolate_2D(points, matrix), + interpolate_plane( + points, + matrix, + align_corners=True, + padding_mode="zeros", + mode="bilinear", + ), + ) + + with self.subTest("3D interpolation"): + points = self.get_random_normalized_points( + n_grids=4, n_points=5, dimension=3 + ) + tensor = torch.randn(size=(4, 5, 2, 7, 2)) + assert torch.allclose( + self._interpolate_3D(points, tensor), + interpolate_volume( + points, + tensor, + align_corners=True, + padding_mode="zeros", + mode="bilinear", + ), + ) + + def test_floating_point_query(self): + """ + test querying the voxel grids on some float positions + """ + with self.subTest("FullResolution"): + grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1)) + params = grid.values_type(**grid.get_shapes()) + params.voxel_grid = torch.tensor( + [ + [ + [[[1, 3], [5, 7]], [[9, 11], [13, 15]]], + [[[2, 4], [6, 8]], [[10, 12], [14, 16]]], + ], + [ + [[[17, 18], [19, 20]], [[21, 22], [23, 24]]], + [[[25, 26], [27, 28]], [[29, 30], [31, 32]]], + ], + ], + dtype=torch.float, + ) + points = ( + torch.tensor( + [ + [ + [1, 0, 1], + [0.5, 1, 1], + [1 / 3, 1 / 3, 2 / 3], + ], + [ + [0, 1, 1], + [0, 0.5, 1], + [1 / 4, 1 / 4, 3 / 4], + ], + ] + ) + / torch.tensor([[1.0, 1, 1]]) + * 2 + - 1 + ) + expected_result = torch.tensor( + [ + [[11, 12], [11, 12], [6.333333, 7.3333333]], + [[20, 28], [19, 27], [19.25, 27.25]], + ] + ) + + assert torch.allclose( + grid.evaluate_local(points, params), + expected_result, + rtol=0.00001, + ), grid.evaluate_local(points, params) + with self.subTest("CP"): + grid = CPFactorizedVoxelGrid( + n_features=1, resolution=(1, 1, 1), n_components=3 + ) + params = grid.values_type(**grid.get_shapes()) + params.vector_components_x = torch.tensor( + [ + [[1, 2], [10.5, 20.5]], + [[10, 20], [2, 4]], + ] + ) + params.vector_components_y = torch.tensor( + [ + [[3, 4, 5], [30.5, 40.5, 50.5]], + [[30, 40, 50], [1, 3, 5]], + ] + ) + params.vector_components_z = torch.tensor( + [ + [[6, 7, 8, 9], [60.5, 70.5, 80.5, 90.5]], + [[60, 70, 80, 90], [6, 7, 8, 9]], + ] + ) + params.basis_matrix = torch.tensor( + [ + [[2.0], [2.0]], + [[1.0], [2.0]], + ] + ) + points = ( + torch.tensor( + [ + [ + [0, 2, 2], + [1, 2, 0.25], + [0.5, 0.5, 1], + [1 / 3, 2 / 3, 2 + 1 / 3], + ], + [ + [1, 0, 1], + [0.5, 2, 2], + [1, 0.5, 0.5], + [1 / 4, 3 / 4, 2 + 1 / 4], + ], + ] + ) + / torch.tensor([[[1.0, 2, 3]]]) + * 2 + - 1 + ) + expected_result_matrix = torch.tensor( + [ + [[85450.25], [130566.5], [77658.75], [86285.422]], + [[42056], [60240], [45604], [38775]], + ] + ) + expected_result_sum = torch.tensor( + [ + [[42725.125], [65283.25], [38829.375], [43142.711]], + [[42028], [60120], [45552], [38723.4375]], + ] + ) + with self.subTest("CP with basis_matrix reduction"): + assert torch.allclose( + grid.evaluate_local(points, params), + expected_result_matrix, + rtol=0.00001, + ) + del params.basis_matrix + with self.subTest("CP with sum reduction"): + assert torch.allclose( + grid.evaluate_local(points, params), + expected_result_sum, + rtol=0.00001, + ) + + with self.subTest("VM"): + grid = VMFactorizedVoxelGrid( + n_features=1, resolution=(1, 1, 1), n_components=3 + ) + params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes()) + params.matrix_components_xy = torch.tensor( + [ + [[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]], + [[[35, 36], [37, 38]], [[39, 40], [41, 42]]], + ] + ) + params.matrix_components_xz = torch.tensor( + [ + [[[7, 8], [9, 10]], [[25, 26], [27, 28.0]]], + [[[43, 44], [45, 46]], [[47, 48], [49, 50]]], + ] + ) + params.matrix_components_yz = torch.tensor( + [ + [[[13, 14], [15, 16]], [[31, 32], [33, 34.0]]], + [[[51, 52], [53, 54]], [[55, 56], [57, 58.0]]], + ] + ) + + params.vector_components_z = torch.tensor( + [ + [[5, 6], [23, 24.0]], + [[59, 60], [61, 62]], + ] + ) + params.vector_components_y = torch.tensor( + [ + [[11, 12], [29, 30.0]], + [[63, 64], [65, 66]], + ] + ) + params.vector_components_x = torch.tensor( + [ + [[17, 18], [35, 36.0]], + [[67, 68], [69, 70.0]], + ] + ) + + params.basis_matrix = torch.tensor( + [ + [2, 2, 2, 2, 2, 2.0], + [1, 2, 1, 2, 1, 2.0], + ] + )[:, :, None] + points = ( + torch.tensor( + [ + [ + [1, 0, 1], + [0.5, 1, 1], + [1 / 3, 1 / 3, 2 / 3], + ], + [ + [0, 1, 0], + [0, 0, 0], + [0, 1, 0], + ], + ] + ) + / torch.tensor([[[1.0, 1, 1]]]) + * 2 + - 1 + ) + expected_result_matrix = torch.tensor( + [ + [[5696], [5854], [5484.888]], + [[27377], [26649], [27377]], + ] + ) + expected_result_sum = torch.tensor( + [ + [[2848], [2927], [2742.444]], + [[17902], [17420], [17902]], + ] + ) + with self.subTest("VM with basis_matrix reduction"): + assert torch.allclose( + grid.evaluate_local(points, params), + expected_result_matrix, + rtol=0.00001, + ) + del params.basis_matrix + with self.subTest("VM with sum reduction"): + assert torch.allclose( + grid.evaluate_local(points, params), + expected_result_sum, + rtol=0.0001, + ), grid.evaluate_local(points, params) + + def test_forward_with_small_init_std(self): + """ + Test does the grid return small values if it is initialized with small + mean and small standard deviation. + """ + + def test(cls, **kwargs): + with self.subTest(cls.__name__): + n_grids = 3 + grid = cls(**kwargs) + shapes = grid.get_shapes() + params = cls.values_type( + **{ + k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001) + for k, shape in shapes.items() + } + ) + points = self.get_random_normalized_points(n_grids=n_grids, n_points=3) + max_expected_result = torch.zeros((len(points), 10)) + 1e-2 + assert torch.all( + grid.evaluate_local(points, params) < max_expected_result + ) + + test( + FullResolutionVoxelGrid, + resolution=(4, 6, 9), + n_features=10, + ) + test( + CPFactorizedVoxelGrid, + resolution=(4, 6, 9), + n_features=10, + n_components=3, + ) + test( + VMFactorizedVoxelGrid, + resolution=(4, 6, 9), + n_features=10, + n_components=3, + )