mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	voxel_grid_implicit_function
Reviewed By: shapovalov Differential Revision: D40622304 fbshipit-source-id: 277515a55c46d9b8300058b439526539a7fe00a0
This commit is contained in:
		
							parent
							
								
									611aba9a20
								
							
						
					
					
						commit
						74754bbf17
					
				@ -394,6 +394,168 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
        in_features: 256
 | 
			
		||||
        out_features: 3
 | 
			
		||||
        ray_dir_in_camera_coords: false
 | 
			
		||||
    implicit_function_VoxelGridImplicitFunction_args:
 | 
			
		||||
      harmonic_embedder_xyz_density_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      harmonic_embedder_xyz_color_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      harmonic_embedder_dir_color_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      decoder_density_class_type: MLPDecoder
 | 
			
		||||
      decoder_color_class_type: MLPDecoder
 | 
			
		||||
      use_multiple_streams: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      scaffold_calculating_epochs: []
 | 
			
		||||
      scaffold_resolution:
 | 
			
		||||
      - 128
 | 
			
		||||
      - 128
 | 
			
		||||
      - 128
 | 
			
		||||
      scaffold_empty_space_threshold: 0.001
 | 
			
		||||
      scaffold_occupancy_chunk_size: 'inf'
 | 
			
		||||
      scaffold_max_pool_kernel_size: 3
 | 
			
		||||
      scaffold_filter_points: true
 | 
			
		||||
      volume_cropping_epochs: []
 | 
			
		||||
      voxel_grid_density_args:
 | 
			
		||||
        voxel_grid_class_type: FullResolutionVoxelGrid
 | 
			
		||||
        extents:
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        translation:
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        init_std: 0.1
 | 
			
		||||
        init_mean: 0.0
 | 
			
		||||
        hold_voxel_grid_as_parameters: true
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        voxel_grid_CPFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: 24
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
        voxel_grid_FullResolutionVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
        voxel_grid_VMFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: null
 | 
			
		||||
          distribution_of_components: null
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
      voxel_grid_color_args:
 | 
			
		||||
        voxel_grid_class_type: FullResolutionVoxelGrid
 | 
			
		||||
        extents:
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        translation:
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        init_std: 0.1
 | 
			
		||||
        init_mean: 0.0
 | 
			
		||||
        hold_voxel_grid_as_parameters: true
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        voxel_grid_CPFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: 24
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
        voxel_grid_FullResolutionVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
        voxel_grid_VMFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: null
 | 
			
		||||
          distribution_of_components: null
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
      decoder_density_ElementwiseDecoder_args:
 | 
			
		||||
        scale: 1.0
 | 
			
		||||
        shift: 0.0
 | 
			
		||||
        operation: IDENTITY
 | 
			
		||||
      decoder_density_MLPDecoder_args:
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        network_args:
 | 
			
		||||
          n_layers: 8
 | 
			
		||||
          output_dim: 256
 | 
			
		||||
          skip_dim: 39
 | 
			
		||||
          hidden_dim: 256
 | 
			
		||||
          input_skips:
 | 
			
		||||
          - 5
 | 
			
		||||
          skip_affine_trans: false
 | 
			
		||||
          last_layer_bias_init: null
 | 
			
		||||
          last_activation: RELU
 | 
			
		||||
          use_xavier_init: true
 | 
			
		||||
      decoder_color_ElementwiseDecoder_args:
 | 
			
		||||
        scale: 1.0
 | 
			
		||||
        shift: 0.0
 | 
			
		||||
        operation: IDENTITY
 | 
			
		||||
      decoder_color_MLPDecoder_args:
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        network_args:
 | 
			
		||||
          n_layers: 8
 | 
			
		||||
          output_dim: 256
 | 
			
		||||
          skip_dim: 39
 | 
			
		||||
          hidden_dim: 256
 | 
			
		||||
          input_skips:
 | 
			
		||||
          - 5
 | 
			
		||||
          skip_affine_trans: false
 | 
			
		||||
          last_layer_bias_init: null
 | 
			
		||||
          last_activation: RELU
 | 
			
		||||
          use_xavier_init: true
 | 
			
		||||
    view_metrics_ViewMetrics_args: {}
 | 
			
		||||
    regularization_metrics_RegularizationMetrics_args: {}
 | 
			
		||||
optimizer_factory_ImplicitronOptimizerFactory_args:
 | 
			
		||||
 | 
			
		||||
@ -52,6 +52,9 @@ from .implicit_function.scene_representation_networks import (  # noqa
 | 
			
		||||
    SRNHyperNetImplicitFunction,
 | 
			
		||||
    SRNImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
from .implicit_function.voxel_grid_implicit_function import (  # noqa
 | 
			
		||||
    VoxelGridImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .renderer.base import (
 | 
			
		||||
    BaseRenderer,
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,616 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from dataclasses import fields
 | 
			
		||||
from typing import Callable, Dict, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from omegaconf import DictConfig
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
 | 
			
		||||
    DecoderFunctionBase,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid import VoxelGridModule
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    enable_get_default_args,
 | 
			
		||||
    get_default_args_field,
 | 
			
		||||
    registry,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer import ray_bundle_to_ray_points
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
 | 
			
		||||
 | 
			
		||||
enable_get_default_args(HarmonicEmbedding)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
# pyre-ignore[13]
 | 
			
		||||
class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    This implicit function consists of two streams, one for the density calculation and one
 | 
			
		||||
    for the color calculation. Each of these streams has three main parts:
 | 
			
		||||
        1) Voxel grids:
 | 
			
		||||
            They take the (x, y, z) position and return the embedding of that point.
 | 
			
		||||
            These components are replaceable, you can make your own or choose one of
 | 
			
		||||
            several options.
 | 
			
		||||
        2) Harmonic embeddings:
 | 
			
		||||
            Convert each feature into series of 'harmonic features', feature is passed through
 | 
			
		||||
            sine and cosine functions. Input is of shape [minibatch, ..., D] output
 | 
			
		||||
            [minibatch, ..., (n_harmonic_functions * 2 + int(append_input)) * D]. Appends
 | 
			
		||||
            input by default. If you want it to behave like identity, put n_harmonic_functions=0
 | 
			
		||||
            and append_input=True.
 | 
			
		||||
        3) Decoding functions:
 | 
			
		||||
            The decoder is an instance of the DecoderFunctionBase and converts the embedding
 | 
			
		||||
            of a spatial location to density/color. Examples are Identity which returns its
 | 
			
		||||
            input and the MLP which uses fully connected nerual network to transform the input.
 | 
			
		||||
            These components are replaceable, you can make your own or choose from
 | 
			
		||||
            several options.
 | 
			
		||||
 | 
			
		||||
    Calculating density is done in three steps:
 | 
			
		||||
        1) Evaluating the voxel grid on points
 | 
			
		||||
        2) Embedding the outputs with harmonic embedding
 | 
			
		||||
        3) Passing through the Density decoder
 | 
			
		||||
 | 
			
		||||
    To calculate the color we need the embedding and the viewing direction, it has five steps:
 | 
			
		||||
        1) Transforming the viewing direction with camera
 | 
			
		||||
        2) Evaluating the voxel grid on points
 | 
			
		||||
        3) Embedding the outputs with harmonic embedding
 | 
			
		||||
        4) Embedding the normalized direction with harmonic embedding
 | 
			
		||||
        5) Passing everything through the Color decoder
 | 
			
		||||
 | 
			
		||||
    If using the Implicitron configuration system the input_dim to the decoding functions will
 | 
			
		||||
    be set to the output_dim of the Harmonic embeddings.
 | 
			
		||||
 | 
			
		||||
    A speed up comes from using the scaffold, a low resolution voxel grid.
 | 
			
		||||
    The scaffold is referenced as "binary occupancy grid mask" in TensoRF paper and "AlphaMask"
 | 
			
		||||
    in official TensoRF implementation.
 | 
			
		||||
    The scaffold is used in:
 | 
			
		||||
        1) filtering points in empty space
 | 
			
		||||
            - controlled by `scaffold_filter_points` boolean. If set to True, points for which
 | 
			
		||||
                scaffold predicts that are in empty space will return 0 density and
 | 
			
		||||
                (0, 0, 0) color.
 | 
			
		||||
        2) calculating the bounding box of an object and cropping the voxel grids
 | 
			
		||||
            - controlled by `volume_cropping_epochs`.
 | 
			
		||||
            - at those epochs the implicit function will find the bounding box of an object
 | 
			
		||||
                inside it and crop density and color grids. Cropping of the voxel grids means
 | 
			
		||||
                preserving only voxel values that are inside the bounding box and changing the
 | 
			
		||||
                resolution to match the original, while preserving the new cropped location in
 | 
			
		||||
                world coordinates.
 | 
			
		||||
 | 
			
		||||
    The scaffold has to exist before attempting filtering and cropping, and is created on
 | 
			
		||||
    `scaffold_calculating_epochs`. Each voxel in the scaffold is labeled as having density 1 if
 | 
			
		||||
    the point in the center of it evaluates to greater than `scaffold_empty_space_threshold`.
 | 
			
		||||
    3D max pooling is performed on the densities of the points in 3D.
 | 
			
		||||
    Scaffold features are off by default.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
        voxel_grid_density (VoxelGridBase): voxel grid to use for density estimation
 | 
			
		||||
        voxel_grid_color   (VoxelGridBase): voxel grid to use for color   estimation
 | 
			
		||||
 | 
			
		||||
        harmonic_embedder_xyz_density (HarmonicEmbedder): Function to transform the outputs of
 | 
			
		||||
            the voxel_grid_density
 | 
			
		||||
        harmonic_embedder_xyz_color (HarmonicEmbedder): Function to transform the outputs of
 | 
			
		||||
            the voxel_grid_color for density
 | 
			
		||||
        harmonic_embedder_dir_color (HarmonicEmbedder): Function to transform the outputs of
 | 
			
		||||
            the voxel_grid_color for color
 | 
			
		||||
 | 
			
		||||
        decoder_density (DecoderFunctionBase): decoder function to use for density estimation
 | 
			
		||||
        color_density   (DecoderFunctionBase): decoder function to use for color   estimation
 | 
			
		||||
 | 
			
		||||
        use_multiple_streams (bool): if you want the density and color calculations to run on
 | 
			
		||||
            different cuda streams set this to True. Default True.
 | 
			
		||||
        xyz_ray_dir_in_camera_coords (bool): This is true if the directions are given in
 | 
			
		||||
            camera coordinates. Default False.
 | 
			
		||||
 | 
			
		||||
        voxel_grid_scaffold (VoxelGridModule): which holds the scaffold. Extents and
 | 
			
		||||
            translation of it are set to those of voxel_grid_density.
 | 
			
		||||
        scaffold_calculating_epochs (Tuple[int, ...]): at which epochs to recalculate the
 | 
			
		||||
            scaffold. (The scaffold will be created automatically at the beginning of
 | 
			
		||||
            the calculation.)
 | 
			
		||||
        scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
 | 
			
		||||
            voxel grid which stores scaffold
 | 
			
		||||
        scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
 | 
			
		||||
            this it will be considered as empty space and the scaffold at that point would
 | 
			
		||||
            evaluate as empty space.
 | 
			
		||||
        scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
 | 
			
		||||
            at the same time. To calculate the scaffold we need to query `get_density()` at
 | 
			
		||||
            every voxel, this calculation can be split into scaffold depth number of xy plane
 | 
			
		||||
            calculations if you want the lowest memory usage, one calculation to calculate the
 | 
			
		||||
            whole scaffold, but with higher memory footprint or any other number of planes.
 | 
			
		||||
            Setting to 'inf' calculates all planes at the same time. Defaults to 'inf'.
 | 
			
		||||
        scaffold_max_pool_kernel_size (int): Size of the pooling region to use when
 | 
			
		||||
            calculating the scaffold. Defaults to 3.
 | 
			
		||||
        scaffold_filter_points (bool): If set to True the points will be filtered using
 | 
			
		||||
            `self.voxel_grid_scaffold`. Filtered points will be predicted as having 0 density
 | 
			
		||||
            and (0, 0, 0) color. The points which were not evaluated as empty space will be
 | 
			
		||||
            passed through the steps outlined above.
 | 
			
		||||
        volume_cropping_epochs: on which epochs to crop the voxel grids to fit the object's
 | 
			
		||||
            bounding box. Scaffold has to be calculated before cropping.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # ---- voxel grid for density
 | 
			
		||||
    voxel_grid_density: VoxelGridModule
 | 
			
		||||
 | 
			
		||||
    # ---- voxel grid for color
 | 
			
		||||
    voxel_grid_color: VoxelGridModule
 | 
			
		||||
 | 
			
		||||
    # ---- harmonic embeddings density
 | 
			
		||||
    harmonic_embedder_xyz_density_args: DictConfig = get_default_args_field(
 | 
			
		||||
        HarmonicEmbedding
 | 
			
		||||
    )
 | 
			
		||||
    harmonic_embedder_xyz_color_args: DictConfig = get_default_args_field(
 | 
			
		||||
        HarmonicEmbedding
 | 
			
		||||
    )
 | 
			
		||||
    harmonic_embedder_dir_color_args: DictConfig = get_default_args_field(
 | 
			
		||||
        HarmonicEmbedding
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # ---- decoder function for density
 | 
			
		||||
    decoder_density_class_type: str = "MLPDecoder"
 | 
			
		||||
    decoder_density: DecoderFunctionBase
 | 
			
		||||
 | 
			
		||||
    # ---- decoder function for color
 | 
			
		||||
    decoder_color_class_type: str = "MLPDecoder"
 | 
			
		||||
    decoder_color: DecoderFunctionBase
 | 
			
		||||
 | 
			
		||||
    # ---- cuda streams
 | 
			
		||||
    use_multiple_streams: bool = True
 | 
			
		||||
 | 
			
		||||
    # ---- camera
 | 
			
		||||
    xyz_ray_dir_in_camera_coords: bool = False
 | 
			
		||||
 | 
			
		||||
    # --- scaffold
 | 
			
		||||
    # voxel_grid_scaffold: VoxelGridModule
 | 
			
		||||
    scaffold_calculating_epochs: Tuple[int, ...] = ()
 | 
			
		||||
    scaffold_resolution: Tuple[int, int, int] = (128, 128, 128)
 | 
			
		||||
    scaffold_empty_space_threshold: float = 0.001
 | 
			
		||||
    scaffold_occupancy_chunk_size: Union[str, int] = "inf"
 | 
			
		||||
    scaffold_max_pool_kernel_size: int = 3
 | 
			
		||||
    scaffold_filter_points: bool = True
 | 
			
		||||
 | 
			
		||||
    # --- cropping
 | 
			
		||||
    volume_cropping_epochs: Tuple[int, ...] = ()
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.harmonic_embedder_xyz_density = HarmonicEmbedding(
 | 
			
		||||
            **self.harmonic_embedder_xyz_density_args
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.harmonic_embedder_xyz_color = HarmonicEmbedding(
 | 
			
		||||
            **self.harmonic_embedder_xyz_color_args
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.harmonic_embedder_dir_color = HarmonicEmbedding(
 | 
			
		||||
            **self.harmonic_embedder_dir_color_args
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self._scaffold_ready = False
 | 
			
		||||
        if type(self.scaffold_occupancy_chunk_size) != int:
 | 
			
		||||
            if self.scaffold_occupancy_chunk_size != "inf":
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "`scaffold_occupancy_chunk_size` has to be int or 'inf'."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        ray_bundle: ImplicitronRayBundle,
 | 
			
		||||
        fun_viewpool=None,
 | 
			
		||||
        camera: Optional[CamerasBase] = None,
 | 
			
		||||
        global_code=None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
 | 
			
		||||
        """
 | 
			
		||||
        The forward function accepts the parametrizations of 3D points sampled along
 | 
			
		||||
        projection rays. The forward pass is responsible for attaching a 3D vector
 | 
			
		||||
        and a 1D scalar representing the point's RGB color and opacity respectively.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            ray_bundle: An ImplicitronRayBundle object containing the following variables:
 | 
			
		||||
                origins: A tensor of shape `(minibatch, ..., 3)` denoting the
 | 
			
		||||
                    origins of the sampling rays in world coords.
 | 
			
		||||
                directions: A tensor of shape `(minibatch, ..., 3)`
 | 
			
		||||
                    containing the direction vectors of sampling rays in world coords.
 | 
			
		||||
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
 | 
			
		||||
                    containing the lengths at which the rays are sampled.
 | 
			
		||||
            fun_viewpool: an optional callback with the signature
 | 
			
		||||
                    fun_fiewpool(points) -> pooled_features
 | 
			
		||||
                where points is a [N_TGT x N x 3] tensor of world coords,
 | 
			
		||||
                and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
 | 
			
		||||
                of the features pooled from the context images.
 | 
			
		||||
            camera: A camera model which will be used to transform the viewing
 | 
			
		||||
                directions
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
 | 
			
		||||
                denoting the opacitiy of each ray point.
 | 
			
		||||
            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
 | 
			
		||||
                denoting the color of each ray point.
 | 
			
		||||
        """
 | 
			
		||||
        # ########## convert the ray parametrizations to world coordinates ########## #
 | 
			
		||||
        # points.shape = [minibatch x n_rays_width x n_rays_height x pts_per_ray x 3]
 | 
			
		||||
        # pyre-ignore[6]
 | 
			
		||||
        points = ray_bundle_to_ray_points(ray_bundle)
 | 
			
		||||
        directions = ray_bundle.directions.reshape(-1, 3)
 | 
			
		||||
        input_shape = points.shape
 | 
			
		||||
        points = points.view(-1, 3)
 | 
			
		||||
 | 
			
		||||
        # ########## filter the points using the scaffold ########## #
 | 
			
		||||
        if self._scaffold_ready and self.scaffold_filter_points:
 | 
			
		||||
            # pyre-ignore[29]
 | 
			
		||||
            non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
 | 
			
		||||
            points = points[non_empty_points]
 | 
			
		||||
            directions = directions[non_empty_points]
 | 
			
		||||
            if len(points) == 0:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
                    "The scaffold has filtered all the points."
 | 
			
		||||
                    "The voxel grids and decoding functions will not be run."
 | 
			
		||||
                )
 | 
			
		||||
                return (
 | 
			
		||||
                    points.new_zeros((*input_shape[:-1], 1)),
 | 
			
		||||
                    points.new_zeros((*input_shape[:-1], 3)),
 | 
			
		||||
                    {},
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # ########## calculate color and density ########## #
 | 
			
		||||
        rays_densities, rays_colors = self.calculate_density_and_color(
 | 
			
		||||
            points, directions, camera
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not (self._scaffold_ready and self.scaffold_filter_points):
 | 
			
		||||
            return (
 | 
			
		||||
                rays_densities.view((*input_shape[:-1], rays_densities.shape[-1])),
 | 
			
		||||
                rays_colors.view((*input_shape[:-1], rays_colors.shape[-1])),
 | 
			
		||||
                {},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # ########## merge scaffold calculated points ########## #
 | 
			
		||||
        # Create a zeroed tensor corresponding to a point with density=0 and fill it
 | 
			
		||||
        # with calculated density for points which are not in empty space. Do the
 | 
			
		||||
        # same for color
 | 
			
		||||
        rays_densities_combined = rays_densities.new_zeros(
 | 
			
		||||
            (math.prod(input_shape[:-1]), rays_densities.shape[-1])
 | 
			
		||||
        )
 | 
			
		||||
        rays_colors_combined = rays_colors.new_zeros(
 | 
			
		||||
            (math.prod(input_shape[:-1]), rays_colors.shape[-1])
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[61]
 | 
			
		||||
        rays_densities_combined[non_empty_points] = rays_densities
 | 
			
		||||
        # pyre-ignore[61]
 | 
			
		||||
        rays_colors_combined[non_empty_points] = rays_colors
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            rays_densities_combined.view((*input_shape[:-1], rays_densities.shape[-1])),
 | 
			
		||||
            rays_colors_combined.view((*input_shape[:-1], rays_colors.shape[-1])),
 | 
			
		||||
            {},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def calculate_density_and_color(
 | 
			
		||||
        self,
 | 
			
		||||
        points: torch.Tensor,
 | 
			
		||||
        directions: torch.Tensor,
 | 
			
		||||
        camera: Optional[CamerasBase] = None,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates density and color at `points`.
 | 
			
		||||
        If enabled use cuda streams.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            points: points at which to calculate density and color.
 | 
			
		||||
                Tensor of shape [..., 3].
 | 
			
		||||
            directions: from which directions are the points viewed
 | 
			
		||||
                Tensor of shape [..., 3].
 | 
			
		||||
            camera: A camera model which will be used to transform the viewing
 | 
			
		||||
                directions
 | 
			
		||||
        Returns:
 | 
			
		||||
               Tuple of color (tensor of shape [..., 3]) and density
 | 
			
		||||
                (tensor of shape [..., 1])
 | 
			
		||||
        """
 | 
			
		||||
        if self.use_multiple_streams and points.is_cuda:
 | 
			
		||||
            current_stream = torch.cuda.current_stream(points.device)
 | 
			
		||||
            other_stream = torch.cuda.Stream(points.device)
 | 
			
		||||
            other_stream.wait_stream(current_stream)
 | 
			
		||||
 | 
			
		||||
            with torch.cuda.stream(other_stream):
 | 
			
		||||
                # rays_densities.shape =
 | 
			
		||||
                # [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
 | 
			
		||||
                rays_densities = self.get_density(points)
 | 
			
		||||
 | 
			
		||||
            # rays_colors.shape =
 | 
			
		||||
            # [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
 | 
			
		||||
            rays_colors = self.get_color(points, camera, directions)
 | 
			
		||||
 | 
			
		||||
            current_stream.wait_stream(other_stream)
 | 
			
		||||
        else:
 | 
			
		||||
            # Same calculation as above, just serial.
 | 
			
		||||
            rays_densities = self.get_density(points)
 | 
			
		||||
            rays_colors = self.get_color(points, camera, directions)
 | 
			
		||||
        return rays_densities, rays_colors
 | 
			
		||||
 | 
			
		||||
    def get_density(self, points: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates density at points:
 | 
			
		||||
            1) Evaluates the voxel grid on points
 | 
			
		||||
            2) Embeds the outputs with harmonic embedding
 | 
			
		||||
            3) Passes everything through the Density decoder
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            points: tensor of shape [..., 3]
 | 
			
		||||
                where the last dimension is the points in the (x, y, z)
 | 
			
		||||
        Returns:
 | 
			
		||||
            calculated densities of shape [..., density_dim], `density_dim` is the
 | 
			
		||||
                feature dimensionality which `decoder_density` returns
 | 
			
		||||
        """
 | 
			
		||||
        embeds_density = self.voxel_grid_density(points)
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
 | 
			
		||||
        # shape = [..., density_dim]
 | 
			
		||||
        return self.decoder_density(harmonic_embedding_density)
 | 
			
		||||
 | 
			
		||||
    def get_color(
 | 
			
		||||
        self,
 | 
			
		||||
        points: torch.Tensor,
 | 
			
		||||
        camera: Optional[CamerasBase],
 | 
			
		||||
        directions: torch.Tensor,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates color at points using the viewing direction:
 | 
			
		||||
            1) Transforms the viewing direction with camera
 | 
			
		||||
            2) Evaluates the voxel grid on points
 | 
			
		||||
            3) Embeds the outputs with harmonic embedding
 | 
			
		||||
            4) Embeds the normalized direction with harmonic embedding
 | 
			
		||||
            5) Passes everything through the Color decoder
 | 
			
		||||
        Args:
 | 
			
		||||
            points: tensor of shape (..., 3)
 | 
			
		||||
                where the last dimension is the points in the (x, y, z)
 | 
			
		||||
            camera: A camera model which will be used to transform the viewing
 | 
			
		||||
                directions
 | 
			
		||||
            directions: A tensor of shape `(..., 3)`
 | 
			
		||||
                containing the direction vectors of sampling rays in world coords.
 | 
			
		||||
        """
 | 
			
		||||
        # ########## transform direction ########## #
 | 
			
		||||
        if self.xyz_ray_dir_in_camera_coords:
 | 
			
		||||
            if camera is None:
 | 
			
		||||
                raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
 | 
			
		||||
            directions = directions @ camera.R
 | 
			
		||||
 | 
			
		||||
        # ########## get voxel grid output ########## #
 | 
			
		||||
        # embeds_color.shape = [..., pts_per_ray, n_features]
 | 
			
		||||
        embeds_color = self.voxel_grid_color(points)
 | 
			
		||||
 | 
			
		||||
        # ########## embed with the harmonic function ########## #
 | 
			
		||||
        # Obtain the harmonic embedding of the voxel grid output.
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
 | 
			
		||||
 | 
			
		||||
        # Normalize the ray_directions to unit l2 norm.
 | 
			
		||||
        rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
 | 
			
		||||
        # Obtain the harmonic embedding of the normalized ray directions.
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        harmonic_embedding_dir = self.harmonic_embedder_dir_color(
 | 
			
		||||
            rays_directions_normed
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        n_rays = directions.shape[0]
 | 
			
		||||
        points_per_ray: int = points.shape[0] // n_rays
 | 
			
		||||
 | 
			
		||||
        harmonic_embedding_dir = torch.repeat_interleave(
 | 
			
		||||
            harmonic_embedding_dir, points_per_ray, dim=0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # total color embedding is concatenation of the harmonic embedding of voxel grid
 | 
			
		||||
        # output and harmonic embedding of the normalized direction
 | 
			
		||||
        total_color_embedding = torch.cat(
 | 
			
		||||
            (harmonic_embedding_color, harmonic_embedding_dir), dim=-1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # ########## evaluate color with the decoding function ########## #
 | 
			
		||||
        # rays_colors.shape = [..., pts_per_ray, 3] in [0-1]
 | 
			
		||||
        return self.decoder_color(total_color_embedding)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def allows_multiple_passes() -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Returns True as this implicit function allows
 | 
			
		||||
        multiple passes. Overridden from ImplicitFunctionBase.
 | 
			
		||||
        """
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
 | 
			
		||||
        """
 | 
			
		||||
        Method which expresses interest in subscribing to optimization epoch updates.
 | 
			
		||||
        This implicit function subscribes to epochs to calculate the scaffold and to
 | 
			
		||||
        crop voxel grids, so this method combines wanted epochs and wraps their callbacks.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of epochs on which to call a callable and callable to be called on
 | 
			
		||||
                particular epoch. The callable returns True if parameter change has
 | 
			
		||||
                happened else False and it must be supplied with one argument, epoch.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def callback(epoch) -> bool:
 | 
			
		||||
            change = False
 | 
			
		||||
            if epoch in self.scaffold_calculating_epochs:
 | 
			
		||||
                change = self._get_scaffold(epoch)
 | 
			
		||||
            if epoch in self.volume_cropping_epochs:
 | 
			
		||||
                change = self._crop(epoch) or change
 | 
			
		||||
            return change
 | 
			
		||||
 | 
			
		||||
        # remove duplicates
 | 
			
		||||
        call_epochs = list(
 | 
			
		||||
            set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs)
 | 
			
		||||
        )
 | 
			
		||||
        return call_epochs, callback
 | 
			
		||||
 | 
			
		||||
    def _crop(self, epoch: int) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Finds the bounding box of an object represented in the scaffold and crops
 | 
			
		||||
        density and color voxel grids to match that bounding box. If density of the
 | 
			
		||||
        scaffold is 0 everywhere (there is no object in it) no change will
 | 
			
		||||
        happen.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            epoch: ignored
 | 
			
		||||
        Returns:
 | 
			
		||||
            True (indicating that parameter change has happened) if there is
 | 
			
		||||
            an object inside, else False.
 | 
			
		||||
        """
 | 
			
		||||
        # find bounding box
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
 | 
			
		||||
        assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
 | 
			
		||||
        non_zero_idxs = torch.nonzero(occupancy)
 | 
			
		||||
        if len(non_zero_idxs) == 0:
 | 
			
		||||
            return False
 | 
			
		||||
        min_indices = tuple(torch.min(non_zero_idxs, dim=0)[0])
 | 
			
		||||
        max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0])
 | 
			
		||||
        min_point, max_point = points[min_indices], points[max_indices]
 | 
			
		||||
 | 
			
		||||
        # crop the voxel grids
 | 
			
		||||
        self.voxel_grid_density.crop_self(min_point, max_point)
 | 
			
		||||
        self.voxel_grid_color.crop_self(min_point, max_point)
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def _get_scaffold(self, epoch: int) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Creates a low resolution grid which is used to filter points that are in empty
 | 
			
		||||
        space.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            epoch: epoch on which it is called, ignored inside method
 | 
			
		||||
        Returns:
 | 
			
		||||
             Always False: Modifies `self.voxel_grid_scaffold` member.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        planes = []
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
 | 
			
		||||
 | 
			
		||||
        chunk_size = (
 | 
			
		||||
            self.scaffold_occupancy_chunk_size
 | 
			
		||||
            if type(self.scaffold_occupancy_chunk_size) == int
 | 
			
		||||
            else points.shape[-1]
 | 
			
		||||
        )
 | 
			
		||||
        for k in range(0, points.shape[-1], chunk_size):
 | 
			
		||||
            points_in_planes = points[..., k : k + chunk_size]
 | 
			
		||||
            planes.append(self.get_density(points_in_planes)[..., 0])
 | 
			
		||||
 | 
			
		||||
        density_cube = torch.cat(planes, dim=-1)
 | 
			
		||||
        density_cube = torch.nn.functional.max_pool3d(
 | 
			
		||||
            density_cube[None, None],
 | 
			
		||||
            kernel_size=self.scaffold_max_pool_kernel_size,
 | 
			
		||||
            padding=self.scaffold_max_pool_kernel_size // 2,
 | 
			
		||||
            stride=1,
 | 
			
		||||
        )
 | 
			
		||||
        occupancy_cube = density_cube > self.scaffold_empty_space_threshold
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self._scaffold_ready = True
 | 
			
		||||
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        args.pop("input_dim", None)
 | 
			
		||||
 | 
			
		||||
    def create_decoder_density_impl(self, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Decoding functions come after harmonic embedding and voxel grid. In order to not
 | 
			
		||||
        calculate the input dimension of the decoder in the config file this function
 | 
			
		||||
        calculates the required input dimension and sets the input dimension of the
 | 
			
		||||
        decoding function to this value.
 | 
			
		||||
        """
 | 
			
		||||
        grid_args = self.voxel_grid_density_args
 | 
			
		||||
        # pyre-ignore[6]
 | 
			
		||||
        grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
 | 
			
		||||
 | 
			
		||||
        embedder_args = self.harmonic_embedder_xyz_density_args
 | 
			
		||||
        input_dim = HarmonicEmbedding.get_output_dim_static(
 | 
			
		||||
            grid_output_dim,
 | 
			
		||||
            embedder_args["n_harmonic_functions"],
 | 
			
		||||
            embedder_args["append_input"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        cls = registry.get(DecoderFunctionBase, type)
 | 
			
		||||
        need_input_dim = any(field.name == "input_dim" for field in fields(cls))
 | 
			
		||||
        if need_input_dim:
 | 
			
		||||
            self.decoder_density = cls(input_dim=input_dim, **args)
 | 
			
		||||
        else:
 | 
			
		||||
            self.decoder_density = cls(**args)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        args.pop("input_dim", None)
 | 
			
		||||
 | 
			
		||||
    def create_decoder_color_impl(self, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Decoding functions come after harmonic embedding and voxel grid. In order to not
 | 
			
		||||
        calculate the input dimension of the decoder in the config file this function
 | 
			
		||||
        calculates the required input dimension and sets the input dimension of the
 | 
			
		||||
        decoding function to this value.
 | 
			
		||||
        """
 | 
			
		||||
        grid_args = self.voxel_grid_color_args
 | 
			
		||||
        # pyre-ignore[6]
 | 
			
		||||
        grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
 | 
			
		||||
 | 
			
		||||
        embedder_args = self.harmonic_embedder_xyz_color_args
 | 
			
		||||
        input_dim0 = HarmonicEmbedding.get_output_dim_static(
 | 
			
		||||
            grid_output_dim,
 | 
			
		||||
            embedder_args["n_harmonic_functions"],
 | 
			
		||||
            embedder_args["append_input"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        dir_dim = 3
 | 
			
		||||
        embedder_args = self.harmonic_embedder_dir_color_args
 | 
			
		||||
        input_dim1 = HarmonicEmbedding.get_output_dim_static(
 | 
			
		||||
            dir_dim,
 | 
			
		||||
            embedder_args["n_harmonic_functions"],
 | 
			
		||||
            embedder_args["append_input"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        input_dim = input_dim0 + input_dim1
 | 
			
		||||
 | 
			
		||||
        cls = registry.get(DecoderFunctionBase, type)
 | 
			
		||||
        need_input_dim = any(field.name == "input_dim" for field in fields(cls))
 | 
			
		||||
        if need_input_dim:
 | 
			
		||||
            self.decoder_color = cls(input_dim=input_dim, **args)
 | 
			
		||||
        else:
 | 
			
		||||
            self.decoder_color = cls(**args)
 | 
			
		||||
 | 
			
		||||
    def _create_voxel_grid_scaffold(self) -> VoxelGridModule:
 | 
			
		||||
        """
 | 
			
		||||
        Creates object to become self.voxel_grid_scaffold:
 | 
			
		||||
            -  makes `self.voxel_grid_scaffold` have same world to local mapping as
 | 
			
		||||
                    `self.voxel_grid_density`
 | 
			
		||||
        """
 | 
			
		||||
        return VoxelGridModule(
 | 
			
		||||
            # pyre-ignore[29]
 | 
			
		||||
            extents=self.voxel_grid_density_args["extents"],
 | 
			
		||||
            # pyre-ignore[29]
 | 
			
		||||
            translation=self.voxel_grid_density_args["translation"],
 | 
			
		||||
            voxel_grid_class_type="FullResolutionVoxelGrid",
 | 
			
		||||
            hold_voxel_grid_as_parameters=False,
 | 
			
		||||
            voxel_grid_FullResolutionVoxelGrid_args={
 | 
			
		||||
                "resolution_changes": {0: self.scaffold_resolution},
 | 
			
		||||
                "padding": "zeros",
 | 
			
		||||
                "align_corners": True,
 | 
			
		||||
                "mode": "trilinear",
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										227
									
								
								tests/implicitron/test_voxel_grid_implicit_function.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								tests/implicitron/test_voxel_grid_implicit_function.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,227 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from omegaconf import DictConfig, OmegaConf
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (
 | 
			
		||||
    VoxelGridImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
 | 
			
		||||
from pytorch3d.renderer import ray_bundle_to_ray_points
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
        expand_args_fields(VoxelGridImplicitFunction)
 | 
			
		||||
 | 
			
		||||
    def _get_simple_implicit_function(self, scaffold_res=16):
 | 
			
		||||
        default_cfg = get_default_args(VoxelGridImplicitFunction)
 | 
			
		||||
        custom_cfg = DictConfig(
 | 
			
		||||
            {
 | 
			
		||||
                "voxel_grid_density_args": {
 | 
			
		||||
                    "voxel_grid_FullResolutionVoxelGrid_args": {"n_features": 7}
 | 
			
		||||
                },
 | 
			
		||||
                "decoder_density_class_type": "ElementwiseDecoder",
 | 
			
		||||
                "decoder_color_class_type": "MLPDecoder",
 | 
			
		||||
                "decoder_color_MLPDecoder_args": {
 | 
			
		||||
                    "network_args": {
 | 
			
		||||
                        "n_layers": 2,
 | 
			
		||||
                        "output_dim": 3,
 | 
			
		||||
                        "hidden_dim": 128,
 | 
			
		||||
                    }
 | 
			
		||||
                },
 | 
			
		||||
                "scaffold_resolution": (scaffold_res, scaffold_res, scaffold_res),
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        cfg = OmegaConf.merge(default_cfg, custom_cfg)
 | 
			
		||||
        return VoxelGridImplicitFunction(**cfg)
 | 
			
		||||
 | 
			
		||||
    def test_forward(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Test one forward of VoxelGridImplicitFunction.
 | 
			
		||||
        """
 | 
			
		||||
        func = self._get_simple_implicit_function()
 | 
			
		||||
 | 
			
		||||
        n_grids, n_points = 10, 9
 | 
			
		||||
        raybundle = ImplicitronRayBundle(
 | 
			
		||||
            origins=torch.randn(n_grids, 2, 3, 3),
 | 
			
		||||
            directions=torch.randn(n_grids, 2, 3, 3),
 | 
			
		||||
            lengths=torch.randn(n_grids, 2, 3, n_points),
 | 
			
		||||
            xys=0,
 | 
			
		||||
        )
 | 
			
		||||
        func(raybundle)
 | 
			
		||||
 | 
			
		||||
    def test_scaffold_formation(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test calculating the scaffold.
 | 
			
		||||
 | 
			
		||||
        We define a custom density function and make the implicit function use it
 | 
			
		||||
        After calculating the scaffold we compare the density of our custom
 | 
			
		||||
        density function with densities from the scaffold.
 | 
			
		||||
        """
 | 
			
		||||
        device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
			
		||||
        func = self._get_simple_implicit_function().to(device)
 | 
			
		||||
        func.scaffold_max_pool_kernel_size = 1
 | 
			
		||||
 | 
			
		||||
        def new_density(points):
 | 
			
		||||
            """
 | 
			
		||||
            Density function which returns 1 if p>(0.5, 0.5, 0.5) or
 | 
			
		||||
            p < (-0.5, -0.5, -0.5) else 0
 | 
			
		||||
            """
 | 
			
		||||
            inshape = points.shape
 | 
			
		||||
            points = points.view(-1, 3)
 | 
			
		||||
            out = []
 | 
			
		||||
            for p in points:
 | 
			
		||||
                if torch.all(p > 0.5) or torch.all(p < -0.5):
 | 
			
		||||
                    out.append(torch.tensor([[1.0]]))
 | 
			
		||||
                else:
 | 
			
		||||
                    out.append(torch.tensor([[0.0]]))
 | 
			
		||||
            return torch.cat(out).view(*inshape[:-1], 1).to(device)
 | 
			
		||||
 | 
			
		||||
        func.get_density = new_density
 | 
			
		||||
        func._get_scaffold(0)
 | 
			
		||||
 | 
			
		||||
        points = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0, 0, 0],
 | 
			
		||||
                [1, 1, 1],
 | 
			
		||||
                [1, 0, 0],
 | 
			
		||||
                [0.1, 0, 0],
 | 
			
		||||
                [10, 1, -1],
 | 
			
		||||
                [-0.8, -0.7, -0.9],
 | 
			
		||||
            ]
 | 
			
		||||
        ).to(device)
 | 
			
		||||
        expected = new_density(points).float().to(device)
 | 
			
		||||
        assert torch.allclose(func.voxel_grid_scaffold(points), expected), (
 | 
			
		||||
            func.voxel_grid_scaffold(points),
 | 
			
		||||
            expected,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_scaffold_filtering(self, n_test_points=100):
 | 
			
		||||
        """
 | 
			
		||||
        Test that filtering points with scaffold works.
 | 
			
		||||
 | 
			
		||||
        We define a scaffold and make the implicit function use it. We also
 | 
			
		||||
        define new density and color functions which check that all passed
 | 
			
		||||
        points are not in empty space (with scaffold function). In the end
 | 
			
		||||
        we compare the result from the implicit function with one calculated
 | 
			
		||||
        simple python, this checks that the points were merged correectly.
 | 
			
		||||
        """
 | 
			
		||||
        device = "cuda"
 | 
			
		||||
        func = self._get_simple_implicit_function().to(device)
 | 
			
		||||
 | 
			
		||||
        def scaffold(points):
 | 
			
		||||
            """'
 | 
			
		||||
            Function to deterministically and randomly enough assign a point
 | 
			
		||||
            to empty or occupied space.
 | 
			
		||||
            Return 1 if second digit of sum after 0 is odd else 0
 | 
			
		||||
            """
 | 
			
		||||
            return (
 | 
			
		||||
                ((points.sum(dim=-1, keepdim=True) * 10**2 % 10).long() % 2) == 1
 | 
			
		||||
            ).float()
 | 
			
		||||
 | 
			
		||||
        def new_density(points):
 | 
			
		||||
            # check if all passed points should be passed here
 | 
			
		||||
            assert torch.all(scaffold(points)), (scaffold(points), points.shape)
 | 
			
		||||
            return points.sum(dim=-1, keepdim=True)
 | 
			
		||||
 | 
			
		||||
        def new_color(points, camera, directions):
 | 
			
		||||
            # check if all passed points should be passed here
 | 
			
		||||
            assert torch.all(scaffold(points))  # , (scaffold(points), points)
 | 
			
		||||
            return points * 2
 | 
			
		||||
 | 
			
		||||
        # check both computation paths that they contain only points
 | 
			
		||||
        # which are not in empty space
 | 
			
		||||
        func.get_density = new_density
 | 
			
		||||
        func.get_color = new_color
 | 
			
		||||
        func.voxel_grid_scaffold.forward = scaffold
 | 
			
		||||
        func._scaffold_ready = True
 | 
			
		||||
 | 
			
		||||
        bundle = ImplicitronRayBundle(
 | 
			
		||||
            origins=torch.rand((n_test_points, 2, 1, 3), device=device),
 | 
			
		||||
            directions=torch.rand((n_test_points, 2, 1, 3), device=device),
 | 
			
		||||
            lengths=torch.rand((n_test_points, 2, 1, 4), device=device),
 | 
			
		||||
            xys=None,
 | 
			
		||||
        )
 | 
			
		||||
        points = ray_bundle_to_ray_points(bundle)
 | 
			
		||||
        result_density, result_color, _ = func(bundle)
 | 
			
		||||
 | 
			
		||||
        # construct the wanted result 'by hand'
 | 
			
		||||
        flat_points = points.view(-1, 3)
 | 
			
		||||
        expected_result_density, expected_result_color = [], []
 | 
			
		||||
        for point in flat_points:
 | 
			
		||||
            if scaffold(point) == 1:
 | 
			
		||||
                expected_result_density.append(point.sum(dim=-1, keepdim=True))
 | 
			
		||||
                expected_result_color.append(point * 2)
 | 
			
		||||
            else:
 | 
			
		||||
                expected_result_density.append(point.new_zeros((1,)))
 | 
			
		||||
                expected_result_color.append(point.new_zeros((3,)))
 | 
			
		||||
        expected_result_density = torch.stack(expected_result_density, dim=0).view(
 | 
			
		||||
            *points.shape[:-1], 1
 | 
			
		||||
        )
 | 
			
		||||
        expected_result_color = torch.stack(expected_result_color, dim=0).view(
 | 
			
		||||
            *points.shape[:-1], 3
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # check that thre result is expected
 | 
			
		||||
        assert torch.allclose(result_density, expected_result_density), (
 | 
			
		||||
            result_density,
 | 
			
		||||
            expected_result_density,
 | 
			
		||||
        )
 | 
			
		||||
        assert torch.allclose(result_color, expected_result_color), (
 | 
			
		||||
            result_color,
 | 
			
		||||
            expected_result_color,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_cropping(self, scaffold_res=9):
 | 
			
		||||
        """
 | 
			
		||||
        Tests whether implicit function finds the bounding box of the object and sends
 | 
			
		||||
        correct min and max points to voxel grids for rescaling.
 | 
			
		||||
        """
 | 
			
		||||
        device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
			
		||||
        func = self._get_simple_implicit_function(scaffold_res=scaffold_res).to(device)
 | 
			
		||||
 | 
			
		||||
        assert scaffold_res >= 8
 | 
			
		||||
        div = (scaffold_res - 1) / 2
 | 
			
		||||
        true_min_point = torch.tensor(
 | 
			
		||||
            [-3 / div, 0 / div, -3 / div],
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        true_max_point = torch.tensor(
 | 
			
		||||
            [1 / div, 2 / div, 3 / div],
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def new_scaffold(points):
 | 
			
		||||
            # 1 if between true_min and true_max point else 0
 | 
			
		||||
            # return points.new_ones((*points.shape[:-1], 1))
 | 
			
		||||
            return (
 | 
			
		||||
                torch.logical_and(true_min_point <= points, points <= true_max_point)
 | 
			
		||||
                .all(dim=-1)
 | 
			
		||||
                .float()[..., None]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        called_crop = []
 | 
			
		||||
 | 
			
		||||
        def assert_min_max_points(min_point, max_point):
 | 
			
		||||
            called_crop.append(1)
 | 
			
		||||
            self.assertClose(min_point, true_min_point)
 | 
			
		||||
            self.assertClose(max_point, true_max_point)
 | 
			
		||||
 | 
			
		||||
        func.voxel_grid_density.crop_self = assert_min_max_points
 | 
			
		||||
        func.voxel_grid_color.crop_self = assert_min_max_points
 | 
			
		||||
        func.voxel_grid_scaffold.forward = new_scaffold
 | 
			
		||||
        func._scaffold_ready = True
 | 
			
		||||
        func._crop(epoch=0)
 | 
			
		||||
        assert len(called_crop) == 2
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user