voxel_grid_implicit_function

Reviewed By: shapovalov

Differential Revision: D40622304

fbshipit-source-id: 277515a55c46d9b8300058b439526539a7fe00a0
This commit is contained in:
Jeremy Reizenstein 2022-10-23 05:36:34 -07:00 committed by Facebook GitHub Bot
parent 611aba9a20
commit 74754bbf17
4 changed files with 1008 additions and 0 deletions

View File

@ -394,6 +394,168 @@ model_factory_ImplicitronModelFactory_args:
in_features: 256 in_features: 256
out_features: 3 out_features: 3
ray_dir_in_camera_coords: false ray_dir_in_camera_coords: false
implicit_function_VoxelGridImplicitFunction_args:
harmonic_embedder_xyz_density_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
harmonic_embedder_xyz_color_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
harmonic_embedder_dir_color_args:
n_harmonic_functions: 6
omega_0: 1.0
logspace: true
append_input: true
decoder_density_class_type: MLPDecoder
decoder_color_class_type: MLPDecoder
use_multiple_streams: true
xyz_ray_dir_in_camera_coords: false
scaffold_calculating_epochs: []
scaffold_resolution:
- 128
- 128
- 128
scaffold_empty_space_threshold: 0.001
scaffold_occupancy_chunk_size: 'inf'
scaffold_max_pool_kernel_size: 3
scaffold_filter_points: true
volume_cropping_epochs: []
voxel_grid_density_args:
voxel_grid_class_type: FullResolutionVoxelGrid
extents:
- 2.0
- 2.0
- 2.0
translation:
- 0.0
- 0.0
- 0.0
init_std: 0.1
init_mean: 0.0
hold_voxel_grid_as_parameters: true
param_groups: {}
voxel_grid_CPFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: 24
basis_matrix: true
voxel_grid_FullResolutionVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
voxel_grid_VMFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: null
distribution_of_components: null
basis_matrix: true
voxel_grid_color_args:
voxel_grid_class_type: FullResolutionVoxelGrid
extents:
- 2.0
- 2.0
- 2.0
translation:
- 0.0
- 0.0
- 0.0
init_std: 0.1
init_mean: 0.0
hold_voxel_grid_as_parameters: true
param_groups: {}
voxel_grid_CPFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: 24
basis_matrix: true
voxel_grid_FullResolutionVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
voxel_grid_VMFactorizedVoxelGrid_args:
align_corners: true
padding: zeros
mode: bilinear
n_features: 1
resolution_changes:
0:
- 128
- 128
- 128
n_components: null
distribution_of_components: null
basis_matrix: true
decoder_density_ElementwiseDecoder_args:
scale: 1.0
shift: 0.0
operation: IDENTITY
decoder_density_MLPDecoder_args:
param_groups: {}
network_args:
n_layers: 8
output_dim: 256
skip_dim: 39
hidden_dim: 256
input_skips:
- 5
skip_affine_trans: false
last_layer_bias_init: null
last_activation: RELU
use_xavier_init: true
decoder_color_ElementwiseDecoder_args:
scale: 1.0
shift: 0.0
operation: IDENTITY
decoder_color_MLPDecoder_args:
param_groups: {}
network_args:
n_layers: 8
output_dim: 256
skip_dim: 39
hidden_dim: 256
input_skips:
- 5
skip_affine_trans: false
last_layer_bias_init: null
last_activation: RELU
use_xavier_init: true
view_metrics_ViewMetrics_args: {} view_metrics_ViewMetrics_args: {}
regularization_metrics_RegularizationMetrics_args: {} regularization_metrics_RegularizationMetrics_args: {}
optimizer_factory_ImplicitronOptimizerFactory_args: optimizer_factory_ImplicitronOptimizerFactory_args:

View File

@ -52,6 +52,9 @@ from .implicit_function.scene_representation_networks import ( # noqa
SRNHyperNetImplicitFunction, SRNHyperNetImplicitFunction,
SRNImplicitFunction, SRNImplicitFunction,
) )
from .implicit_function.voxel_grid_implicit_function import ( # noqa
VoxelGridImplicitFunction,
)
from .renderer.base import ( from .renderer.base import (
BaseRenderer, BaseRenderer,

View File

@ -0,0 +1,616 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import warnings
from dataclasses import fields
from typing import Callable, Dict, Optional, Tuple, Union
import torch
from omegaconf import DictConfig
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
DecoderFunctionBase,
)
from pytorch3d.implicitron.models.implicit_function.voxel_grid import VoxelGridModule
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import (
enable_get_default_args,
get_default_args_field,
registry,
run_auto_creation,
)
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
enable_get_default_args(HarmonicEmbedding)
@registry.register
# pyre-ignore[13]
class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
"""
This implicit function consists of two streams, one for the density calculation and one
for the color calculation. Each of these streams has three main parts:
1) Voxel grids:
They take the (x, y, z) position and return the embedding of that point.
These components are replaceable, you can make your own or choose one of
several options.
2) Harmonic embeddings:
Convert each feature into series of 'harmonic features', feature is passed through
sine and cosine functions. Input is of shape [minibatch, ..., D] output
[minibatch, ..., (n_harmonic_functions * 2 + int(append_input)) * D]. Appends
input by default. If you want it to behave like identity, put n_harmonic_functions=0
and append_input=True.
3) Decoding functions:
The decoder is an instance of the DecoderFunctionBase and converts the embedding
of a spatial location to density/color. Examples are Identity which returns its
input and the MLP which uses fully connected nerual network to transform the input.
These components are replaceable, you can make your own or choose from
several options.
Calculating density is done in three steps:
1) Evaluating the voxel grid on points
2) Embedding the outputs with harmonic embedding
3) Passing through the Density decoder
To calculate the color we need the embedding and the viewing direction, it has five steps:
1) Transforming the viewing direction with camera
2) Evaluating the voxel grid on points
3) Embedding the outputs with harmonic embedding
4) Embedding the normalized direction with harmonic embedding
5) Passing everything through the Color decoder
If using the Implicitron configuration system the input_dim to the decoding functions will
be set to the output_dim of the Harmonic embeddings.
A speed up comes from using the scaffold, a low resolution voxel grid.
The scaffold is referenced as "binary occupancy grid mask" in TensoRF paper and "AlphaMask"
in official TensoRF implementation.
The scaffold is used in:
1) filtering points in empty space
- controlled by `scaffold_filter_points` boolean. If set to True, points for which
scaffold predicts that are in empty space will return 0 density and
(0, 0, 0) color.
2) calculating the bounding box of an object and cropping the voxel grids
- controlled by `volume_cropping_epochs`.
- at those epochs the implicit function will find the bounding box of an object
inside it and crop density and color grids. Cropping of the voxel grids means
preserving only voxel values that are inside the bounding box and changing the
resolution to match the original, while preserving the new cropped location in
world coordinates.
The scaffold has to exist before attempting filtering and cropping, and is created on
`scaffold_calculating_epochs`. Each voxel in the scaffold is labeled as having density 1 if
the point in the center of it evaluates to greater than `scaffold_empty_space_threshold`.
3D max pooling is performed on the densities of the points in 3D.
Scaffold features are off by default.
Members:
voxel_grid_density (VoxelGridBase): voxel grid to use for density estimation
voxel_grid_color (VoxelGridBase): voxel grid to use for color estimation
harmonic_embedder_xyz_density (HarmonicEmbedder): Function to transform the outputs of
the voxel_grid_density
harmonic_embedder_xyz_color (HarmonicEmbedder): Function to transform the outputs of
the voxel_grid_color for density
harmonic_embedder_dir_color (HarmonicEmbedder): Function to transform the outputs of
the voxel_grid_color for color
decoder_density (DecoderFunctionBase): decoder function to use for density estimation
color_density (DecoderFunctionBase): decoder function to use for color estimation
use_multiple_streams (bool): if you want the density and color calculations to run on
different cuda streams set this to True. Default True.
xyz_ray_dir_in_camera_coords (bool): This is true if the directions are given in
camera coordinates. Default False.
voxel_grid_scaffold (VoxelGridModule): which holds the scaffold. Extents and
translation of it are set to those of voxel_grid_density.
scaffold_calculating_epochs (Tuple[int, ...]): at which epochs to recalculate the
scaffold. (The scaffold will be created automatically at the beginning of
the calculation.)
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
voxel grid which stores scaffold
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
this it will be considered as empty space and the scaffold at that point would
evaluate as empty space.
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
at the same time. To calculate the scaffold we need to query `get_density()` at
every voxel, this calculation can be split into scaffold depth number of xy plane
calculations if you want the lowest memory usage, one calculation to calculate the
whole scaffold, but with higher memory footprint or any other number of planes.
Setting to 'inf' calculates all planes at the same time. Defaults to 'inf'.
scaffold_max_pool_kernel_size (int): Size of the pooling region to use when
calculating the scaffold. Defaults to 3.
scaffold_filter_points (bool): If set to True the points will be filtered using
`self.voxel_grid_scaffold`. Filtered points will be predicted as having 0 density
and (0, 0, 0) color. The points which were not evaluated as empty space will be
passed through the steps outlined above.
volume_cropping_epochs: on which epochs to crop the voxel grids to fit the object's
bounding box. Scaffold has to be calculated before cropping.
"""
# ---- voxel grid for density
voxel_grid_density: VoxelGridModule
# ---- voxel grid for color
voxel_grid_color: VoxelGridModule
# ---- harmonic embeddings density
harmonic_embedder_xyz_density_args: DictConfig = get_default_args_field(
HarmonicEmbedding
)
harmonic_embedder_xyz_color_args: DictConfig = get_default_args_field(
HarmonicEmbedding
)
harmonic_embedder_dir_color_args: DictConfig = get_default_args_field(
HarmonicEmbedding
)
# ---- decoder function for density
decoder_density_class_type: str = "MLPDecoder"
decoder_density: DecoderFunctionBase
# ---- decoder function for color
decoder_color_class_type: str = "MLPDecoder"
decoder_color: DecoderFunctionBase
# ---- cuda streams
use_multiple_streams: bool = True
# ---- camera
xyz_ray_dir_in_camera_coords: bool = False
# --- scaffold
# voxel_grid_scaffold: VoxelGridModule
scaffold_calculating_epochs: Tuple[int, ...] = ()
scaffold_resolution: Tuple[int, int, int] = (128, 128, 128)
scaffold_empty_space_threshold: float = 0.001
scaffold_occupancy_chunk_size: Union[str, int] = "inf"
scaffold_max_pool_kernel_size: int = 3
scaffold_filter_points: bool = True
# --- cropping
volume_cropping_epochs: Tuple[int, ...] = ()
def __post_init__(self) -> None:
super().__init__()
run_auto_creation(self)
# pyre-ignore[16]
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
# pyre-ignore[16]
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
**self.harmonic_embedder_xyz_density_args
)
# pyre-ignore[16]
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
**self.harmonic_embedder_xyz_color_args
)
# pyre-ignore[16]
self.harmonic_embedder_dir_color = HarmonicEmbedding(
**self.harmonic_embedder_dir_color_args
)
# pyre-ignore[16]
self._scaffold_ready = False
if type(self.scaffold_occupancy_chunk_size) != int:
if self.scaffold_occupancy_chunk_size != "inf":
raise ValueError(
"`scaffold_occupancy_chunk_size` has to be int or 'inf'."
)
def forward(
self,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
"""
The forward function accepts the parametrizations of 3D points sampled along
projection rays. The forward pass is responsible for attaching a 3D vector
and a 1D scalar representing the point's RGB color and opacity respectively.
Args:
ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
of the features pooled from the context images.
camera: A camera model which will be used to transform the viewing
directions
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
"""
# ########## convert the ray parametrizations to world coordinates ########## #
# points.shape = [minibatch x n_rays_width x n_rays_height x pts_per_ray x 3]
# pyre-ignore[6]
points = ray_bundle_to_ray_points(ray_bundle)
directions = ray_bundle.directions.reshape(-1, 3)
input_shape = points.shape
points = points.view(-1, 3)
# ########## filter the points using the scaffold ########## #
if self._scaffold_ready and self.scaffold_filter_points:
# pyre-ignore[29]
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
points = points[non_empty_points]
directions = directions[non_empty_points]
if len(points) == 0:
warnings.warn(
"The scaffold has filtered all the points."
"The voxel grids and decoding functions will not be run."
)
return (
points.new_zeros((*input_shape[:-1], 1)),
points.new_zeros((*input_shape[:-1], 3)),
{},
)
# ########## calculate color and density ########## #
rays_densities, rays_colors = self.calculate_density_and_color(
points, directions, camera
)
if not (self._scaffold_ready and self.scaffold_filter_points):
return (
rays_densities.view((*input_shape[:-1], rays_densities.shape[-1])),
rays_colors.view((*input_shape[:-1], rays_colors.shape[-1])),
{},
)
# ########## merge scaffold calculated points ########## #
# Create a zeroed tensor corresponding to a point with density=0 and fill it
# with calculated density for points which are not in empty space. Do the
# same for color
rays_densities_combined = rays_densities.new_zeros(
(math.prod(input_shape[:-1]), rays_densities.shape[-1])
)
rays_colors_combined = rays_colors.new_zeros(
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
)
# pyre-ignore[61]
rays_densities_combined[non_empty_points] = rays_densities
# pyre-ignore[61]
rays_colors_combined[non_empty_points] = rays_colors
return (
rays_densities_combined.view((*input_shape[:-1], rays_densities.shape[-1])),
rays_colors_combined.view((*input_shape[:-1], rays_colors.shape[-1])),
{},
)
def calculate_density_and_color(
self,
points: torch.Tensor,
directions: torch.Tensor,
camera: Optional[CamerasBase] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculates density and color at `points`.
If enabled use cuda streams.
Args:
points: points at which to calculate density and color.
Tensor of shape [..., 3].
directions: from which directions are the points viewed
Tensor of shape [..., 3].
camera: A camera model which will be used to transform the viewing
directions
Returns:
Tuple of color (tensor of shape [..., 3]) and density
(tensor of shape [..., 1])
"""
if self.use_multiple_streams and points.is_cuda:
current_stream = torch.cuda.current_stream(points.device)
other_stream = torch.cuda.Stream(points.device)
other_stream.wait_stream(current_stream)
with torch.cuda.stream(other_stream):
# rays_densities.shape =
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
rays_densities = self.get_density(points)
# rays_colors.shape =
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
rays_colors = self.get_color(points, camera, directions)
current_stream.wait_stream(other_stream)
else:
# Same calculation as above, just serial.
rays_densities = self.get_density(points)
rays_colors = self.get_color(points, camera, directions)
return rays_densities, rays_colors
def get_density(self, points: torch.Tensor) -> torch.Tensor:
"""
Calculates density at points:
1) Evaluates the voxel grid on points
2) Embeds the outputs with harmonic embedding
3) Passes everything through the Density decoder
Args:
points: tensor of shape [..., 3]
where the last dimension is the points in the (x, y, z)
Returns:
calculated densities of shape [..., density_dim], `density_dim` is the
feature dimensionality which `decoder_density` returns
"""
embeds_density = self.voxel_grid_density(points)
# pyre-ignore[29]
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
# shape = [..., density_dim]
return self.decoder_density(harmonic_embedding_density)
def get_color(
self,
points: torch.Tensor,
camera: Optional[CamerasBase],
directions: torch.Tensor,
) -> torch.Tensor:
"""
Calculates color at points using the viewing direction:
1) Transforms the viewing direction with camera
2) Evaluates the voxel grid on points
3) Embeds the outputs with harmonic embedding
4) Embeds the normalized direction with harmonic embedding
5) Passes everything through the Color decoder
Args:
points: tensor of shape (..., 3)
where the last dimension is the points in the (x, y, z)
camera: A camera model which will be used to transform the viewing
directions
directions: A tensor of shape `(..., 3)`
containing the direction vectors of sampling rays in world coords.
"""
# ########## transform direction ########## #
if self.xyz_ray_dir_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
directions = directions @ camera.R
# ########## get voxel grid output ########## #
# embeds_color.shape = [..., pts_per_ray, n_features]
embeds_color = self.voxel_grid_color(points)
# ########## embed with the harmonic function ########## #
# Obtain the harmonic embedding of the voxel grid output.
# pyre-ignore[29]
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-ignore[29]
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
rays_directions_normed
)
n_rays = directions.shape[0]
points_per_ray: int = points.shape[0] // n_rays
harmonic_embedding_dir = torch.repeat_interleave(
harmonic_embedding_dir, points_per_ray, dim=0
)
# total color embedding is concatenation of the harmonic embedding of voxel grid
# output and harmonic embedding of the normalized direction
total_color_embedding = torch.cat(
(harmonic_embedding_color, harmonic_embedding_dir), dim=-1
)
# ########## evaluate color with the decoding function ########## #
# rays_colors.shape = [..., pts_per_ray, 3] in [0-1]
return self.decoder_color(total_color_embedding)
@staticmethod
def allows_multiple_passes() -> bool:
"""
Returns True as this implicit function allows
multiple passes. Overridden from ImplicitFunctionBase.
"""
return True
def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
"""
Method which expresses interest in subscribing to optimization epoch updates.
This implicit function subscribes to epochs to calculate the scaffold and to
crop voxel grids, so this method combines wanted epochs and wraps their callbacks.
Returns:
list of epochs on which to call a callable and callable to be called on
particular epoch. The callable returns True if parameter change has
happened else False and it must be supplied with one argument, epoch.
"""
def callback(epoch) -> bool:
change = False
if epoch in self.scaffold_calculating_epochs:
change = self._get_scaffold(epoch)
if epoch in self.volume_cropping_epochs:
change = self._crop(epoch) or change
return change
# remove duplicates
call_epochs = list(
set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs)
)
return call_epochs, callback
def _crop(self, epoch: int) -> bool:
"""
Finds the bounding box of an object represented in the scaffold and crops
density and color voxel grids to match that bounding box. If density of the
scaffold is 0 everywhere (there is no object in it) no change will
happen.
Args:
epoch: ignored
Returns:
True (indicating that parameter change has happened) if there is
an object inside, else False.
"""
# find bounding box
# pyre-ignore[16]
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
# pyre-ignore[29]
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
non_zero_idxs = torch.nonzero(occupancy)
if len(non_zero_idxs) == 0:
return False
min_indices = tuple(torch.min(non_zero_idxs, dim=0)[0])
max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0])
min_point, max_point = points[min_indices], points[max_indices]
# crop the voxel grids
self.voxel_grid_density.crop_self(min_point, max_point)
self.voxel_grid_color.crop_self(min_point, max_point)
return True
@torch.no_grad()
def _get_scaffold(self, epoch: int) -> bool:
"""
Creates a low resolution grid which is used to filter points that are in empty
space.
Args:
epoch: epoch on which it is called, ignored inside method
Returns:
Always False: Modifies `self.voxel_grid_scaffold` member.
"""
planes = []
# pyre-ignore[16]
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
chunk_size = (
self.scaffold_occupancy_chunk_size
if type(self.scaffold_occupancy_chunk_size) == int
else points.shape[-1]
)
for k in range(0, points.shape[-1], chunk_size):
points_in_planes = points[..., k : k + chunk_size]
planes.append(self.get_density(points_in_planes)[..., 0])
density_cube = torch.cat(planes, dim=-1)
density_cube = torch.nn.functional.max_pool3d(
density_cube[None, None],
kernel_size=self.scaffold_max_pool_kernel_size,
padding=self.scaffold_max_pool_kernel_size // 2,
stride=1,
)
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
# pyre-ignore[16]
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
# pyre-ignore[16]
self._scaffold_ready = True
return False
@classmethod
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
args.pop("input_dim", None)
def create_decoder_density_impl(self, type, args: DictConfig) -> None:
"""
Decoding functions come after harmonic embedding and voxel grid. In order to not
calculate the input dimension of the decoder in the config file this function
calculates the required input dimension and sets the input dimension of the
decoding function to this value.
"""
grid_args = self.voxel_grid_density_args
# pyre-ignore[6]
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
embedder_args = self.harmonic_embedder_xyz_density_args
input_dim = HarmonicEmbedding.get_output_dim_static(
grid_output_dim,
embedder_args["n_harmonic_functions"],
embedder_args["append_input"],
)
cls = registry.get(DecoderFunctionBase, type)
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
if need_input_dim:
self.decoder_density = cls(input_dim=input_dim, **args)
else:
self.decoder_density = cls(**args)
@classmethod
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
args.pop("input_dim", None)
def create_decoder_color_impl(self, type, args: DictConfig) -> None:
"""
Decoding functions come after harmonic embedding and voxel grid. In order to not
calculate the input dimension of the decoder in the config file this function
calculates the required input dimension and sets the input dimension of the
decoding function to this value.
"""
grid_args = self.voxel_grid_color_args
# pyre-ignore[6]
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
embedder_args = self.harmonic_embedder_xyz_color_args
input_dim0 = HarmonicEmbedding.get_output_dim_static(
grid_output_dim,
embedder_args["n_harmonic_functions"],
embedder_args["append_input"],
)
dir_dim = 3
embedder_args = self.harmonic_embedder_dir_color_args
input_dim1 = HarmonicEmbedding.get_output_dim_static(
dir_dim,
embedder_args["n_harmonic_functions"],
embedder_args["append_input"],
)
input_dim = input_dim0 + input_dim1
cls = registry.get(DecoderFunctionBase, type)
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
if need_input_dim:
self.decoder_color = cls(input_dim=input_dim, **args)
else:
self.decoder_color = cls(**args)
def _create_voxel_grid_scaffold(self) -> VoxelGridModule:
"""
Creates object to become self.voxel_grid_scaffold:
- makes `self.voxel_grid_scaffold` have same world to local mapping as
`self.voxel_grid_density`
"""
return VoxelGridModule(
# pyre-ignore[29]
extents=self.voxel_grid_density_args["extents"],
# pyre-ignore[29]
translation=self.voxel_grid_density_args["translation"],
voxel_grid_class_type="FullResolutionVoxelGrid",
hold_voxel_grid_as_parameters=False,
voxel_grid_FullResolutionVoxelGrid_args={
"resolution_changes": {0: self.scaffold_resolution},
"padding": "zeros",
"align_corners": True,
"mode": "trilinear",
},
)

View File

@ -0,0 +1,227 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (
VoxelGridImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer import ray_bundle_to_ray_points
from tests.common_testing import TestCaseMixin
class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
expand_args_fields(VoxelGridImplicitFunction)
def _get_simple_implicit_function(self, scaffold_res=16):
default_cfg = get_default_args(VoxelGridImplicitFunction)
custom_cfg = DictConfig(
{
"voxel_grid_density_args": {
"voxel_grid_FullResolutionVoxelGrid_args": {"n_features": 7}
},
"decoder_density_class_type": "ElementwiseDecoder",
"decoder_color_class_type": "MLPDecoder",
"decoder_color_MLPDecoder_args": {
"network_args": {
"n_layers": 2,
"output_dim": 3,
"hidden_dim": 128,
}
},
"scaffold_resolution": (scaffold_res, scaffold_res, scaffold_res),
}
)
cfg = OmegaConf.merge(default_cfg, custom_cfg)
return VoxelGridImplicitFunction(**cfg)
def test_forward(self) -> None:
"""
Test one forward of VoxelGridImplicitFunction.
"""
func = self._get_simple_implicit_function()
n_grids, n_points = 10, 9
raybundle = ImplicitronRayBundle(
origins=torch.randn(n_grids, 2, 3, 3),
directions=torch.randn(n_grids, 2, 3, 3),
lengths=torch.randn(n_grids, 2, 3, n_points),
xys=0,
)
func(raybundle)
def test_scaffold_formation(self):
"""
Test calculating the scaffold.
We define a custom density function and make the implicit function use it
After calculating the scaffold we compare the density of our custom
density function with densities from the scaffold.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
func = self._get_simple_implicit_function().to(device)
func.scaffold_max_pool_kernel_size = 1
def new_density(points):
"""
Density function which returns 1 if p>(0.5, 0.5, 0.5) or
p < (-0.5, -0.5, -0.5) else 0
"""
inshape = points.shape
points = points.view(-1, 3)
out = []
for p in points:
if torch.all(p > 0.5) or torch.all(p < -0.5):
out.append(torch.tensor([[1.0]]))
else:
out.append(torch.tensor([[0.0]]))
return torch.cat(out).view(*inshape[:-1], 1).to(device)
func.get_density = new_density
func._get_scaffold(0)
points = torch.tensor(
[
[0, 0, 0],
[1, 1, 1],
[1, 0, 0],
[0.1, 0, 0],
[10, 1, -1],
[-0.8, -0.7, -0.9],
]
).to(device)
expected = new_density(points).float().to(device)
assert torch.allclose(func.voxel_grid_scaffold(points), expected), (
func.voxel_grid_scaffold(points),
expected,
)
def test_scaffold_filtering(self, n_test_points=100):
"""
Test that filtering points with scaffold works.
We define a scaffold and make the implicit function use it. We also
define new density and color functions which check that all passed
points are not in empty space (with scaffold function). In the end
we compare the result from the implicit function with one calculated
simple python, this checks that the points were merged correectly.
"""
device = "cuda"
func = self._get_simple_implicit_function().to(device)
def scaffold(points):
"""'
Function to deterministically and randomly enough assign a point
to empty or occupied space.
Return 1 if second digit of sum after 0 is odd else 0
"""
return (
((points.sum(dim=-1, keepdim=True) * 10**2 % 10).long() % 2) == 1
).float()
def new_density(points):
# check if all passed points should be passed here
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
return points.sum(dim=-1, keepdim=True)
def new_color(points, camera, directions):
# check if all passed points should be passed here
assert torch.all(scaffold(points)) # , (scaffold(points), points)
return points * 2
# check both computation paths that they contain only points
# which are not in empty space
func.get_density = new_density
func.get_color = new_color
func.voxel_grid_scaffold.forward = scaffold
func._scaffold_ready = True
bundle = ImplicitronRayBundle(
origins=torch.rand((n_test_points, 2, 1, 3), device=device),
directions=torch.rand((n_test_points, 2, 1, 3), device=device),
lengths=torch.rand((n_test_points, 2, 1, 4), device=device),
xys=None,
)
points = ray_bundle_to_ray_points(bundle)
result_density, result_color, _ = func(bundle)
# construct the wanted result 'by hand'
flat_points = points.view(-1, 3)
expected_result_density, expected_result_color = [], []
for point in flat_points:
if scaffold(point) == 1:
expected_result_density.append(point.sum(dim=-1, keepdim=True))
expected_result_color.append(point * 2)
else:
expected_result_density.append(point.new_zeros((1,)))
expected_result_color.append(point.new_zeros((3,)))
expected_result_density = torch.stack(expected_result_density, dim=0).view(
*points.shape[:-1], 1
)
expected_result_color = torch.stack(expected_result_color, dim=0).view(
*points.shape[:-1], 3
)
# check that thre result is expected
assert torch.allclose(result_density, expected_result_density), (
result_density,
expected_result_density,
)
assert torch.allclose(result_color, expected_result_color), (
result_color,
expected_result_color,
)
def test_cropping(self, scaffold_res=9):
"""
Tests whether implicit function finds the bounding box of the object and sends
correct min and max points to voxel grids for rescaling.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
func = self._get_simple_implicit_function(scaffold_res=scaffold_res).to(device)
assert scaffold_res >= 8
div = (scaffold_res - 1) / 2
true_min_point = torch.tensor(
[-3 / div, 0 / div, -3 / div],
device=device,
)
true_max_point = torch.tensor(
[1 / div, 2 / div, 3 / div],
device=device,
)
def new_scaffold(points):
# 1 if between true_min and true_max point else 0
# return points.new_ones((*points.shape[:-1], 1))
return (
torch.logical_and(true_min_point <= points, points <= true_max_point)
.all(dim=-1)
.float()[..., None]
)
called_crop = []
def assert_min_max_points(min_point, max_point):
called_crop.append(1)
self.assertClose(min_point, true_min_point)
self.assertClose(max_point, true_max_point)
func.voxel_grid_density.crop_self = assert_min_max_points
func.voxel_grid_color.crop_self = assert_min_max_points
func.voxel_grid_scaffold.forward = new_scaffold
func._scaffold_ready = True
func._crop(epoch=0)
assert len(called_crop) == 2