mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	voxel grids with interpolation
Summary: Added voxel grid classes from TensoRF, both in their factorized (CP and VM) and full form. Reviewed By: bottler Differential Revision: D38465419 fbshipit-source-id: 8b306338af58dc50ef47a682616022a0512c0047
This commit is contained in:
		
							parent
							
								
									af799facdd
								
							
						
					
					
						commit
						edee25a1e5
					
				@ -7,6 +7,8 @@
 | 
			
		||||
from typing import Callable, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from pytorch3d.common.compat import prod
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
@ -88,3 +90,98 @@ def create_embeddings_for_implicit_function(
 | 
			
		||||
        embeds = broadcast_global_code(embeds, global_code)
 | 
			
		||||
 | 
			
		||||
    return embeds
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def interpolate_line(
 | 
			
		||||
    points: torch.Tensor,
 | 
			
		||||
    source: torch.Tensor,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Linearly interpolates values of source grids. The first dimension of points represents
 | 
			
		||||
    number of points and the second coordinate, for example ([[x0], [x1], ...]). The first
 | 
			
		||||
    dimension of argument source represents feature and ones after that the spatial
 | 
			
		||||
    dimension.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        points: shape (n_grids, n_points, 1),
 | 
			
		||||
        source: tensor of shape (n_grids, features, width),
 | 
			
		||||
    Returns:
 | 
			
		||||
        interpolated tensor of shape (n_grids, n_points, features)
 | 
			
		||||
    """
 | 
			
		||||
    # To enable sampling of the source using the torch.functional.grid_sample
 | 
			
		||||
    # points need to have 2 coordinates.
 | 
			
		||||
    expansion = points.new_zeros(points.shape)
 | 
			
		||||
    points = torch.cat((points, expansion), dim=-1)
 | 
			
		||||
 | 
			
		||||
    source = source[:, :, None, :]
 | 
			
		||||
    points = points[:, :, None, :]
 | 
			
		||||
 | 
			
		||||
    out = F.grid_sample(
 | 
			
		||||
        grid=points,
 | 
			
		||||
        input=source,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    return out[:, :, :, 0].permute(0, 2, 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def interpolate_plane(
 | 
			
		||||
    points: torch.Tensor,
 | 
			
		||||
    source: torch.Tensor,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Bilinearly interpolates values of source grids. The first dimension of points represents
 | 
			
		||||
    number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]).
 | 
			
		||||
    The first dimension of argument source represents feature and ones after that the
 | 
			
		||||
    spatial dimension.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        points: shape (n_grids, n_points, 2),
 | 
			
		||||
        source: tensor of shape (n_grids, features, width, height),
 | 
			
		||||
    Returns:
 | 
			
		||||
        interpolated tensor of shape (n_grids, n_points, features)
 | 
			
		||||
    """
 | 
			
		||||
    # permuting because torch.nn.functional.grid_sample works with
 | 
			
		||||
    # (features, height, width) and not
 | 
			
		||||
    # (features, width, height)
 | 
			
		||||
    source = source.permute(0, 1, 3, 2)
 | 
			
		||||
    points = points[:, :, None, :]
 | 
			
		||||
 | 
			
		||||
    out = F.grid_sample(
 | 
			
		||||
        grid=points,
 | 
			
		||||
        input=source,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    return out[:, :, :, 0].permute(0, 2, 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def interpolate_volume(
 | 
			
		||||
    points: torch.Tensor, source: torch.Tensor, **kwargs
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Interpolates values of source grids. The first dimension of points represents
 | 
			
		||||
    number of points and the second coordinates, for example
 | 
			
		||||
    [[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features
 | 
			
		||||
    and ones after that the spatial dimension.
 | 
			
		||||
 | 
			
		||||
    Arguments:
 | 
			
		||||
        points: shape (n_grids, n_points, 3),
 | 
			
		||||
        source: tensor of shape (n_grids, features, width, height, depth),
 | 
			
		||||
    Returns:
 | 
			
		||||
        interpolated tensor of shape (n_grids, n_points, features)
 | 
			
		||||
    """
 | 
			
		||||
    if "mode" in kwargs and kwargs["mode"] == "trilinear":
 | 
			
		||||
        kwargs = kwargs.copy()
 | 
			
		||||
        kwargs["mode"] = "bilinear"
 | 
			
		||||
    # permuting because torch.nn.functional.grid_sample works with
 | 
			
		||||
    # (features, depth, height, width) and not (features, width, height, depth)
 | 
			
		||||
    source = source.permute(0, 1, 4, 3, 2)
 | 
			
		||||
    grid = points[:, :, None, None, :]
 | 
			
		||||
 | 
			
		||||
    out = F.grid_sample(
 | 
			
		||||
        grid=grid,
 | 
			
		||||
        input=source,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    return out[:, :, :, 0, 0].permute(0, 2, 1)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										428
									
								
								pytorch3d/implicitron/models/implicit_function/voxel_grid.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										428
									
								
								pytorch3d/implicitron/models/implicit_function/voxel_grid.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,428 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This file contains classes that implement Voxel grids, both in their full resolution
 | 
			
		||||
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
 | 
			
		||||
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
 | 
			
		||||
https://arxiv.org/abs/2203.09517.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import ClassVar, Dict, Optional, Tuple, Type
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
from pytorch3d.structures.volumes import VolumeLocator
 | 
			
		||||
 | 
			
		||||
from .utils import interpolate_line, interpolate_plane, interpolate_volume
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class VoxelGridValuesBase:
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VoxelGridBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Base class for all the voxel grid variants whith added trilinear interpolation between
 | 
			
		||||
    voxels (for example if voxel (0.333, 1, 3) is queried that would return the result
 | 
			
		||||
    2/3*voxel[0, 1, 3] + 1/3*voxel[1, 1, 3])
 | 
			
		||||
 | 
			
		||||
    Internally voxel grids are indexed by (features, x, y, z). If queried the point is not
 | 
			
		||||
    inside the voxel grid the vector that will be returned is determined by padding.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
        align_corners: parameter used in torch.functional.grid_sample. For details go to
 | 
			
		||||
            https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html by
 | 
			
		||||
            default is True
 | 
			
		||||
        padding: padding mode for outside grid values 'zeros' | 'border' | 'reflection'.
 | 
			
		||||
            Default is 'zeros'
 | 
			
		||||
        mode: interpolation mode to calculate output values :
 | 
			
		||||
            'bilinear' | 'nearest' | 'bicubic' | 'trilinear'.
 | 
			
		||||
            Default: 'bilinear' Note: mode='bicubic' supports only FullResolutionVoxelGrid.
 | 
			
		||||
            When mode='bilinear' and the input is 5-D, the interpolation mode used internally
 | 
			
		||||
            will actually be trilinear.
 | 
			
		||||
        n_features: number of dimensions of base feature vector. Determines how many features
 | 
			
		||||
            the grid returns.
 | 
			
		||||
        resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    align_corners: bool = True
 | 
			
		||||
    padding: str = "zeros"
 | 
			
		||||
    mode: str = "bilinear"
 | 
			
		||||
    n_features: int = 1
 | 
			
		||||
    resolution: Tuple[int, int, int] = (64, 64, 64)
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def evaluate_world(
 | 
			
		||||
        self,
 | 
			
		||||
        points: torch.Tensor,
 | 
			
		||||
        grid_values: VoxelGridValuesBase,
 | 
			
		||||
        locator: VolumeLocator,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Evaluates the voxel grid at points in the world coordinate frame.
 | 
			
		||||
        The interpolation type is determined by the `mode` member.
 | 
			
		||||
 | 
			
		||||
        Arguments:
 | 
			
		||||
            points (torch.Tensor): tensor of points that you want to query
 | 
			
		||||
                of a form (n_grids, n_points, 3)
 | 
			
		||||
            grid_values: an object of type Class.values_type which has tensors as
 | 
			
		||||
                members which have shapes derived from the get_shapes() method
 | 
			
		||||
            locator: a VolumeLocator object
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor: shape (n_grids, n_points, n_features)
 | 
			
		||||
        """
 | 
			
		||||
        points_local = locator.world_to_local_coords(points)
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        return self.evaluate_local(points_local, grid_values)
 | 
			
		||||
 | 
			
		||||
    def evaluate_local(
 | 
			
		||||
        self, points: torch.Tensor, grid_values: VoxelGridValuesBase
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Evaluates the voxel grid at points in the local coordinate frame,
 | 
			
		||||
        The interpolation type is determined by the `mode` member.
 | 
			
		||||
 | 
			
		||||
        Arguments:
 | 
			
		||||
            points (torch.Tensor): tensor of points that you want to query
 | 
			
		||||
                of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
 | 
			
		||||
            grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
 | 
			
		||||
                as members which have shapes derived from the get_shapes() method
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor: shape (n_grids, n_points, n_features)
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> Dict[str, Tuple]:
 | 
			
		||||
        """
 | 
			
		||||
        Using parameters from the __init__ method, this method returns the
 | 
			
		||||
        shapes of individual tensors needed to run the evaluate method.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
 | 
			
		||||
                replace the shapes in the dictionary with tensors of those shapes and add the
 | 
			
		||||
                first 'batch' dimension. If the required shape is (a, b) and you want to
 | 
			
		||||
                have g grids than the tensor that replaces the shape should have the
 | 
			
		||||
                shape (g, a, b).
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
 | 
			
		||||
    voxel_grid: torch.Tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class FullResolutionVoxelGrid(VoxelGridBase):
 | 
			
		||||
    """
 | 
			
		||||
    Full resolution voxel grid equivalent to 4D tensor where shape is
 | 
			
		||||
    (features, width, height, depth) with linear interpolation between voxels.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # the type of grid_values argument needed to run evaluate_local()
 | 
			
		||||
    values_type: ClassVar[Type[VoxelGridValuesBase]] = FullResolutionVoxelGridValues
 | 
			
		||||
 | 
			
		||||
    def evaluate_local(
 | 
			
		||||
        self, points: torch.Tensor, grid_values: FullResolutionVoxelGridValues
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Evaluates the voxel grid at points in the local coordinate frame,
 | 
			
		||||
        The interpolation type is determined by the `mode` member.
 | 
			
		||||
 | 
			
		||||
        Arguments:
 | 
			
		||||
            points (torch.Tensor): tensor of points that you want to query
 | 
			
		||||
                of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
 | 
			
		||||
            grid_values: an object of type values_type which has tensors as
 | 
			
		||||
                members which have shapes derived from the get_shapes() method
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor: shape (n_grids, n_points, n_features)
 | 
			
		||||
        """
 | 
			
		||||
        return interpolate_volume(
 | 
			
		||||
            points,
 | 
			
		||||
            grid_values.voxel_grid,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding,
 | 
			
		||||
            mode=self.mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> Dict[str, Tuple]:
 | 
			
		||||
        return {"voxel_grid": (self.n_features, *self.resolution)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class CPFactorizedVoxelGridValues(VoxelGridValuesBase):
 | 
			
		||||
    vector_components_x: torch.Tensor
 | 
			
		||||
    vector_components_y: torch.Tensor
 | 
			
		||||
    vector_components_z: torch.Tensor
 | 
			
		||||
    basis_matrix: Optional[torch.Tensor] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class CPFactorizedVoxelGrid(VoxelGridBase):
 | 
			
		||||
    """
 | 
			
		||||
    Canonical Polyadic (CP/CANDECOMP/PARAFAC) Factorization factorizes the 3d grid into three
 | 
			
		||||
    vectors (x, y, z). For n_components=n, the 3d grid is a sum of the two outer products
 | 
			
		||||
    (call it ⊗) of each vector type (x, y, z):
 | 
			
		||||
 | 
			
		||||
    3d_grid = x0 ⊗ y0 ⊗ z0 + x1 ⊗ y1 ⊗ z1 + ... + xn ⊗ yn ⊗ zn
 | 
			
		||||
 | 
			
		||||
    These tensors are passed in a object of CPFactorizedVoxelGridValues (here obj) as
 | 
			
		||||
    obj.vector_components_x, obj.vector_components_y, obj.vector_components_z. Their shapes are
 | 
			
		||||
    `(n_components, r)` where `r` is the relevant resolution.
 | 
			
		||||
 | 
			
		||||
    Each element of this sum has an extra dimension, which gets matrix-multiplied by an
 | 
			
		||||
    appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
 | 
			
		||||
    brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
 | 
			
		||||
    of different components are summed together to create (n_grids, n_components, 1) tensor.
 | 
			
		||||
    With some notation abuse, ignoring the interpolation operation, simplifying and denoting
 | 
			
		||||
    n_features as F, n_components as C and n_grids as G:
 | 
			
		||||
 | 
			
		||||
    3d_grid = (x ⊗ y ⊗ z) @ basis # GWHDC x GCF -> GWHDF
 | 
			
		||||
 | 
			
		||||
    The basis feature vectors are passed as obj.basis_matrix.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
        n_components: number of vector triplets, higher number gives better approximation.
 | 
			
		||||
        matrix_reduction: how to transform components. If matrix_reduction=True result
 | 
			
		||||
            matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
 | 
			
		||||
            basis_matrix of shape (n_grids, n_components, n_features). If
 | 
			
		||||
            matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
 | 
			
		||||
            is summed along the rows to get (n_grids, n_points, 1).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # the type of grid_values argument needed to run evaluate_local()
 | 
			
		||||
    values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues
 | 
			
		||||
 | 
			
		||||
    n_components: int = 24
 | 
			
		||||
    matrix_reduction: bool = True
 | 
			
		||||
 | 
			
		||||
    def evaluate_local(
 | 
			
		||||
        self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        def factor(i):
 | 
			
		||||
            axis = ["x", "y", "z"][i]
 | 
			
		||||
            index = points[..., i, None]
 | 
			
		||||
            vector = getattr(grid_values, "vector_components_" + axis)
 | 
			
		||||
            return interpolate_line(
 | 
			
		||||
                index,
 | 
			
		||||
                vector,
 | 
			
		||||
                align_corners=self.align_corners,
 | 
			
		||||
                padding_mode=self.padding,
 | 
			
		||||
                mode=self.mode,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # collect points from all the vectors and multipy them out
 | 
			
		||||
        mult = factor(0) * factor(1) * factor(2)
 | 
			
		||||
 | 
			
		||||
        # reduce the result from
 | 
			
		||||
        # (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
 | 
			
		||||
        if grid_values.basis_matrix is not None:
 | 
			
		||||
            # (n_grids, n_points, n_features) =
 | 
			
		||||
            # (n_grids, n_points, total_n_components) x (total_n_components, n_features)
 | 
			
		||||
            return torch.bmm(mult, grid_values.basis_matrix)
 | 
			
		||||
 | 
			
		||||
        return mult.sum(axis=-1, keepdim=True)
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> Dict[str, Tuple[int, int]]:
 | 
			
		||||
        if self.matrix_reduction is False and self.n_features != 1:
 | 
			
		||||
            raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
 | 
			
		||||
 | 
			
		||||
        shape_dict = {
 | 
			
		||||
            "vector_components_x": (self.n_components, self.resolution[0]),
 | 
			
		||||
            "vector_components_y": (self.n_components, self.resolution[1]),
 | 
			
		||||
            "vector_components_z": (self.n_components, self.resolution[2]),
 | 
			
		||||
        }
 | 
			
		||||
        if self.matrix_reduction:
 | 
			
		||||
            shape_dict["basis_matrix"] = (self.n_components, self.n_features)
 | 
			
		||||
        return shape_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class VMFactorizedVoxelGridValues(VoxelGridValuesBase):
 | 
			
		||||
    vector_components_x: torch.Tensor
 | 
			
		||||
    vector_components_y: torch.Tensor
 | 
			
		||||
    vector_components_z: torch.Tensor
 | 
			
		||||
    matrix_components_xy: torch.Tensor
 | 
			
		||||
    matrix_components_yz: torch.Tensor
 | 
			
		||||
    matrix_components_xz: torch.Tensor
 | 
			
		||||
    basis_matrix: Optional[torch.Tensor] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class VMFactorizedVoxelGrid(VoxelGridBase):
 | 
			
		||||
    """
 | 
			
		||||
    Implementation of Vector-Matrix Factorization of a tensor from
 | 
			
		||||
    https://arxiv.org/abs/2203.09517.
 | 
			
		||||
 | 
			
		||||
    Vector-Matrix Factorization factorizes the 3d grid into three matrices
 | 
			
		||||
    (xy, xz, yz) and three vectors (x, y, z). For n_components=1, the 3d grid
 | 
			
		||||
    is a sum of the outer products (call it ⊗) of each matrix with its
 | 
			
		||||
    complementary vector:
 | 
			
		||||
 | 
			
		||||
    3d_grid = xy ⊗ z + xz ⊗ y + yz ⊗ x.
 | 
			
		||||
 | 
			
		||||
    These tensors are passed in a VMFactorizedVoxelGridValues object (here obj)
 | 
			
		||||
    as obj.matrix_components_xy, obj.matrix_components_xy, obj.vector_components_y, etc.
 | 
			
		||||
 | 
			
		||||
    Their shapes are `(n_grids, n_components, r0, r1)` for matrix_components and
 | 
			
		||||
    (n_grids, n_components, r2)` for vector_componenets. Each of `r0, r1 and r2` coresponds
 | 
			
		||||
    to one resolution in (width, height and depth).
 | 
			
		||||
 | 
			
		||||
    Each element of this sum has an extra dimension, which gets matrix-multiplied by an
 | 
			
		||||
    appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
 | 
			
		||||
    brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
 | 
			
		||||
    of different components are summed together to create (n_grids, n_components, 1) tensor.
 | 
			
		||||
    With some notation abuse, ignoring the interpolation operation, simplifying and denoting
 | 
			
		||||
    n_features as F, n_components as C (which can differ for each dimension) and n_grids as G:
 | 
			
		||||
 | 
			
		||||
    3d_grid = concat((xy ⊗ z), (xz ⊗ y).permute(0, 2, 1),
 | 
			
		||||
                (yz ⊗ x).permute(2, 0, 1)) @ basis_matrix # GWHDC x GCF -> GWHDF
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
        n_components: total number of matrix vector pairs, this must be divisible by 3. Set
 | 
			
		||||
            this if you want to have equal representational power in all 3 directions. You
 | 
			
		||||
            must specify either n_components or distribution_of_components, you cannot
 | 
			
		||||
            specify both.
 | 
			
		||||
        distribution_of_components: if you do not want equal representational power in
 | 
			
		||||
            all 3 directions specify a tuple of numbers of matrix_vector pairs for each
 | 
			
		||||
            coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
 | 
			
		||||
            either n_components or distribution_of_components, you cannot specify both.
 | 
			
		||||
        matrix_reduction: how to transform components. If matrix_reduction=True result
 | 
			
		||||
            matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
 | 
			
		||||
            the basis_matrix of shape (n_grids, n_components, n_features). If
 | 
			
		||||
            matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
 | 
			
		||||
            is summed along the rows to get (n_grids, n_points, 1).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # the type of grid_values argument needed to run evaluate_local()
 | 
			
		||||
    values_type: ClassVar[Type[VoxelGridValuesBase]] = VMFactorizedVoxelGridValues
 | 
			
		||||
 | 
			
		||||
    n_components: Optional[int] = None
 | 
			
		||||
    distribution_of_components: Optional[Tuple[int, int, int]] = None
 | 
			
		||||
    matrix_reduction: bool = True
 | 
			
		||||
 | 
			
		||||
    def evaluate_local(
 | 
			
		||||
        self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        # collect points from matrices and vectors and multiply them
 | 
			
		||||
        a = interpolate_plane(
 | 
			
		||||
            points[..., :2],
 | 
			
		||||
            grid_values.matrix_components_xy,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding,
 | 
			
		||||
            mode=self.mode,
 | 
			
		||||
        ) * interpolate_line(
 | 
			
		||||
            points[..., 2:],
 | 
			
		||||
            grid_values.vector_components_z,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding,
 | 
			
		||||
            mode=self.mode,
 | 
			
		||||
        )
 | 
			
		||||
        b = interpolate_plane(
 | 
			
		||||
            points[..., [0, 2]],
 | 
			
		||||
            grid_values.matrix_components_xz,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding,
 | 
			
		||||
            mode=self.mode,
 | 
			
		||||
        ) * interpolate_line(
 | 
			
		||||
            points[..., 1:2],
 | 
			
		||||
            grid_values.vector_components_y,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding,
 | 
			
		||||
            mode=self.mode,
 | 
			
		||||
        )
 | 
			
		||||
        c = interpolate_plane(
 | 
			
		||||
            points[..., 1:],
 | 
			
		||||
            grid_values.matrix_components_yz,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding,
 | 
			
		||||
            mode=self.mode,
 | 
			
		||||
        ) * interpolate_line(
 | 
			
		||||
            points[..., :1],
 | 
			
		||||
            grid_values.vector_components_x,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding,
 | 
			
		||||
            mode=self.mode,
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[28]
 | 
			
		||||
        feats = torch.cat((a, b, c), axis=-1)
 | 
			
		||||
 | 
			
		||||
        # reduce the result from
 | 
			
		||||
        # (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
 | 
			
		||||
        if grid_values.basis_matrix is not None:
 | 
			
		||||
            # (n_grids, n_points, n_features) =
 | 
			
		||||
            # (n_grids, n_points, total_n_components) x
 | 
			
		||||
            #               (n_grids, total_n_components, n_features)
 | 
			
		||||
            return torch.bmm(feats, grid_values.basis_matrix)
 | 
			
		||||
        # pyre-ignore[28]
 | 
			
		||||
        return feats.sum(axis=-1, keepdim=True)
 | 
			
		||||
 | 
			
		||||
    def get_shapes(self) -> Dict[str, Tuple]:
 | 
			
		||||
        if self.matrix_reduction is False and self.n_features != 1:
 | 
			
		||||
            raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
 | 
			
		||||
        if self.distribution_of_components is None and self.n_components is None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "You need to provide n_components or distribution_of_components"
 | 
			
		||||
            )
 | 
			
		||||
        if (
 | 
			
		||||
            self.distribution_of_components is not None
 | 
			
		||||
            and self.n_components is not None
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "You cannot define n_components and distribution_of_components"
 | 
			
		||||
            )
 | 
			
		||||
        # pyre-ignore[58]
 | 
			
		||||
        if self.distribution_of_components is None and self.n_components % 3 != 0:
 | 
			
		||||
            raise ValueError("n_components must be divisible by 3")
 | 
			
		||||
        if self.distribution_of_components is None:
 | 
			
		||||
            # pyre-ignore[58]
 | 
			
		||||
            calculated_distribution_of_components = [
 | 
			
		||||
                self.n_components // 3 for _ in range(3)
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            calculated_distribution_of_components = self.distribution_of_components
 | 
			
		||||
 | 
			
		||||
        shape_dict = {
 | 
			
		||||
            "vector_components_x": (
 | 
			
		||||
                calculated_distribution_of_components[1],
 | 
			
		||||
                self.resolution[0],
 | 
			
		||||
            ),
 | 
			
		||||
            "vector_components_y": (
 | 
			
		||||
                calculated_distribution_of_components[2],
 | 
			
		||||
                self.resolution[1],
 | 
			
		||||
            ),
 | 
			
		||||
            "vector_components_z": (
 | 
			
		||||
                calculated_distribution_of_components[0],
 | 
			
		||||
                self.resolution[2],
 | 
			
		||||
            ),
 | 
			
		||||
            "matrix_components_xy": (
 | 
			
		||||
                calculated_distribution_of_components[0],
 | 
			
		||||
                self.resolution[0],
 | 
			
		||||
                self.resolution[1],
 | 
			
		||||
            ),
 | 
			
		||||
            "matrix_components_yz": (
 | 
			
		||||
                calculated_distribution_of_components[1],
 | 
			
		||||
                self.resolution[1],
 | 
			
		||||
                self.resolution[2],
 | 
			
		||||
            ),
 | 
			
		||||
            "matrix_components_xz": (
 | 
			
		||||
                calculated_distribution_of_components[2],
 | 
			
		||||
                self.resolution[0],
 | 
			
		||||
                self.resolution[2],
 | 
			
		||||
            ),
 | 
			
		||||
        }
 | 
			
		||||
        if self.matrix_reduction:
 | 
			
		||||
            shape_dict["basis_matrix"] = (
 | 
			
		||||
                sum(calculated_distribution_of_components),
 | 
			
		||||
                self.n_features,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return shape_dict
 | 
			
		||||
							
								
								
									
										587
									
								
								tests/implicitron/test_voxel_grids.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										587
									
								
								tests/implicitron/test_voxel_grids.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,587 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.utils import (
 | 
			
		||||
    interpolate_line,
 | 
			
		||||
    interpolate_plane,
 | 
			
		||||
    interpolate_volume,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
 | 
			
		||||
    CPFactorizedVoxelGrid,
 | 
			
		||||
    FullResolutionVoxelGrid,
 | 
			
		||||
    VMFactorizedVoxelGrid,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    Tests Voxel grids, tests them by setting all elements to zero (after retrieving
 | 
			
		||||
    they should also return zero) and by setting all of the elements to one and
 | 
			
		||||
    getting the result. Also tests the interpolation by 'manually' interpolating
 | 
			
		||||
    one by one sample and comparing with the batched implementation.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def test_my_code(self):
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def get_random_normalized_points(
 | 
			
		||||
        self, n_grids, n_points, dimension=3
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        # create random query points
 | 
			
		||||
        return torch.rand(n_grids, n_points, dimension) * 2 - 1
 | 
			
		||||
 | 
			
		||||
    def _test_query_with_constant_init_cp(
 | 
			
		||||
        self,
 | 
			
		||||
        n_grids: int,
 | 
			
		||||
        n_features: int,
 | 
			
		||||
        n_components: int,
 | 
			
		||||
        resolution: Tuple[int],
 | 
			
		||||
        value: float = 1,
 | 
			
		||||
        n_points: int = 1,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # set everything to 'value' and do query for elementsthe result should
 | 
			
		||||
        # be of shape (n_grids, n_points, n_features) and be filled with n_components
 | 
			
		||||
        # * value
 | 
			
		||||
        grid = CPFactorizedVoxelGrid(
 | 
			
		||||
            resolution=resolution,
 | 
			
		||||
            n_components=n_components,
 | 
			
		||||
            n_features=n_features,
 | 
			
		||||
        )
 | 
			
		||||
        shapes = grid.get_shapes()
 | 
			
		||||
 | 
			
		||||
        params = grid.values_type(
 | 
			
		||||
            **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        assert torch.allclose(
 | 
			
		||||
            grid.evaluate_local(
 | 
			
		||||
                self.get_random_normalized_points(n_grids, n_points), params
 | 
			
		||||
            ),
 | 
			
		||||
            torch.ones(n_grids, n_points, n_features) * n_components * value,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _test_query_with_constant_init_vm(
 | 
			
		||||
        self,
 | 
			
		||||
        n_grids: int,
 | 
			
		||||
        n_features: int,
 | 
			
		||||
        resolution: Tuple[int],
 | 
			
		||||
        n_components: Optional[int] = None,
 | 
			
		||||
        distribution: Optional[Tuple[int]] = None,
 | 
			
		||||
        value: float = 1,
 | 
			
		||||
        n_points: int = 1,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # set everything to 'value' and do query for elements
 | 
			
		||||
        grid = VMFactorizedVoxelGrid(
 | 
			
		||||
            n_features=n_features,
 | 
			
		||||
            resolution=resolution,
 | 
			
		||||
            n_components=n_components,
 | 
			
		||||
            distribution_of_components=distribution,
 | 
			
		||||
        )
 | 
			
		||||
        shapes = grid.get_shapes()
 | 
			
		||||
        params = grid.values_type(
 | 
			
		||||
            **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_element = (
 | 
			
		||||
            n_components * value if distribution is None else sum(distribution) * value
 | 
			
		||||
        )
 | 
			
		||||
        assert torch.allclose(
 | 
			
		||||
            grid.evaluate_local(
 | 
			
		||||
                self.get_random_normalized_points(n_grids, n_points), params
 | 
			
		||||
            ),
 | 
			
		||||
            torch.ones(n_grids, n_points, n_features) * expected_element,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _test_query_with_constant_init_full(
 | 
			
		||||
        self,
 | 
			
		||||
        n_grids: int,
 | 
			
		||||
        n_features: int,
 | 
			
		||||
        resolution: Tuple[int],
 | 
			
		||||
        value: int = 1,
 | 
			
		||||
        n_points: int = 1,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # set everything to 'value' and do query for elements
 | 
			
		||||
        grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution)
 | 
			
		||||
        shapes = grid.get_shapes()
 | 
			
		||||
        params = grid.values_type(
 | 
			
		||||
            **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_element = value
 | 
			
		||||
        assert torch.allclose(
 | 
			
		||||
            grid.evaluate_local(
 | 
			
		||||
                self.get_random_normalized_points(n_grids, n_points), params
 | 
			
		||||
            ),
 | 
			
		||||
            torch.ones(n_grids, n_points, n_features) * expected_element,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_query_with_constant_init(self):
 | 
			
		||||
        with self.subTest("Full"):
 | 
			
		||||
            self._test_query_with_constant_init_full(
 | 
			
		||||
                n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
 | 
			
		||||
            )
 | 
			
		||||
        with self.subTest("Full with 1 in dimensions"):
 | 
			
		||||
            self._test_query_with_constant_init_full(
 | 
			
		||||
                n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
 | 
			
		||||
            )
 | 
			
		||||
        with self.subTest("CP"):
 | 
			
		||||
            self._test_query_with_constant_init_cp(
 | 
			
		||||
                n_grids=5,
 | 
			
		||||
                n_features=6,
 | 
			
		||||
                n_components=7,
 | 
			
		||||
                resolution=(3, 4, 5),
 | 
			
		||||
                n_points=2,
 | 
			
		||||
            )
 | 
			
		||||
        with self.subTest("CP with 1 in dimensions"):
 | 
			
		||||
            self._test_query_with_constant_init_cp(
 | 
			
		||||
                n_grids=2,
 | 
			
		||||
                n_features=1,
 | 
			
		||||
                n_components=3,
 | 
			
		||||
                resolution=(3, 1, 1),
 | 
			
		||||
                n_points=4,
 | 
			
		||||
            )
 | 
			
		||||
        with self.subTest("VM with symetric distribution"):
 | 
			
		||||
            self._test_query_with_constant_init_vm(
 | 
			
		||||
                n_grids=6,
 | 
			
		||||
                n_features=9,
 | 
			
		||||
                resolution=(2, 12, 2),
 | 
			
		||||
                n_components=12,
 | 
			
		||||
                n_points=3,
 | 
			
		||||
            )
 | 
			
		||||
        with self.subTest("VM with distribution"):
 | 
			
		||||
            self._test_query_with_constant_init_vm(
 | 
			
		||||
                n_grids=5,
 | 
			
		||||
                n_features=1,
 | 
			
		||||
                resolution=(5, 9, 7),
 | 
			
		||||
                distribution=(33, 41, 1),
 | 
			
		||||
                n_points=7,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_query_with_zero_init(self):
 | 
			
		||||
        with self.subTest("Query testing with zero init CPFactorizedVoxelGrid"):
 | 
			
		||||
            self._test_query_with_constant_init_cp(
 | 
			
		||||
                n_grids=5,
 | 
			
		||||
                n_features=6,
 | 
			
		||||
                n_components=7,
 | 
			
		||||
                resolution=(3, 2, 5),
 | 
			
		||||
                n_points=3,
 | 
			
		||||
                value=0,
 | 
			
		||||
            )
 | 
			
		||||
        with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
 | 
			
		||||
            self._test_query_with_constant_init_vm(
 | 
			
		||||
                n_grids=2,
 | 
			
		||||
                n_features=9,
 | 
			
		||||
                resolution=(2, 11, 3),
 | 
			
		||||
                n_components=3,
 | 
			
		||||
                n_points=3,
 | 
			
		||||
                value=0,
 | 
			
		||||
            )
 | 
			
		||||
        with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
 | 
			
		||||
            self._test_query_with_constant_init_full(
 | 
			
		||||
                n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
        expand_args_fields(FullResolutionVoxelGrid)
 | 
			
		||||
        expand_args_fields(CPFactorizedVoxelGrid)
 | 
			
		||||
        expand_args_fields(VMFactorizedVoxelGrid)
 | 
			
		||||
 | 
			
		||||
    def _interpolate_1D(
 | 
			
		||||
        self, points: torch.Tensor, vectors: torch.Tensor
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        interpolate vector from points, which are (batch, 1) and individual point is in [-1, 1]
 | 
			
		||||
        """
 | 
			
		||||
        result = []
 | 
			
		||||
        _, _, width = vectors.shape
 | 
			
		||||
        # transform from [-1, 1] to [0, width-1]
 | 
			
		||||
        points = (points + 1) / 2 * (width - 1)
 | 
			
		||||
        for vector, row in zip(vectors, points):
 | 
			
		||||
            newrow = []
 | 
			
		||||
            for x in row:
 | 
			
		||||
                xf, xc = int(torch.floor(x)), int(torch.ceil(x))
 | 
			
		||||
                itemf, itemc = vector[:, xf], vector[:, xc]
 | 
			
		||||
                tmp = itemf * (xc - x) + itemc * (x - xf)
 | 
			
		||||
                newrow.append(tmp[None, None, :])
 | 
			
		||||
            result.append(torch.cat(newrow, dim=1))
 | 
			
		||||
        return torch.cat(result)
 | 
			
		||||
 | 
			
		||||
    def _interpolate_2D(
 | 
			
		||||
        self, points: torch.Tensor, matrices: torch.Tensor
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        interpolate matrix from points, which are (batch, 2) and individual point is in [-1, 1]
 | 
			
		||||
        """
 | 
			
		||||
        result = []
 | 
			
		||||
        n_grids, _, width, height = matrices.shape
 | 
			
		||||
        points = (points + 1) / 2 * (torch.tensor([[[width, height]]]) - 1)
 | 
			
		||||
        for matrix, row in zip(matrices, points):
 | 
			
		||||
            newrow = []
 | 
			
		||||
            for x, y in row:
 | 
			
		||||
                xf, xc = int(torch.floor(x)), int(torch.ceil(x))
 | 
			
		||||
                yf, yc = int(torch.floor(y)), int(torch.ceil(y))
 | 
			
		||||
                itemff, itemfc = matrix[:, xf, yf], matrix[:, xf, yc]
 | 
			
		||||
                itemcf, itemcc = matrix[:, xc, yf], matrix[:, xc, yc]
 | 
			
		||||
                itemf = itemff * (xc - x) + itemcf * (x - xf)
 | 
			
		||||
                itemc = itemfc * (xc - x) + itemcc * (x - xf)
 | 
			
		||||
                tmp = itemf * (yc - y) + itemc * (y - yf)
 | 
			
		||||
                newrow.append(tmp[None, None, :])
 | 
			
		||||
            result.append(torch.cat(newrow, dim=1))
 | 
			
		||||
        return torch.cat(result)
 | 
			
		||||
 | 
			
		||||
    def _interpolate_3D(
 | 
			
		||||
        self, points: torch.Tensor, tensors: torch.Tensor
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        interpolate tensors from points, which are (batch, 3) and individual point is in [-1, 1]
 | 
			
		||||
        """
 | 
			
		||||
        result = []
 | 
			
		||||
        _, _, width, height, depth = tensors.shape
 | 
			
		||||
        batch_normalized_points = (
 | 
			
		||||
            (points + 1) / 2 * (torch.tensor([[[width, height, depth]]]) - 1)
 | 
			
		||||
        )
 | 
			
		||||
        batch_points = points
 | 
			
		||||
 | 
			
		||||
        for tensor, points, normalized_points in zip(
 | 
			
		||||
            tensors, batch_points, batch_normalized_points
 | 
			
		||||
        ):
 | 
			
		||||
            newrow = []
 | 
			
		||||
            for (x, y, z), (_, _, nz) in zip(points, normalized_points):
 | 
			
		||||
                zf, zc = int(torch.floor(nz)), int(torch.ceil(nz))
 | 
			
		||||
                itemf = self._interpolate_2D(
 | 
			
		||||
                    points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zf]
 | 
			
		||||
                )
 | 
			
		||||
                itemc = self._interpolate_2D(
 | 
			
		||||
                    points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zc]
 | 
			
		||||
                )
 | 
			
		||||
                tmp = self._interpolate_1D(
 | 
			
		||||
                    points=torch.tensor([[[z]]]),
 | 
			
		||||
                    vectors=torch.cat((itemf, itemc), dim=1).permute(0, 2, 1),
 | 
			
		||||
                )
 | 
			
		||||
                newrow.append(tmp)
 | 
			
		||||
            result.append(torch.cat(newrow, dim=1))
 | 
			
		||||
        return torch.cat(result)
 | 
			
		||||
 | 
			
		||||
    def test_interpolation(self):
 | 
			
		||||
 | 
			
		||||
        with self.subTest("1D interpolation"):
 | 
			
		||||
            points = self.get_random_normalized_points(
 | 
			
		||||
                n_grids=4, n_points=5, dimension=1
 | 
			
		||||
            )
 | 
			
		||||
            vector = torch.randn(size=(4, 3, 2))
 | 
			
		||||
            assert torch.allclose(
 | 
			
		||||
                self._interpolate_1D(points, vector),
 | 
			
		||||
                interpolate_line(
 | 
			
		||||
                    points,
 | 
			
		||||
                    vector,
 | 
			
		||||
                    align_corners=True,
 | 
			
		||||
                    padding_mode="zeros",
 | 
			
		||||
                    mode="bilinear",
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
        with self.subTest("2D interpolation"):
 | 
			
		||||
            points = self.get_random_normalized_points(
 | 
			
		||||
                n_grids=4, n_points=5, dimension=2
 | 
			
		||||
            )
 | 
			
		||||
            matrix = torch.randn(size=(4, 2, 3, 5))
 | 
			
		||||
            assert torch.allclose(
 | 
			
		||||
                self._interpolate_2D(points, matrix),
 | 
			
		||||
                interpolate_plane(
 | 
			
		||||
                    points,
 | 
			
		||||
                    matrix,
 | 
			
		||||
                    align_corners=True,
 | 
			
		||||
                    padding_mode="zeros",
 | 
			
		||||
                    mode="bilinear",
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        with self.subTest("3D interpolation"):
 | 
			
		||||
            points = self.get_random_normalized_points(
 | 
			
		||||
                n_grids=4, n_points=5, dimension=3
 | 
			
		||||
            )
 | 
			
		||||
            tensor = torch.randn(size=(4, 5, 2, 7, 2))
 | 
			
		||||
            assert torch.allclose(
 | 
			
		||||
                self._interpolate_3D(points, tensor),
 | 
			
		||||
                interpolate_volume(
 | 
			
		||||
                    points,
 | 
			
		||||
                    tensor,
 | 
			
		||||
                    align_corners=True,
 | 
			
		||||
                    padding_mode="zeros",
 | 
			
		||||
                    mode="bilinear",
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_floating_point_query(self):
 | 
			
		||||
        """
 | 
			
		||||
        test querying the voxel grids on some float positions
 | 
			
		||||
        """
 | 
			
		||||
        with self.subTest("FullResolution"):
 | 
			
		||||
            grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1))
 | 
			
		||||
            params = grid.values_type(**grid.get_shapes())
 | 
			
		||||
            params.voxel_grid = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [
 | 
			
		||||
                        [[[1, 3], [5, 7]], [[9, 11], [13, 15]]],
 | 
			
		||||
                        [[[2, 4], [6, 8]], [[10, 12], [14, 16]]],
 | 
			
		||||
                    ],
 | 
			
		||||
                    [
 | 
			
		||||
                        [[[17, 18], [19, 20]], [[21, 22], [23, 24]]],
 | 
			
		||||
                        [[[25, 26], [27, 28]], [[29, 30], [31, 32]]],
 | 
			
		||||
                    ],
 | 
			
		||||
                ],
 | 
			
		||||
                dtype=torch.float,
 | 
			
		||||
            )
 | 
			
		||||
            points = (
 | 
			
		||||
                torch.tensor(
 | 
			
		||||
                    [
 | 
			
		||||
                        [
 | 
			
		||||
                            [1, 0, 1],
 | 
			
		||||
                            [0.5, 1, 1],
 | 
			
		||||
                            [1 / 3, 1 / 3, 2 / 3],
 | 
			
		||||
                        ],
 | 
			
		||||
                        [
 | 
			
		||||
                            [0, 1, 1],
 | 
			
		||||
                            [0, 0.5, 1],
 | 
			
		||||
                            [1 / 4, 1 / 4, 3 / 4],
 | 
			
		||||
                        ],
 | 
			
		||||
                    ]
 | 
			
		||||
                )
 | 
			
		||||
                / torch.tensor([[1.0, 1, 1]])
 | 
			
		||||
                * 2
 | 
			
		||||
                - 1
 | 
			
		||||
            )
 | 
			
		||||
            expected_result = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[11, 12], [11, 12], [6.333333, 7.3333333]],
 | 
			
		||||
                    [[20, 28], [19, 27], [19.25, 27.25]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            assert torch.allclose(
 | 
			
		||||
                grid.evaluate_local(points, params),
 | 
			
		||||
                expected_result,
 | 
			
		||||
                rtol=0.00001,
 | 
			
		||||
            ), grid.evaluate_local(points, params)
 | 
			
		||||
        with self.subTest("CP"):
 | 
			
		||||
            grid = CPFactorizedVoxelGrid(
 | 
			
		||||
                n_features=1, resolution=(1, 1, 1), n_components=3
 | 
			
		||||
            )
 | 
			
		||||
            params = grid.values_type(**grid.get_shapes())
 | 
			
		||||
            params.vector_components_x = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[1, 2], [10.5, 20.5]],
 | 
			
		||||
                    [[10, 20], [2, 4]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            params.vector_components_y = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[3, 4, 5], [30.5, 40.5, 50.5]],
 | 
			
		||||
                    [[30, 40, 50], [1, 3, 5]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            params.vector_components_z = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[6, 7, 8, 9], [60.5, 70.5, 80.5, 90.5]],
 | 
			
		||||
                    [[60, 70, 80, 90], [6, 7, 8, 9]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            params.basis_matrix = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[2.0], [2.0]],
 | 
			
		||||
                    [[1.0], [2.0]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            points = (
 | 
			
		||||
                torch.tensor(
 | 
			
		||||
                    [
 | 
			
		||||
                        [
 | 
			
		||||
                            [0, 2, 2],
 | 
			
		||||
                            [1, 2, 0.25],
 | 
			
		||||
                            [0.5, 0.5, 1],
 | 
			
		||||
                            [1 / 3, 2 / 3, 2 + 1 / 3],
 | 
			
		||||
                        ],
 | 
			
		||||
                        [
 | 
			
		||||
                            [1, 0, 1],
 | 
			
		||||
                            [0.5, 2, 2],
 | 
			
		||||
                            [1, 0.5, 0.5],
 | 
			
		||||
                            [1 / 4, 3 / 4, 2 + 1 / 4],
 | 
			
		||||
                        ],
 | 
			
		||||
                    ]
 | 
			
		||||
                )
 | 
			
		||||
                / torch.tensor([[[1.0, 2, 3]]])
 | 
			
		||||
                * 2
 | 
			
		||||
                - 1
 | 
			
		||||
            )
 | 
			
		||||
            expected_result_matrix = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[85450.25], [130566.5], [77658.75], [86285.422]],
 | 
			
		||||
                    [[42056], [60240], [45604], [38775]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            expected_result_sum = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[42725.125], [65283.25], [38829.375], [43142.711]],
 | 
			
		||||
                    [[42028], [60120], [45552], [38723.4375]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            with self.subTest("CP with basis_matrix reduction"):
 | 
			
		||||
                assert torch.allclose(
 | 
			
		||||
                    grid.evaluate_local(points, params),
 | 
			
		||||
                    expected_result_matrix,
 | 
			
		||||
                    rtol=0.00001,
 | 
			
		||||
                )
 | 
			
		||||
            del params.basis_matrix
 | 
			
		||||
            with self.subTest("CP with sum reduction"):
 | 
			
		||||
                assert torch.allclose(
 | 
			
		||||
                    grid.evaluate_local(points, params),
 | 
			
		||||
                    expected_result_sum,
 | 
			
		||||
                    rtol=0.00001,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        with self.subTest("VM"):
 | 
			
		||||
            grid = VMFactorizedVoxelGrid(
 | 
			
		||||
                n_features=1, resolution=(1, 1, 1), n_components=3
 | 
			
		||||
            )
 | 
			
		||||
            params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes())
 | 
			
		||||
            params.matrix_components_xy = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]],
 | 
			
		||||
                    [[[35, 36], [37, 38]], [[39, 40], [41, 42]]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            params.matrix_components_xz = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[[7, 8], [9, 10]], [[25, 26], [27, 28.0]]],
 | 
			
		||||
                    [[[43, 44], [45, 46]], [[47, 48], [49, 50]]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            params.matrix_components_yz = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[[13, 14], [15, 16]], [[31, 32], [33, 34.0]]],
 | 
			
		||||
                    [[[51, 52], [53, 54]], [[55, 56], [57, 58.0]]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            params.vector_components_z = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[5, 6], [23, 24.0]],
 | 
			
		||||
                    [[59, 60], [61, 62]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            params.vector_components_y = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[11, 12], [29, 30.0]],
 | 
			
		||||
                    [[63, 64], [65, 66]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            params.vector_components_x = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[17, 18], [35, 36.0]],
 | 
			
		||||
                    [[67, 68], [69, 70.0]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            params.basis_matrix = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [2, 2, 2, 2, 2, 2.0],
 | 
			
		||||
                    [1, 2, 1, 2, 1, 2.0],
 | 
			
		||||
                ]
 | 
			
		||||
            )[:, :, None]
 | 
			
		||||
            points = (
 | 
			
		||||
                torch.tensor(
 | 
			
		||||
                    [
 | 
			
		||||
                        [
 | 
			
		||||
                            [1, 0, 1],
 | 
			
		||||
                            [0.5, 1, 1],
 | 
			
		||||
                            [1 / 3, 1 / 3, 2 / 3],
 | 
			
		||||
                        ],
 | 
			
		||||
                        [
 | 
			
		||||
                            [0, 1, 0],
 | 
			
		||||
                            [0, 0, 0],
 | 
			
		||||
                            [0, 1, 0],
 | 
			
		||||
                        ],
 | 
			
		||||
                    ]
 | 
			
		||||
                )
 | 
			
		||||
                / torch.tensor([[[1.0, 1, 1]]])
 | 
			
		||||
                * 2
 | 
			
		||||
                - 1
 | 
			
		||||
            )
 | 
			
		||||
            expected_result_matrix = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[5696], [5854], [5484.888]],
 | 
			
		||||
                    [[27377], [26649], [27377]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            expected_result_sum = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [[2848], [2927], [2742.444]],
 | 
			
		||||
                    [[17902], [17420], [17902]],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            with self.subTest("VM with basis_matrix reduction"):
 | 
			
		||||
                assert torch.allclose(
 | 
			
		||||
                    grid.evaluate_local(points, params),
 | 
			
		||||
                    expected_result_matrix,
 | 
			
		||||
                    rtol=0.00001,
 | 
			
		||||
                )
 | 
			
		||||
            del params.basis_matrix
 | 
			
		||||
            with self.subTest("VM with sum reduction"):
 | 
			
		||||
                assert torch.allclose(
 | 
			
		||||
                    grid.evaluate_local(points, params),
 | 
			
		||||
                    expected_result_sum,
 | 
			
		||||
                    rtol=0.0001,
 | 
			
		||||
                ), grid.evaluate_local(points, params)
 | 
			
		||||
 | 
			
		||||
    def test_forward_with_small_init_std(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test does the grid return small values if it is initialized with small
 | 
			
		||||
        mean and small standard deviation.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def test(cls, **kwargs):
 | 
			
		||||
            with self.subTest(cls.__name__):
 | 
			
		||||
                n_grids = 3
 | 
			
		||||
                grid = cls(**kwargs)
 | 
			
		||||
                shapes = grid.get_shapes()
 | 
			
		||||
                params = cls.values_type(
 | 
			
		||||
                    **{
 | 
			
		||||
                        k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001)
 | 
			
		||||
                        for k, shape in shapes.items()
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
                points = self.get_random_normalized_points(n_grids=n_grids, n_points=3)
 | 
			
		||||
                max_expected_result = torch.zeros((len(points), 10)) + 1e-2
 | 
			
		||||
                assert torch.all(
 | 
			
		||||
                    grid.evaluate_local(points, params) < max_expected_result
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        test(
 | 
			
		||||
            FullResolutionVoxelGrid,
 | 
			
		||||
            resolution=(4, 6, 9),
 | 
			
		||||
            n_features=10,
 | 
			
		||||
        )
 | 
			
		||||
        test(
 | 
			
		||||
            CPFactorizedVoxelGrid,
 | 
			
		||||
            resolution=(4, 6, 9),
 | 
			
		||||
            n_features=10,
 | 
			
		||||
            n_components=3,
 | 
			
		||||
        )
 | 
			
		||||
        test(
 | 
			
		||||
            VMFactorizedVoxelGrid,
 | 
			
		||||
            resolution=(4, 6, 9),
 | 
			
		||||
            n_features=10,
 | 
			
		||||
            n_components=3,
 | 
			
		||||
        )
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user