From 74754bbf172719ee337126ad0758a6efba9538df Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Sun, 23 Oct 2022 05:36:34 -0700 Subject: [PATCH] voxel_grid_implicit_function Reviewed By: shapovalov Differential Revision: D40622304 fbshipit-source-id: 277515a55c46d9b8300058b439526539a7fe00a0 --- .../implicitron_trainer/tests/experiment.yaml | 162 +++++ pytorch3d/implicitron/models/generic_model.py | 3 + .../voxel_grid_implicit_function.py | 616 ++++++++++++++++++ .../test_voxel_grid_implicit_function.py | 227 +++++++ 4 files changed, 1008 insertions(+) create mode 100644 pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py create mode 100644 tests/implicitron/test_voxel_grid_implicit_function.py diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 87236c93..3833a7bc 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -394,6 +394,168 @@ model_factory_ImplicitronModelFactory_args: in_features: 256 out_features: 3 ray_dir_in_camera_coords: false + implicit_function_VoxelGridImplicitFunction_args: + harmonic_embedder_xyz_density_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + harmonic_embedder_xyz_color_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + harmonic_embedder_dir_color_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + decoder_density_class_type: MLPDecoder + decoder_color_class_type: MLPDecoder + use_multiple_streams: true + xyz_ray_dir_in_camera_coords: false + scaffold_calculating_epochs: [] + scaffold_resolution: + - 128 + - 128 + - 128 + scaffold_empty_space_threshold: 0.001 + scaffold_occupancy_chunk_size: 'inf' + scaffold_max_pool_kernel_size: 3 + scaffold_filter_points: true + volume_cropping_epochs: [] + voxel_grid_density_args: + voxel_grid_class_type: FullResolutionVoxelGrid + extents: + - 2.0 + - 2.0 + - 2.0 + translation: + - 0.0 + - 0.0 + - 0.0 + init_std: 0.1 + init_mean: 0.0 + hold_voxel_grid_as_parameters: true + param_groups: {} + voxel_grid_CPFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: 24 + basis_matrix: true + voxel_grid_FullResolutionVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + voxel_grid_VMFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: null + distribution_of_components: null + basis_matrix: true + voxel_grid_color_args: + voxel_grid_class_type: FullResolutionVoxelGrid + extents: + - 2.0 + - 2.0 + - 2.0 + translation: + - 0.0 + - 0.0 + - 0.0 + init_std: 0.1 + init_mean: 0.0 + hold_voxel_grid_as_parameters: true + param_groups: {} + voxel_grid_CPFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: 24 + basis_matrix: true + voxel_grid_FullResolutionVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + voxel_grid_VMFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: null + distribution_of_components: null + basis_matrix: true + decoder_density_ElementwiseDecoder_args: + scale: 1.0 + shift: 0.0 + operation: IDENTITY + decoder_density_MLPDecoder_args: + param_groups: {} + network_args: + n_layers: 8 + output_dim: 256 + skip_dim: 39 + hidden_dim: 256 + input_skips: + - 5 + skip_affine_trans: false + last_layer_bias_init: null + last_activation: RELU + use_xavier_init: true + decoder_color_ElementwiseDecoder_args: + scale: 1.0 + shift: 0.0 + operation: IDENTITY + decoder_color_MLPDecoder_args: + param_groups: {} + network_args: + n_layers: 8 + output_dim: 256 + skip_dim: 39 + hidden_dim: 256 + input_skips: + - 5 + skip_affine_trans: false + last_layer_bias_init: null + last_activation: RELU + use_xavier_init: true view_metrics_ViewMetrics_args: {} regularization_metrics_RegularizationMetrics_args: {} optimizer_factory_ImplicitronOptimizerFactory_args: diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 3c1715cf..ba3c4a18 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -52,6 +52,9 @@ from .implicit_function.scene_representation_networks import ( # noqa SRNHyperNetImplicitFunction, SRNImplicitFunction, ) +from .implicit_function.voxel_grid_implicit_function import ( # noqa + VoxelGridImplicitFunction, +) from .renderer.base import ( BaseRenderer, diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py new file mode 100644 index 00000000..c42ade91 --- /dev/null +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py @@ -0,0 +1,616 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings +from dataclasses import fields +from typing import Callable, Dict, Optional, Tuple, Union + +import torch + +from omegaconf import DictConfig + +from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase +from pytorch3d.implicitron.models.implicit_function.decoding_functions import ( + DecoderFunctionBase, +) +from pytorch3d.implicitron.models.implicit_function.voxel_grid import VoxelGridModule +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle +from pytorch3d.implicitron.tools.config import ( + enable_get_default_args, + get_default_args_field, + registry, + run_auto_creation, +) +from pytorch3d.renderer import ray_bundle_to_ray_points +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.renderer.implicit import HarmonicEmbedding + +enable_get_default_args(HarmonicEmbedding) + + +@registry.register +# pyre-ignore[13] +class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): + """ + This implicit function consists of two streams, one for the density calculation and one + for the color calculation. Each of these streams has three main parts: + 1) Voxel grids: + They take the (x, y, z) position and return the embedding of that point. + These components are replaceable, you can make your own or choose one of + several options. + 2) Harmonic embeddings: + Convert each feature into series of 'harmonic features', feature is passed through + sine and cosine functions. Input is of shape [minibatch, ..., D] output + [minibatch, ..., (n_harmonic_functions * 2 + int(append_input)) * D]. Appends + input by default. If you want it to behave like identity, put n_harmonic_functions=0 + and append_input=True. + 3) Decoding functions: + The decoder is an instance of the DecoderFunctionBase and converts the embedding + of a spatial location to density/color. Examples are Identity which returns its + input and the MLP which uses fully connected nerual network to transform the input. + These components are replaceable, you can make your own or choose from + several options. + + Calculating density is done in three steps: + 1) Evaluating the voxel grid on points + 2) Embedding the outputs with harmonic embedding + 3) Passing through the Density decoder + + To calculate the color we need the embedding and the viewing direction, it has five steps: + 1) Transforming the viewing direction with camera + 2) Evaluating the voxel grid on points + 3) Embedding the outputs with harmonic embedding + 4) Embedding the normalized direction with harmonic embedding + 5) Passing everything through the Color decoder + + If using the Implicitron configuration system the input_dim to the decoding functions will + be set to the output_dim of the Harmonic embeddings. + + A speed up comes from using the scaffold, a low resolution voxel grid. + The scaffold is referenced as "binary occupancy grid mask" in TensoRF paper and "AlphaMask" + in official TensoRF implementation. + The scaffold is used in: + 1) filtering points in empty space + - controlled by `scaffold_filter_points` boolean. If set to True, points for which + scaffold predicts that are in empty space will return 0 density and + (0, 0, 0) color. + 2) calculating the bounding box of an object and cropping the voxel grids + - controlled by `volume_cropping_epochs`. + - at those epochs the implicit function will find the bounding box of an object + inside it and crop density and color grids. Cropping of the voxel grids means + preserving only voxel values that are inside the bounding box and changing the + resolution to match the original, while preserving the new cropped location in + world coordinates. + + The scaffold has to exist before attempting filtering and cropping, and is created on + `scaffold_calculating_epochs`. Each voxel in the scaffold is labeled as having density 1 if + the point in the center of it evaluates to greater than `scaffold_empty_space_threshold`. + 3D max pooling is performed on the densities of the points in 3D. + Scaffold features are off by default. + + Members: + voxel_grid_density (VoxelGridBase): voxel grid to use for density estimation + voxel_grid_color (VoxelGridBase): voxel grid to use for color estimation + + harmonic_embedder_xyz_density (HarmonicEmbedder): Function to transform the outputs of + the voxel_grid_density + harmonic_embedder_xyz_color (HarmonicEmbedder): Function to transform the outputs of + the voxel_grid_color for density + harmonic_embedder_dir_color (HarmonicEmbedder): Function to transform the outputs of + the voxel_grid_color for color + + decoder_density (DecoderFunctionBase): decoder function to use for density estimation + color_density (DecoderFunctionBase): decoder function to use for color estimation + + use_multiple_streams (bool): if you want the density and color calculations to run on + different cuda streams set this to True. Default True. + xyz_ray_dir_in_camera_coords (bool): This is true if the directions are given in + camera coordinates. Default False. + + voxel_grid_scaffold (VoxelGridModule): which holds the scaffold. Extents and + translation of it are set to those of voxel_grid_density. + scaffold_calculating_epochs (Tuple[int, ...]): at which epochs to recalculate the + scaffold. (The scaffold will be created automatically at the beginning of + the calculation.) + scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying + voxel grid which stores scaffold + scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than + this it will be considered as empty space and the scaffold at that point would + evaluate as empty space. + scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate + at the same time. To calculate the scaffold we need to query `get_density()` at + every voxel, this calculation can be split into scaffold depth number of xy plane + calculations if you want the lowest memory usage, one calculation to calculate the + whole scaffold, but with higher memory footprint or any other number of planes. + Setting to 'inf' calculates all planes at the same time. Defaults to 'inf'. + scaffold_max_pool_kernel_size (int): Size of the pooling region to use when + calculating the scaffold. Defaults to 3. + scaffold_filter_points (bool): If set to True the points will be filtered using + `self.voxel_grid_scaffold`. Filtered points will be predicted as having 0 density + and (0, 0, 0) color. The points which were not evaluated as empty space will be + passed through the steps outlined above. + volume_cropping_epochs: on which epochs to crop the voxel grids to fit the object's + bounding box. Scaffold has to be calculated before cropping. + """ + + # ---- voxel grid for density + voxel_grid_density: VoxelGridModule + + # ---- voxel grid for color + voxel_grid_color: VoxelGridModule + + # ---- harmonic embeddings density + harmonic_embedder_xyz_density_args: DictConfig = get_default_args_field( + HarmonicEmbedding + ) + harmonic_embedder_xyz_color_args: DictConfig = get_default_args_field( + HarmonicEmbedding + ) + harmonic_embedder_dir_color_args: DictConfig = get_default_args_field( + HarmonicEmbedding + ) + + # ---- decoder function for density + decoder_density_class_type: str = "MLPDecoder" + decoder_density: DecoderFunctionBase + + # ---- decoder function for color + decoder_color_class_type: str = "MLPDecoder" + decoder_color: DecoderFunctionBase + + # ---- cuda streams + use_multiple_streams: bool = True + + # ---- camera + xyz_ray_dir_in_camera_coords: bool = False + + # --- scaffold + # voxel_grid_scaffold: VoxelGridModule + scaffold_calculating_epochs: Tuple[int, ...] = () + scaffold_resolution: Tuple[int, int, int] = (128, 128, 128) + scaffold_empty_space_threshold: float = 0.001 + scaffold_occupancy_chunk_size: Union[str, int] = "inf" + scaffold_max_pool_kernel_size: int = 3 + scaffold_filter_points: bool = True + + # --- cropping + volume_cropping_epochs: Tuple[int, ...] = () + + def __post_init__(self) -> None: + super().__init__() + run_auto_creation(self) + # pyre-ignore[16] + self.voxel_grid_scaffold = self._create_voxel_grid_scaffold() + # pyre-ignore[16] + self.harmonic_embedder_xyz_density = HarmonicEmbedding( + **self.harmonic_embedder_xyz_density_args + ) + # pyre-ignore[16] + self.harmonic_embedder_xyz_color = HarmonicEmbedding( + **self.harmonic_embedder_xyz_color_args + ) + # pyre-ignore[16] + self.harmonic_embedder_dir_color = HarmonicEmbedding( + **self.harmonic_embedder_dir_color_args + ) + # pyre-ignore[16] + self._scaffold_ready = False + if type(self.scaffold_occupancy_chunk_size) != int: + if self.scaffold_occupancy_chunk_size != "inf": + raise ValueError( + "`scaffold_occupancy_chunk_size` has to be int or 'inf'." + ) + + def forward( + self, + ray_bundle: ImplicitronRayBundle, + fun_viewpool=None, + camera: Optional[CamerasBase] = None, + global_code=None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor, Dict]: + """ + The forward function accepts the parametrizations of 3D points sampled along + projection rays. The forward pass is responsible for attaching a 3D vector + and a 1D scalar representing the point's RGB color and opacity respectively. + + Args: + ray_bundle: An ImplicitronRayBundle object containing the following variables: + origins: A tensor of shape `(minibatch, ..., 3)` denoting the + origins of the sampling rays in world coords. + directions: A tensor of shape `(minibatch, ..., 3)` + containing the direction vectors of sampling rays in world coords. + lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` + containing the lengths at which the rays are sampled. + fun_viewpool: an optional callback with the signature + fun_fiewpool(points) -> pooled_features + where points is a [N_TGT x N x 3] tensor of world coords, + and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor + of the features pooled from the context images. + camera: A camera model which will be used to transform the viewing + directions + + Returns: + rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` + denoting the opacitiy of each ray point. + rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` + denoting the color of each ray point. + """ + # ########## convert the ray parametrizations to world coordinates ########## # + # points.shape = [minibatch x n_rays_width x n_rays_height x pts_per_ray x 3] + # pyre-ignore[6] + points = ray_bundle_to_ray_points(ray_bundle) + directions = ray_bundle.directions.reshape(-1, 3) + input_shape = points.shape + points = points.view(-1, 3) + + # ########## filter the points using the scaffold ########## # + if self._scaffold_ready and self.scaffold_filter_points: + # pyre-ignore[29] + non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0 + points = points[non_empty_points] + directions = directions[non_empty_points] + if len(points) == 0: + warnings.warn( + "The scaffold has filtered all the points." + "The voxel grids and decoding functions will not be run." + ) + return ( + points.new_zeros((*input_shape[:-1], 1)), + points.new_zeros((*input_shape[:-1], 3)), + {}, + ) + + # ########## calculate color and density ########## # + rays_densities, rays_colors = self.calculate_density_and_color( + points, directions, camera + ) + + if not (self._scaffold_ready and self.scaffold_filter_points): + return ( + rays_densities.view((*input_shape[:-1], rays_densities.shape[-1])), + rays_colors.view((*input_shape[:-1], rays_colors.shape[-1])), + {}, + ) + + # ########## merge scaffold calculated points ########## # + # Create a zeroed tensor corresponding to a point with density=0 and fill it + # with calculated density for points which are not in empty space. Do the + # same for color + rays_densities_combined = rays_densities.new_zeros( + (math.prod(input_shape[:-1]), rays_densities.shape[-1]) + ) + rays_colors_combined = rays_colors.new_zeros( + (math.prod(input_shape[:-1]), rays_colors.shape[-1]) + ) + # pyre-ignore[61] + rays_densities_combined[non_empty_points] = rays_densities + # pyre-ignore[61] + rays_colors_combined[non_empty_points] = rays_colors + + return ( + rays_densities_combined.view((*input_shape[:-1], rays_densities.shape[-1])), + rays_colors_combined.view((*input_shape[:-1], rays_colors.shape[-1])), + {}, + ) + + def calculate_density_and_color( + self, + points: torch.Tensor, + directions: torch.Tensor, + camera: Optional[CamerasBase] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculates density and color at `points`. + If enabled use cuda streams. + + Args: + points: points at which to calculate density and color. + Tensor of shape [..., 3]. + directions: from which directions are the points viewed + Tensor of shape [..., 3]. + camera: A camera model which will be used to transform the viewing + directions + Returns: + Tuple of color (tensor of shape [..., 3]) and density + (tensor of shape [..., 1]) + """ + if self.use_multiple_streams and points.is_cuda: + current_stream = torch.cuda.current_stream(points.device) + other_stream = torch.cuda.Stream(points.device) + other_stream.wait_stream(current_stream) + + with torch.cuda.stream(other_stream): + # rays_densities.shape = + # [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim] + rays_densities = self.get_density(points) + + # rays_colors.shape = + # [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim] + rays_colors = self.get_color(points, camera, directions) + + current_stream.wait_stream(other_stream) + else: + # Same calculation as above, just serial. + rays_densities = self.get_density(points) + rays_colors = self.get_color(points, camera, directions) + return rays_densities, rays_colors + + def get_density(self, points: torch.Tensor) -> torch.Tensor: + """ + Calculates density at points: + 1) Evaluates the voxel grid on points + 2) Embeds the outputs with harmonic embedding + 3) Passes everything through the Density decoder + + Args: + points: tensor of shape [..., 3] + where the last dimension is the points in the (x, y, z) + Returns: + calculated densities of shape [..., density_dim], `density_dim` is the + feature dimensionality which `decoder_density` returns + """ + embeds_density = self.voxel_grid_density(points) + # pyre-ignore[29] + harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density) + # shape = [..., density_dim] + return self.decoder_density(harmonic_embedding_density) + + def get_color( + self, + points: torch.Tensor, + camera: Optional[CamerasBase], + directions: torch.Tensor, + ) -> torch.Tensor: + """ + Calculates color at points using the viewing direction: + 1) Transforms the viewing direction with camera + 2) Evaluates the voxel grid on points + 3) Embeds the outputs with harmonic embedding + 4) Embeds the normalized direction with harmonic embedding + 5) Passes everything through the Color decoder + Args: + points: tensor of shape (..., 3) + where the last dimension is the points in the (x, y, z) + camera: A camera model which will be used to transform the viewing + directions + directions: A tensor of shape `(..., 3)` + containing the direction vectors of sampling rays in world coords. + """ + # ########## transform direction ########## # + if self.xyz_ray_dir_in_camera_coords: + if camera is None: + raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords") + directions = directions @ camera.R + + # ########## get voxel grid output ########## # + # embeds_color.shape = [..., pts_per_ray, n_features] + embeds_color = self.voxel_grid_color(points) + + # ########## embed with the harmonic function ########## # + # Obtain the harmonic embedding of the voxel grid output. + # pyre-ignore[29] + harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color) + + # Normalize the ray_directions to unit l2 norm. + rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1) + # Obtain the harmonic embedding of the normalized ray directions. + # pyre-ignore[29] + harmonic_embedding_dir = self.harmonic_embedder_dir_color( + rays_directions_normed + ) + + n_rays = directions.shape[0] + points_per_ray: int = points.shape[0] // n_rays + + harmonic_embedding_dir = torch.repeat_interleave( + harmonic_embedding_dir, points_per_ray, dim=0 + ) + + # total color embedding is concatenation of the harmonic embedding of voxel grid + # output and harmonic embedding of the normalized direction + total_color_embedding = torch.cat( + (harmonic_embedding_color, harmonic_embedding_dir), dim=-1 + ) + + # ########## evaluate color with the decoding function ########## # + # rays_colors.shape = [..., pts_per_ray, 3] in [0-1] + return self.decoder_color(total_color_embedding) + + @staticmethod + def allows_multiple_passes() -> bool: + """ + Returns True as this implicit function allows + multiple passes. Overridden from ImplicitFunctionBase. + """ + return True + + def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]: + """ + Method which expresses interest in subscribing to optimization epoch updates. + This implicit function subscribes to epochs to calculate the scaffold and to + crop voxel grids, so this method combines wanted epochs and wraps their callbacks. + + Returns: + list of epochs on which to call a callable and callable to be called on + particular epoch. The callable returns True if parameter change has + happened else False and it must be supplied with one argument, epoch. + """ + + def callback(epoch) -> bool: + change = False + if epoch in self.scaffold_calculating_epochs: + change = self._get_scaffold(epoch) + if epoch in self.volume_cropping_epochs: + change = self._crop(epoch) or change + return change + + # remove duplicates + call_epochs = list( + set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs) + ) + return call_epochs, callback + + def _crop(self, epoch: int) -> bool: + """ + Finds the bounding box of an object represented in the scaffold and crops + density and color voxel grids to match that bounding box. If density of the + scaffold is 0 everywhere (there is no object in it) no change will + happen. + + Args: + epoch: ignored + Returns: + True (indicating that parameter change has happened) if there is + an object inside, else False. + """ + # find bounding box + # pyre-ignore[16] + points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch) + assert self._scaffold_ready, "Scaffold has to be calculated before cropping." + # pyre-ignore[29] + occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0 + non_zero_idxs = torch.nonzero(occupancy) + if len(non_zero_idxs) == 0: + return False + min_indices = tuple(torch.min(non_zero_idxs, dim=0)[0]) + max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0]) + min_point, max_point = points[min_indices], points[max_indices] + + # crop the voxel grids + self.voxel_grid_density.crop_self(min_point, max_point) + self.voxel_grid_color.crop_self(min_point, max_point) + return True + + @torch.no_grad() + def _get_scaffold(self, epoch: int) -> bool: + """ + Creates a low resolution grid which is used to filter points that are in empty + space. + + Args: + epoch: epoch on which it is called, ignored inside method + Returns: + Always False: Modifies `self.voxel_grid_scaffold` member. + """ + + planes = [] + # pyre-ignore[16] + points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch) + + chunk_size = ( + self.scaffold_occupancy_chunk_size + if type(self.scaffold_occupancy_chunk_size) == int + else points.shape[-1] + ) + for k in range(0, points.shape[-1], chunk_size): + points_in_planes = points[..., k : k + chunk_size] + planes.append(self.get_density(points_in_planes)[..., 0]) + + density_cube = torch.cat(planes, dim=-1) + density_cube = torch.nn.functional.max_pool3d( + density_cube[None, None], + kernel_size=self.scaffold_max_pool_kernel_size, + padding=self.scaffold_max_pool_kernel_size // 2, + stride=1, + ) + occupancy_cube = density_cube > self.scaffold_empty_space_threshold + # pyre-ignore[16] + self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float() + # pyre-ignore[16] + self._scaffold_ready = True + + return False + + @classmethod + def decoder_density_tweak_args(cls, type, args: DictConfig) -> None: + args.pop("input_dim", None) + + def create_decoder_density_impl(self, type, args: DictConfig) -> None: + """ + Decoding functions come after harmonic embedding and voxel grid. In order to not + calculate the input dimension of the decoder in the config file this function + calculates the required input dimension and sets the input dimension of the + decoding function to this value. + """ + grid_args = self.voxel_grid_density_args + # pyre-ignore[6] + grid_output_dim = VoxelGridModule.get_output_dim(grid_args) + + embedder_args = self.harmonic_embedder_xyz_density_args + input_dim = HarmonicEmbedding.get_output_dim_static( + grid_output_dim, + embedder_args["n_harmonic_functions"], + embedder_args["append_input"], + ) + + cls = registry.get(DecoderFunctionBase, type) + need_input_dim = any(field.name == "input_dim" for field in fields(cls)) + if need_input_dim: + self.decoder_density = cls(input_dim=input_dim, **args) + else: + self.decoder_density = cls(**args) + + @classmethod + def decoder_color_tweak_args(cls, type, args: DictConfig) -> None: + args.pop("input_dim", None) + + def create_decoder_color_impl(self, type, args: DictConfig) -> None: + """ + Decoding functions come after harmonic embedding and voxel grid. In order to not + calculate the input dimension of the decoder in the config file this function + calculates the required input dimension and sets the input dimension of the + decoding function to this value. + """ + grid_args = self.voxel_grid_color_args + # pyre-ignore[6] + grid_output_dim = VoxelGridModule.get_output_dim(grid_args) + + embedder_args = self.harmonic_embedder_xyz_color_args + input_dim0 = HarmonicEmbedding.get_output_dim_static( + grid_output_dim, + embedder_args["n_harmonic_functions"], + embedder_args["append_input"], + ) + + dir_dim = 3 + embedder_args = self.harmonic_embedder_dir_color_args + input_dim1 = HarmonicEmbedding.get_output_dim_static( + dir_dim, + embedder_args["n_harmonic_functions"], + embedder_args["append_input"], + ) + + input_dim = input_dim0 + input_dim1 + + cls = registry.get(DecoderFunctionBase, type) + need_input_dim = any(field.name == "input_dim" for field in fields(cls)) + if need_input_dim: + self.decoder_color = cls(input_dim=input_dim, **args) + else: + self.decoder_color = cls(**args) + + def _create_voxel_grid_scaffold(self) -> VoxelGridModule: + """ + Creates object to become self.voxel_grid_scaffold: + - makes `self.voxel_grid_scaffold` have same world to local mapping as + `self.voxel_grid_density` + """ + return VoxelGridModule( + # pyre-ignore[29] + extents=self.voxel_grid_density_args["extents"], + # pyre-ignore[29] + translation=self.voxel_grid_density_args["translation"], + voxel_grid_class_type="FullResolutionVoxelGrid", + hold_voxel_grid_as_parameters=False, + voxel_grid_FullResolutionVoxelGrid_args={ + "resolution_changes": {0: self.scaffold_resolution}, + "padding": "zeros", + "align_corners": True, + "mode": "trilinear", + }, + ) diff --git a/tests/implicitron/test_voxel_grid_implicit_function.py b/tests/implicitron/test_voxel_grid_implicit_function.py new file mode 100644 index 00000000..b5d482c8 --- /dev/null +++ b/tests/implicitron/test_voxel_grid_implicit_function.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch + +from omegaconf import DictConfig, OmegaConf +from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( + VoxelGridImplicitFunction, +) +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle + +from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args +from pytorch3d.renderer import ray_bundle_to_ray_points +from tests.common_testing import TestCaseMixin + + +class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + expand_args_fields(VoxelGridImplicitFunction) + + def _get_simple_implicit_function(self, scaffold_res=16): + default_cfg = get_default_args(VoxelGridImplicitFunction) + custom_cfg = DictConfig( + { + "voxel_grid_density_args": { + "voxel_grid_FullResolutionVoxelGrid_args": {"n_features": 7} + }, + "decoder_density_class_type": "ElementwiseDecoder", + "decoder_color_class_type": "MLPDecoder", + "decoder_color_MLPDecoder_args": { + "network_args": { + "n_layers": 2, + "output_dim": 3, + "hidden_dim": 128, + } + }, + "scaffold_resolution": (scaffold_res, scaffold_res, scaffold_res), + } + ) + cfg = OmegaConf.merge(default_cfg, custom_cfg) + return VoxelGridImplicitFunction(**cfg) + + def test_forward(self) -> None: + """ + Test one forward of VoxelGridImplicitFunction. + """ + func = self._get_simple_implicit_function() + + n_grids, n_points = 10, 9 + raybundle = ImplicitronRayBundle( + origins=torch.randn(n_grids, 2, 3, 3), + directions=torch.randn(n_grids, 2, 3, 3), + lengths=torch.randn(n_grids, 2, 3, n_points), + xys=0, + ) + func(raybundle) + + def test_scaffold_formation(self): + """ + Test calculating the scaffold. + + We define a custom density function and make the implicit function use it + After calculating the scaffold we compare the density of our custom + density function with densities from the scaffold. + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + func = self._get_simple_implicit_function().to(device) + func.scaffold_max_pool_kernel_size = 1 + + def new_density(points): + """ + Density function which returns 1 if p>(0.5, 0.5, 0.5) or + p < (-0.5, -0.5, -0.5) else 0 + """ + inshape = points.shape + points = points.view(-1, 3) + out = [] + for p in points: + if torch.all(p > 0.5) or torch.all(p < -0.5): + out.append(torch.tensor([[1.0]])) + else: + out.append(torch.tensor([[0.0]])) + return torch.cat(out).view(*inshape[:-1], 1).to(device) + + func.get_density = new_density + func._get_scaffold(0) + + points = torch.tensor( + [ + [0, 0, 0], + [1, 1, 1], + [1, 0, 0], + [0.1, 0, 0], + [10, 1, -1], + [-0.8, -0.7, -0.9], + ] + ).to(device) + expected = new_density(points).float().to(device) + assert torch.allclose(func.voxel_grid_scaffold(points), expected), ( + func.voxel_grid_scaffold(points), + expected, + ) + + def test_scaffold_filtering(self, n_test_points=100): + """ + Test that filtering points with scaffold works. + + We define a scaffold and make the implicit function use it. We also + define new density and color functions which check that all passed + points are not in empty space (with scaffold function). In the end + we compare the result from the implicit function with one calculated + simple python, this checks that the points were merged correectly. + """ + device = "cuda" + func = self._get_simple_implicit_function().to(device) + + def scaffold(points): + """' + Function to deterministically and randomly enough assign a point + to empty or occupied space. + Return 1 if second digit of sum after 0 is odd else 0 + """ + return ( + ((points.sum(dim=-1, keepdim=True) * 10**2 % 10).long() % 2) == 1 + ).float() + + def new_density(points): + # check if all passed points should be passed here + assert torch.all(scaffold(points)), (scaffold(points), points.shape) + return points.sum(dim=-1, keepdim=True) + + def new_color(points, camera, directions): + # check if all passed points should be passed here + assert torch.all(scaffold(points)) # , (scaffold(points), points) + return points * 2 + + # check both computation paths that they contain only points + # which are not in empty space + func.get_density = new_density + func.get_color = new_color + func.voxel_grid_scaffold.forward = scaffold + func._scaffold_ready = True + + bundle = ImplicitronRayBundle( + origins=torch.rand((n_test_points, 2, 1, 3), device=device), + directions=torch.rand((n_test_points, 2, 1, 3), device=device), + lengths=torch.rand((n_test_points, 2, 1, 4), device=device), + xys=None, + ) + points = ray_bundle_to_ray_points(bundle) + result_density, result_color, _ = func(bundle) + + # construct the wanted result 'by hand' + flat_points = points.view(-1, 3) + expected_result_density, expected_result_color = [], [] + for point in flat_points: + if scaffold(point) == 1: + expected_result_density.append(point.sum(dim=-1, keepdim=True)) + expected_result_color.append(point * 2) + else: + expected_result_density.append(point.new_zeros((1,))) + expected_result_color.append(point.new_zeros((3,))) + expected_result_density = torch.stack(expected_result_density, dim=0).view( + *points.shape[:-1], 1 + ) + expected_result_color = torch.stack(expected_result_color, dim=0).view( + *points.shape[:-1], 3 + ) + + # check that thre result is expected + assert torch.allclose(result_density, expected_result_density), ( + result_density, + expected_result_density, + ) + assert torch.allclose(result_color, expected_result_color), ( + result_color, + expected_result_color, + ) + + def test_cropping(self, scaffold_res=9): + """ + Tests whether implicit function finds the bounding box of the object and sends + correct min and max points to voxel grids for rescaling. + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + func = self._get_simple_implicit_function(scaffold_res=scaffold_res).to(device) + + assert scaffold_res >= 8 + div = (scaffold_res - 1) / 2 + true_min_point = torch.tensor( + [-3 / div, 0 / div, -3 / div], + device=device, + ) + true_max_point = torch.tensor( + [1 / div, 2 / div, 3 / div], + device=device, + ) + + def new_scaffold(points): + # 1 if between true_min and true_max point else 0 + # return points.new_ones((*points.shape[:-1], 1)) + return ( + torch.logical_and(true_min_point <= points, points <= true_max_point) + .all(dim=-1) + .float()[..., None] + ) + + called_crop = [] + + def assert_min_max_points(min_point, max_point): + called_crop.append(1) + self.assertClose(min_point, true_min_point) + self.assertClose(max_point, true_max_point) + + func.voxel_grid_density.crop_self = assert_min_max_points + func.voxel_grid_color.crop_self = assert_min_max_points + func.voxel_grid_scaffold.forward = new_scaffold + func._scaffold_ready = True + func._crop(epoch=0) + assert len(called_crop) == 2