mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
voxel_grid_implicit_function
Reviewed By: shapovalov Differential Revision: D40622304 fbshipit-source-id: 277515a55c46d9b8300058b439526539a7fe00a0
This commit is contained in:
parent
611aba9a20
commit
74754bbf17
@ -394,6 +394,168 @@ model_factory_ImplicitronModelFactory_args:
|
||||
in_features: 256
|
||||
out_features: 3
|
||||
ray_dir_in_camera_coords: false
|
||||
implicit_function_VoxelGridImplicitFunction_args:
|
||||
harmonic_embedder_xyz_density_args:
|
||||
n_harmonic_functions: 6
|
||||
omega_0: 1.0
|
||||
logspace: true
|
||||
append_input: true
|
||||
harmonic_embedder_xyz_color_args:
|
||||
n_harmonic_functions: 6
|
||||
omega_0: 1.0
|
||||
logspace: true
|
||||
append_input: true
|
||||
harmonic_embedder_dir_color_args:
|
||||
n_harmonic_functions: 6
|
||||
omega_0: 1.0
|
||||
logspace: true
|
||||
append_input: true
|
||||
decoder_density_class_type: MLPDecoder
|
||||
decoder_color_class_type: MLPDecoder
|
||||
use_multiple_streams: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
scaffold_calculating_epochs: []
|
||||
scaffold_resolution:
|
||||
- 128
|
||||
- 128
|
||||
- 128
|
||||
scaffold_empty_space_threshold: 0.001
|
||||
scaffold_occupancy_chunk_size: 'inf'
|
||||
scaffold_max_pool_kernel_size: 3
|
||||
scaffold_filter_points: true
|
||||
volume_cropping_epochs: []
|
||||
voxel_grid_density_args:
|
||||
voxel_grid_class_type: FullResolutionVoxelGrid
|
||||
extents:
|
||||
- 2.0
|
||||
- 2.0
|
||||
- 2.0
|
||||
translation:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
init_std: 0.1
|
||||
init_mean: 0.0
|
||||
hold_voxel_grid_as_parameters: true
|
||||
param_groups: {}
|
||||
voxel_grid_CPFactorizedVoxelGrid_args:
|
||||
align_corners: true
|
||||
padding: zeros
|
||||
mode: bilinear
|
||||
n_features: 1
|
||||
resolution_changes:
|
||||
0:
|
||||
- 128
|
||||
- 128
|
||||
- 128
|
||||
n_components: 24
|
||||
basis_matrix: true
|
||||
voxel_grid_FullResolutionVoxelGrid_args:
|
||||
align_corners: true
|
||||
padding: zeros
|
||||
mode: bilinear
|
||||
n_features: 1
|
||||
resolution_changes:
|
||||
0:
|
||||
- 128
|
||||
- 128
|
||||
- 128
|
||||
voxel_grid_VMFactorizedVoxelGrid_args:
|
||||
align_corners: true
|
||||
padding: zeros
|
||||
mode: bilinear
|
||||
n_features: 1
|
||||
resolution_changes:
|
||||
0:
|
||||
- 128
|
||||
- 128
|
||||
- 128
|
||||
n_components: null
|
||||
distribution_of_components: null
|
||||
basis_matrix: true
|
||||
voxel_grid_color_args:
|
||||
voxel_grid_class_type: FullResolutionVoxelGrid
|
||||
extents:
|
||||
- 2.0
|
||||
- 2.0
|
||||
- 2.0
|
||||
translation:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
init_std: 0.1
|
||||
init_mean: 0.0
|
||||
hold_voxel_grid_as_parameters: true
|
||||
param_groups: {}
|
||||
voxel_grid_CPFactorizedVoxelGrid_args:
|
||||
align_corners: true
|
||||
padding: zeros
|
||||
mode: bilinear
|
||||
n_features: 1
|
||||
resolution_changes:
|
||||
0:
|
||||
- 128
|
||||
- 128
|
||||
- 128
|
||||
n_components: 24
|
||||
basis_matrix: true
|
||||
voxel_grid_FullResolutionVoxelGrid_args:
|
||||
align_corners: true
|
||||
padding: zeros
|
||||
mode: bilinear
|
||||
n_features: 1
|
||||
resolution_changes:
|
||||
0:
|
||||
- 128
|
||||
- 128
|
||||
- 128
|
||||
voxel_grid_VMFactorizedVoxelGrid_args:
|
||||
align_corners: true
|
||||
padding: zeros
|
||||
mode: bilinear
|
||||
n_features: 1
|
||||
resolution_changes:
|
||||
0:
|
||||
- 128
|
||||
- 128
|
||||
- 128
|
||||
n_components: null
|
||||
distribution_of_components: null
|
||||
basis_matrix: true
|
||||
decoder_density_ElementwiseDecoder_args:
|
||||
scale: 1.0
|
||||
shift: 0.0
|
||||
operation: IDENTITY
|
||||
decoder_density_MLPDecoder_args:
|
||||
param_groups: {}
|
||||
network_args:
|
||||
n_layers: 8
|
||||
output_dim: 256
|
||||
skip_dim: 39
|
||||
hidden_dim: 256
|
||||
input_skips:
|
||||
- 5
|
||||
skip_affine_trans: false
|
||||
last_layer_bias_init: null
|
||||
last_activation: RELU
|
||||
use_xavier_init: true
|
||||
decoder_color_ElementwiseDecoder_args:
|
||||
scale: 1.0
|
||||
shift: 0.0
|
||||
operation: IDENTITY
|
||||
decoder_color_MLPDecoder_args:
|
||||
param_groups: {}
|
||||
network_args:
|
||||
n_layers: 8
|
||||
output_dim: 256
|
||||
skip_dim: 39
|
||||
hidden_dim: 256
|
||||
input_skips:
|
||||
- 5
|
||||
skip_affine_trans: false
|
||||
last_layer_bias_init: null
|
||||
last_activation: RELU
|
||||
use_xavier_init: true
|
||||
view_metrics_ViewMetrics_args: {}
|
||||
regularization_metrics_RegularizationMetrics_args: {}
|
||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
||||
|
@ -52,6 +52,9 @@ from .implicit_function.scene_representation_networks import ( # noqa
|
||||
SRNHyperNetImplicitFunction,
|
||||
SRNImplicitFunction,
|
||||
)
|
||||
from .implicit_function.voxel_grid_implicit_function import ( # noqa
|
||||
VoxelGridImplicitFunction,
|
||||
)
|
||||
|
||||
from .renderer.base import (
|
||||
BaseRenderer,
|
||||
|
@ -0,0 +1,616 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import fields
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
|
||||
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
|
||||
DecoderFunctionBase,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid import VoxelGridModule
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
enable_get_default_args,
|
||||
get_default_args_field,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
|
||||
enable_get_default_args(HarmonicEmbedding)
|
||||
|
||||
|
||||
@registry.register
|
||||
# pyre-ignore[13]
|
||||
class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
"""
|
||||
This implicit function consists of two streams, one for the density calculation and one
|
||||
for the color calculation. Each of these streams has three main parts:
|
||||
1) Voxel grids:
|
||||
They take the (x, y, z) position and return the embedding of that point.
|
||||
These components are replaceable, you can make your own or choose one of
|
||||
several options.
|
||||
2) Harmonic embeddings:
|
||||
Convert each feature into series of 'harmonic features', feature is passed through
|
||||
sine and cosine functions. Input is of shape [minibatch, ..., D] output
|
||||
[minibatch, ..., (n_harmonic_functions * 2 + int(append_input)) * D]. Appends
|
||||
input by default. If you want it to behave like identity, put n_harmonic_functions=0
|
||||
and append_input=True.
|
||||
3) Decoding functions:
|
||||
The decoder is an instance of the DecoderFunctionBase and converts the embedding
|
||||
of a spatial location to density/color. Examples are Identity which returns its
|
||||
input and the MLP which uses fully connected nerual network to transform the input.
|
||||
These components are replaceable, you can make your own or choose from
|
||||
several options.
|
||||
|
||||
Calculating density is done in three steps:
|
||||
1) Evaluating the voxel grid on points
|
||||
2) Embedding the outputs with harmonic embedding
|
||||
3) Passing through the Density decoder
|
||||
|
||||
To calculate the color we need the embedding and the viewing direction, it has five steps:
|
||||
1) Transforming the viewing direction with camera
|
||||
2) Evaluating the voxel grid on points
|
||||
3) Embedding the outputs with harmonic embedding
|
||||
4) Embedding the normalized direction with harmonic embedding
|
||||
5) Passing everything through the Color decoder
|
||||
|
||||
If using the Implicitron configuration system the input_dim to the decoding functions will
|
||||
be set to the output_dim of the Harmonic embeddings.
|
||||
|
||||
A speed up comes from using the scaffold, a low resolution voxel grid.
|
||||
The scaffold is referenced as "binary occupancy grid mask" in TensoRF paper and "AlphaMask"
|
||||
in official TensoRF implementation.
|
||||
The scaffold is used in:
|
||||
1) filtering points in empty space
|
||||
- controlled by `scaffold_filter_points` boolean. If set to True, points for which
|
||||
scaffold predicts that are in empty space will return 0 density and
|
||||
(0, 0, 0) color.
|
||||
2) calculating the bounding box of an object and cropping the voxel grids
|
||||
- controlled by `volume_cropping_epochs`.
|
||||
- at those epochs the implicit function will find the bounding box of an object
|
||||
inside it and crop density and color grids. Cropping of the voxel grids means
|
||||
preserving only voxel values that are inside the bounding box and changing the
|
||||
resolution to match the original, while preserving the new cropped location in
|
||||
world coordinates.
|
||||
|
||||
The scaffold has to exist before attempting filtering and cropping, and is created on
|
||||
`scaffold_calculating_epochs`. Each voxel in the scaffold is labeled as having density 1 if
|
||||
the point in the center of it evaluates to greater than `scaffold_empty_space_threshold`.
|
||||
3D max pooling is performed on the densities of the points in 3D.
|
||||
Scaffold features are off by default.
|
||||
|
||||
Members:
|
||||
voxel_grid_density (VoxelGridBase): voxel grid to use for density estimation
|
||||
voxel_grid_color (VoxelGridBase): voxel grid to use for color estimation
|
||||
|
||||
harmonic_embedder_xyz_density (HarmonicEmbedder): Function to transform the outputs of
|
||||
the voxel_grid_density
|
||||
harmonic_embedder_xyz_color (HarmonicEmbedder): Function to transform the outputs of
|
||||
the voxel_grid_color for density
|
||||
harmonic_embedder_dir_color (HarmonicEmbedder): Function to transform the outputs of
|
||||
the voxel_grid_color for color
|
||||
|
||||
decoder_density (DecoderFunctionBase): decoder function to use for density estimation
|
||||
color_density (DecoderFunctionBase): decoder function to use for color estimation
|
||||
|
||||
use_multiple_streams (bool): if you want the density and color calculations to run on
|
||||
different cuda streams set this to True. Default True.
|
||||
xyz_ray_dir_in_camera_coords (bool): This is true if the directions are given in
|
||||
camera coordinates. Default False.
|
||||
|
||||
voxel_grid_scaffold (VoxelGridModule): which holds the scaffold. Extents and
|
||||
translation of it are set to those of voxel_grid_density.
|
||||
scaffold_calculating_epochs (Tuple[int, ...]): at which epochs to recalculate the
|
||||
scaffold. (The scaffold will be created automatically at the beginning of
|
||||
the calculation.)
|
||||
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
|
||||
voxel grid which stores scaffold
|
||||
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
|
||||
this it will be considered as empty space and the scaffold at that point would
|
||||
evaluate as empty space.
|
||||
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
|
||||
at the same time. To calculate the scaffold we need to query `get_density()` at
|
||||
every voxel, this calculation can be split into scaffold depth number of xy plane
|
||||
calculations if you want the lowest memory usage, one calculation to calculate the
|
||||
whole scaffold, but with higher memory footprint or any other number of planes.
|
||||
Setting to 'inf' calculates all planes at the same time. Defaults to 'inf'.
|
||||
scaffold_max_pool_kernel_size (int): Size of the pooling region to use when
|
||||
calculating the scaffold. Defaults to 3.
|
||||
scaffold_filter_points (bool): If set to True the points will be filtered using
|
||||
`self.voxel_grid_scaffold`. Filtered points will be predicted as having 0 density
|
||||
and (0, 0, 0) color. The points which were not evaluated as empty space will be
|
||||
passed through the steps outlined above.
|
||||
volume_cropping_epochs: on which epochs to crop the voxel grids to fit the object's
|
||||
bounding box. Scaffold has to be calculated before cropping.
|
||||
"""
|
||||
|
||||
# ---- voxel grid for density
|
||||
voxel_grid_density: VoxelGridModule
|
||||
|
||||
# ---- voxel grid for color
|
||||
voxel_grid_color: VoxelGridModule
|
||||
|
||||
# ---- harmonic embeddings density
|
||||
harmonic_embedder_xyz_density_args: DictConfig = get_default_args_field(
|
||||
HarmonicEmbedding
|
||||
)
|
||||
harmonic_embedder_xyz_color_args: DictConfig = get_default_args_field(
|
||||
HarmonicEmbedding
|
||||
)
|
||||
harmonic_embedder_dir_color_args: DictConfig = get_default_args_field(
|
||||
HarmonicEmbedding
|
||||
)
|
||||
|
||||
# ---- decoder function for density
|
||||
decoder_density_class_type: str = "MLPDecoder"
|
||||
decoder_density: DecoderFunctionBase
|
||||
|
||||
# ---- decoder function for color
|
||||
decoder_color_class_type: str = "MLPDecoder"
|
||||
decoder_color: DecoderFunctionBase
|
||||
|
||||
# ---- cuda streams
|
||||
use_multiple_streams: bool = True
|
||||
|
||||
# ---- camera
|
||||
xyz_ray_dir_in_camera_coords: bool = False
|
||||
|
||||
# --- scaffold
|
||||
# voxel_grid_scaffold: VoxelGridModule
|
||||
scaffold_calculating_epochs: Tuple[int, ...] = ()
|
||||
scaffold_resolution: Tuple[int, int, int] = (128, 128, 128)
|
||||
scaffold_empty_space_threshold: float = 0.001
|
||||
scaffold_occupancy_chunk_size: Union[str, int] = "inf"
|
||||
scaffold_max_pool_kernel_size: int = 3
|
||||
scaffold_filter_points: bool = True
|
||||
|
||||
# --- cropping
|
||||
volume_cropping_epochs: Tuple[int, ...] = ()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
# pyre-ignore[16]
|
||||
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
||||
# pyre-ignore[16]
|
||||
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_xyz_density_args
|
||||
)
|
||||
# pyre-ignore[16]
|
||||
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_xyz_color_args
|
||||
)
|
||||
# pyre-ignore[16]
|
||||
self.harmonic_embedder_dir_color = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_dir_color_args
|
||||
)
|
||||
# pyre-ignore[16]
|
||||
self._scaffold_ready = False
|
||||
if type(self.scaffold_occupancy_chunk_size) != int:
|
||||
if self.scaffold_occupancy_chunk_size != "inf":
|
||||
raise ValueError(
|
||||
"`scaffold_occupancy_chunk_size` has to be int or 'inf'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ray_bundle: ImplicitronRayBundle,
|
||||
fun_viewpool=None,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
global_code=None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
|
||||
"""
|
||||
The forward function accepts the parametrizations of 3D points sampled along
|
||||
projection rays. The forward pass is responsible for attaching a 3D vector
|
||||
and a 1D scalar representing the point's RGB color and opacity respectively.
|
||||
|
||||
Args:
|
||||
ray_bundle: An ImplicitronRayBundle object containing the following variables:
|
||||
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||
origins of the sampling rays in world coords.
|
||||
directions: A tensor of shape `(minibatch, ..., 3)`
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||
containing the lengths at which the rays are sampled.
|
||||
fun_viewpool: an optional callback with the signature
|
||||
fun_fiewpool(points) -> pooled_features
|
||||
where points is a [N_TGT x N x 3] tensor of world coords,
|
||||
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
|
||||
of the features pooled from the context images.
|
||||
camera: A camera model which will be used to transform the viewing
|
||||
directions
|
||||
|
||||
Returns:
|
||||
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
|
||||
denoting the opacitiy of each ray point.
|
||||
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||
denoting the color of each ray point.
|
||||
"""
|
||||
# ########## convert the ray parametrizations to world coordinates ########## #
|
||||
# points.shape = [minibatch x n_rays_width x n_rays_height x pts_per_ray x 3]
|
||||
# pyre-ignore[6]
|
||||
points = ray_bundle_to_ray_points(ray_bundle)
|
||||
directions = ray_bundle.directions.reshape(-1, 3)
|
||||
input_shape = points.shape
|
||||
points = points.view(-1, 3)
|
||||
|
||||
# ########## filter the points using the scaffold ########## #
|
||||
if self._scaffold_ready and self.scaffold_filter_points:
|
||||
# pyre-ignore[29]
|
||||
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
points = points[non_empty_points]
|
||||
directions = directions[non_empty_points]
|
||||
if len(points) == 0:
|
||||
warnings.warn(
|
||||
"The scaffold has filtered all the points."
|
||||
"The voxel grids and decoding functions will not be run."
|
||||
)
|
||||
return (
|
||||
points.new_zeros((*input_shape[:-1], 1)),
|
||||
points.new_zeros((*input_shape[:-1], 3)),
|
||||
{},
|
||||
)
|
||||
|
||||
# ########## calculate color and density ########## #
|
||||
rays_densities, rays_colors = self.calculate_density_and_color(
|
||||
points, directions, camera
|
||||
)
|
||||
|
||||
if not (self._scaffold_ready and self.scaffold_filter_points):
|
||||
return (
|
||||
rays_densities.view((*input_shape[:-1], rays_densities.shape[-1])),
|
||||
rays_colors.view((*input_shape[:-1], rays_colors.shape[-1])),
|
||||
{},
|
||||
)
|
||||
|
||||
# ########## merge scaffold calculated points ########## #
|
||||
# Create a zeroed tensor corresponding to a point with density=0 and fill it
|
||||
# with calculated density for points which are not in empty space. Do the
|
||||
# same for color
|
||||
rays_densities_combined = rays_densities.new_zeros(
|
||||
(math.prod(input_shape[:-1]), rays_densities.shape[-1])
|
||||
)
|
||||
rays_colors_combined = rays_colors.new_zeros(
|
||||
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
|
||||
)
|
||||
# pyre-ignore[61]
|
||||
rays_densities_combined[non_empty_points] = rays_densities
|
||||
# pyre-ignore[61]
|
||||
rays_colors_combined[non_empty_points] = rays_colors
|
||||
|
||||
return (
|
||||
rays_densities_combined.view((*input_shape[:-1], rays_densities.shape[-1])),
|
||||
rays_colors_combined.view((*input_shape[:-1], rays_colors.shape[-1])),
|
||||
{},
|
||||
)
|
||||
|
||||
def calculate_density_and_color(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
directions: torch.Tensor,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calculates density and color at `points`.
|
||||
If enabled use cuda streams.
|
||||
|
||||
Args:
|
||||
points: points at which to calculate density and color.
|
||||
Tensor of shape [..., 3].
|
||||
directions: from which directions are the points viewed
|
||||
Tensor of shape [..., 3].
|
||||
camera: A camera model which will be used to transform the viewing
|
||||
directions
|
||||
Returns:
|
||||
Tuple of color (tensor of shape [..., 3]) and density
|
||||
(tensor of shape [..., 1])
|
||||
"""
|
||||
if self.use_multiple_streams and points.is_cuda:
|
||||
current_stream = torch.cuda.current_stream(points.device)
|
||||
other_stream = torch.cuda.Stream(points.device)
|
||||
other_stream.wait_stream(current_stream)
|
||||
|
||||
with torch.cuda.stream(other_stream):
|
||||
# rays_densities.shape =
|
||||
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
|
||||
rays_densities = self.get_density(points)
|
||||
|
||||
# rays_colors.shape =
|
||||
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
|
||||
rays_colors = self.get_color(points, camera, directions)
|
||||
|
||||
current_stream.wait_stream(other_stream)
|
||||
else:
|
||||
# Same calculation as above, just serial.
|
||||
rays_densities = self.get_density(points)
|
||||
rays_colors = self.get_color(points, camera, directions)
|
||||
return rays_densities, rays_colors
|
||||
|
||||
def get_density(self, points: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates density at points:
|
||||
1) Evaluates the voxel grid on points
|
||||
2) Embeds the outputs with harmonic embedding
|
||||
3) Passes everything through the Density decoder
|
||||
|
||||
Args:
|
||||
points: tensor of shape [..., 3]
|
||||
where the last dimension is the points in the (x, y, z)
|
||||
Returns:
|
||||
calculated densities of shape [..., density_dim], `density_dim` is the
|
||||
feature dimensionality which `decoder_density` returns
|
||||
"""
|
||||
embeds_density = self.voxel_grid_density(points)
|
||||
# pyre-ignore[29]
|
||||
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
|
||||
# shape = [..., density_dim]
|
||||
return self.decoder_density(harmonic_embedding_density)
|
||||
|
||||
def get_color(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
camera: Optional[CamerasBase],
|
||||
directions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates color at points using the viewing direction:
|
||||
1) Transforms the viewing direction with camera
|
||||
2) Evaluates the voxel grid on points
|
||||
3) Embeds the outputs with harmonic embedding
|
||||
4) Embeds the normalized direction with harmonic embedding
|
||||
5) Passes everything through the Color decoder
|
||||
Args:
|
||||
points: tensor of shape (..., 3)
|
||||
where the last dimension is the points in the (x, y, z)
|
||||
camera: A camera model which will be used to transform the viewing
|
||||
directions
|
||||
directions: A tensor of shape `(..., 3)`
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
"""
|
||||
# ########## transform direction ########## #
|
||||
if self.xyz_ray_dir_in_camera_coords:
|
||||
if camera is None:
|
||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||
directions = directions @ camera.R
|
||||
|
||||
# ########## get voxel grid output ########## #
|
||||
# embeds_color.shape = [..., pts_per_ray, n_features]
|
||||
embeds_color = self.voxel_grid_color(points)
|
||||
|
||||
# ########## embed with the harmonic function ########## #
|
||||
# Obtain the harmonic embedding of the voxel grid output.
|
||||
# pyre-ignore[29]
|
||||
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
|
||||
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
# pyre-ignore[29]
|
||||
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
|
||||
rays_directions_normed
|
||||
)
|
||||
|
||||
n_rays = directions.shape[0]
|
||||
points_per_ray: int = points.shape[0] // n_rays
|
||||
|
||||
harmonic_embedding_dir = torch.repeat_interleave(
|
||||
harmonic_embedding_dir, points_per_ray, dim=0
|
||||
)
|
||||
|
||||
# total color embedding is concatenation of the harmonic embedding of voxel grid
|
||||
# output and harmonic embedding of the normalized direction
|
||||
total_color_embedding = torch.cat(
|
||||
(harmonic_embedding_color, harmonic_embedding_dir), dim=-1
|
||||
)
|
||||
|
||||
# ########## evaluate color with the decoding function ########## #
|
||||
# rays_colors.shape = [..., pts_per_ray, 3] in [0-1]
|
||||
return self.decoder_color(total_color_embedding)
|
||||
|
||||
@staticmethod
|
||||
def allows_multiple_passes() -> bool:
|
||||
"""
|
||||
Returns True as this implicit function allows
|
||||
multiple passes. Overridden from ImplicitFunctionBase.
|
||||
"""
|
||||
return True
|
||||
|
||||
def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
|
||||
"""
|
||||
Method which expresses interest in subscribing to optimization epoch updates.
|
||||
This implicit function subscribes to epochs to calculate the scaffold and to
|
||||
crop voxel grids, so this method combines wanted epochs and wraps their callbacks.
|
||||
|
||||
Returns:
|
||||
list of epochs on which to call a callable and callable to be called on
|
||||
particular epoch. The callable returns True if parameter change has
|
||||
happened else False and it must be supplied with one argument, epoch.
|
||||
"""
|
||||
|
||||
def callback(epoch) -> bool:
|
||||
change = False
|
||||
if epoch in self.scaffold_calculating_epochs:
|
||||
change = self._get_scaffold(epoch)
|
||||
if epoch in self.volume_cropping_epochs:
|
||||
change = self._crop(epoch) or change
|
||||
return change
|
||||
|
||||
# remove duplicates
|
||||
call_epochs = list(
|
||||
set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs)
|
||||
)
|
||||
return call_epochs, callback
|
||||
|
||||
def _crop(self, epoch: int) -> bool:
|
||||
"""
|
||||
Finds the bounding box of an object represented in the scaffold and crops
|
||||
density and color voxel grids to match that bounding box. If density of the
|
||||
scaffold is 0 everywhere (there is no object in it) no change will
|
||||
happen.
|
||||
|
||||
Args:
|
||||
epoch: ignored
|
||||
Returns:
|
||||
True (indicating that parameter change has happened) if there is
|
||||
an object inside, else False.
|
||||
"""
|
||||
# find bounding box
|
||||
# pyre-ignore[16]
|
||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
|
||||
# pyre-ignore[29]
|
||||
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
non_zero_idxs = torch.nonzero(occupancy)
|
||||
if len(non_zero_idxs) == 0:
|
||||
return False
|
||||
min_indices = tuple(torch.min(non_zero_idxs, dim=0)[0])
|
||||
max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0])
|
||||
min_point, max_point = points[min_indices], points[max_indices]
|
||||
|
||||
# crop the voxel grids
|
||||
self.voxel_grid_density.crop_self(min_point, max_point)
|
||||
self.voxel_grid_color.crop_self(min_point, max_point)
|
||||
return True
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_scaffold(self, epoch: int) -> bool:
|
||||
"""
|
||||
Creates a low resolution grid which is used to filter points that are in empty
|
||||
space.
|
||||
|
||||
Args:
|
||||
epoch: epoch on which it is called, ignored inside method
|
||||
Returns:
|
||||
Always False: Modifies `self.voxel_grid_scaffold` member.
|
||||
"""
|
||||
|
||||
planes = []
|
||||
# pyre-ignore[16]
|
||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||
|
||||
chunk_size = (
|
||||
self.scaffold_occupancy_chunk_size
|
||||
if type(self.scaffold_occupancy_chunk_size) == int
|
||||
else points.shape[-1]
|
||||
)
|
||||
for k in range(0, points.shape[-1], chunk_size):
|
||||
points_in_planes = points[..., k : k + chunk_size]
|
||||
planes.append(self.get_density(points_in_planes)[..., 0])
|
||||
|
||||
density_cube = torch.cat(planes, dim=-1)
|
||||
density_cube = torch.nn.functional.max_pool3d(
|
||||
density_cube[None, None],
|
||||
kernel_size=self.scaffold_max_pool_kernel_size,
|
||||
padding=self.scaffold_max_pool_kernel_size // 2,
|
||||
stride=1,
|
||||
)
|
||||
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
|
||||
# pyre-ignore[16]
|
||||
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
|
||||
# pyre-ignore[16]
|
||||
self._scaffold_ready = True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
args.pop("input_dim", None)
|
||||
|
||||
def create_decoder_density_impl(self, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
||||
calculate the input dimension of the decoder in the config file this function
|
||||
calculates the required input dimension and sets the input dimension of the
|
||||
decoding function to this value.
|
||||
"""
|
||||
grid_args = self.voxel_grid_density_args
|
||||
# pyre-ignore[6]
|
||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||
|
||||
embedder_args = self.harmonic_embedder_xyz_density_args
|
||||
input_dim = HarmonicEmbedding.get_output_dim_static(
|
||||
grid_output_dim,
|
||||
embedder_args["n_harmonic_functions"],
|
||||
embedder_args["append_input"],
|
||||
)
|
||||
|
||||
cls = registry.get(DecoderFunctionBase, type)
|
||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||
if need_input_dim:
|
||||
self.decoder_density = cls(input_dim=input_dim, **args)
|
||||
else:
|
||||
self.decoder_density = cls(**args)
|
||||
|
||||
@classmethod
|
||||
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
args.pop("input_dim", None)
|
||||
|
||||
def create_decoder_color_impl(self, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
||||
calculate the input dimension of the decoder in the config file this function
|
||||
calculates the required input dimension and sets the input dimension of the
|
||||
decoding function to this value.
|
||||
"""
|
||||
grid_args = self.voxel_grid_color_args
|
||||
# pyre-ignore[6]
|
||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||
|
||||
embedder_args = self.harmonic_embedder_xyz_color_args
|
||||
input_dim0 = HarmonicEmbedding.get_output_dim_static(
|
||||
grid_output_dim,
|
||||
embedder_args["n_harmonic_functions"],
|
||||
embedder_args["append_input"],
|
||||
)
|
||||
|
||||
dir_dim = 3
|
||||
embedder_args = self.harmonic_embedder_dir_color_args
|
||||
input_dim1 = HarmonicEmbedding.get_output_dim_static(
|
||||
dir_dim,
|
||||
embedder_args["n_harmonic_functions"],
|
||||
embedder_args["append_input"],
|
||||
)
|
||||
|
||||
input_dim = input_dim0 + input_dim1
|
||||
|
||||
cls = registry.get(DecoderFunctionBase, type)
|
||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||
if need_input_dim:
|
||||
self.decoder_color = cls(input_dim=input_dim, **args)
|
||||
else:
|
||||
self.decoder_color = cls(**args)
|
||||
|
||||
def _create_voxel_grid_scaffold(self) -> VoxelGridModule:
|
||||
"""
|
||||
Creates object to become self.voxel_grid_scaffold:
|
||||
- makes `self.voxel_grid_scaffold` have same world to local mapping as
|
||||
`self.voxel_grid_density`
|
||||
"""
|
||||
return VoxelGridModule(
|
||||
# pyre-ignore[29]
|
||||
extents=self.voxel_grid_density_args["extents"],
|
||||
# pyre-ignore[29]
|
||||
translation=self.voxel_grid_density_args["translation"],
|
||||
voxel_grid_class_type="FullResolutionVoxelGrid",
|
||||
hold_voxel_grid_as_parameters=False,
|
||||
voxel_grid_FullResolutionVoxelGrid_args={
|
||||
"resolution_changes": {0: self.scaffold_resolution},
|
||||
"padding": "zeros",
|
||||
"align_corners": True,
|
||||
"mode": "trilinear",
|
||||
},
|
||||
)
|
227
tests/implicitron/test_voxel_grid_implicit_function.py
Normal file
227
tests/implicitron/test_voxel_grid_implicit_function.py
Normal file
@ -0,0 +1,227 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (
|
||||
VoxelGridImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
expand_args_fields(VoxelGridImplicitFunction)
|
||||
|
||||
def _get_simple_implicit_function(self, scaffold_res=16):
|
||||
default_cfg = get_default_args(VoxelGridImplicitFunction)
|
||||
custom_cfg = DictConfig(
|
||||
{
|
||||
"voxel_grid_density_args": {
|
||||
"voxel_grid_FullResolutionVoxelGrid_args": {"n_features": 7}
|
||||
},
|
||||
"decoder_density_class_type": "ElementwiseDecoder",
|
||||
"decoder_color_class_type": "MLPDecoder",
|
||||
"decoder_color_MLPDecoder_args": {
|
||||
"network_args": {
|
||||
"n_layers": 2,
|
||||
"output_dim": 3,
|
||||
"hidden_dim": 128,
|
||||
}
|
||||
},
|
||||
"scaffold_resolution": (scaffold_res, scaffold_res, scaffold_res),
|
||||
}
|
||||
)
|
||||
cfg = OmegaConf.merge(default_cfg, custom_cfg)
|
||||
return VoxelGridImplicitFunction(**cfg)
|
||||
|
||||
def test_forward(self) -> None:
|
||||
"""
|
||||
Test one forward of VoxelGridImplicitFunction.
|
||||
"""
|
||||
func = self._get_simple_implicit_function()
|
||||
|
||||
n_grids, n_points = 10, 9
|
||||
raybundle = ImplicitronRayBundle(
|
||||
origins=torch.randn(n_grids, 2, 3, 3),
|
||||
directions=torch.randn(n_grids, 2, 3, 3),
|
||||
lengths=torch.randn(n_grids, 2, 3, n_points),
|
||||
xys=0,
|
||||
)
|
||||
func(raybundle)
|
||||
|
||||
def test_scaffold_formation(self):
|
||||
"""
|
||||
Test calculating the scaffold.
|
||||
|
||||
We define a custom density function and make the implicit function use it
|
||||
After calculating the scaffold we compare the density of our custom
|
||||
density function with densities from the scaffold.
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
func = self._get_simple_implicit_function().to(device)
|
||||
func.scaffold_max_pool_kernel_size = 1
|
||||
|
||||
def new_density(points):
|
||||
"""
|
||||
Density function which returns 1 if p>(0.5, 0.5, 0.5) or
|
||||
p < (-0.5, -0.5, -0.5) else 0
|
||||
"""
|
||||
inshape = points.shape
|
||||
points = points.view(-1, 3)
|
||||
out = []
|
||||
for p in points:
|
||||
if torch.all(p > 0.5) or torch.all(p < -0.5):
|
||||
out.append(torch.tensor([[1.0]]))
|
||||
else:
|
||||
out.append(torch.tensor([[0.0]]))
|
||||
return torch.cat(out).view(*inshape[:-1], 1).to(device)
|
||||
|
||||
func.get_density = new_density
|
||||
func._get_scaffold(0)
|
||||
|
||||
points = torch.tensor(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[1, 1, 1],
|
||||
[1, 0, 0],
|
||||
[0.1, 0, 0],
|
||||
[10, 1, -1],
|
||||
[-0.8, -0.7, -0.9],
|
||||
]
|
||||
).to(device)
|
||||
expected = new_density(points).float().to(device)
|
||||
assert torch.allclose(func.voxel_grid_scaffold(points), expected), (
|
||||
func.voxel_grid_scaffold(points),
|
||||
expected,
|
||||
)
|
||||
|
||||
def test_scaffold_filtering(self, n_test_points=100):
|
||||
"""
|
||||
Test that filtering points with scaffold works.
|
||||
|
||||
We define a scaffold and make the implicit function use it. We also
|
||||
define new density and color functions which check that all passed
|
||||
points are not in empty space (with scaffold function). In the end
|
||||
we compare the result from the implicit function with one calculated
|
||||
simple python, this checks that the points were merged correectly.
|
||||
"""
|
||||
device = "cuda"
|
||||
func = self._get_simple_implicit_function().to(device)
|
||||
|
||||
def scaffold(points):
|
||||
"""'
|
||||
Function to deterministically and randomly enough assign a point
|
||||
to empty or occupied space.
|
||||
Return 1 if second digit of sum after 0 is odd else 0
|
||||
"""
|
||||
return (
|
||||
((points.sum(dim=-1, keepdim=True) * 10**2 % 10).long() % 2) == 1
|
||||
).float()
|
||||
|
||||
def new_density(points):
|
||||
# check if all passed points should be passed here
|
||||
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
|
||||
return points.sum(dim=-1, keepdim=True)
|
||||
|
||||
def new_color(points, camera, directions):
|
||||
# check if all passed points should be passed here
|
||||
assert torch.all(scaffold(points)) # , (scaffold(points), points)
|
||||
return points * 2
|
||||
|
||||
# check both computation paths that they contain only points
|
||||
# which are not in empty space
|
||||
func.get_density = new_density
|
||||
func.get_color = new_color
|
||||
func.voxel_grid_scaffold.forward = scaffold
|
||||
func._scaffold_ready = True
|
||||
|
||||
bundle = ImplicitronRayBundle(
|
||||
origins=torch.rand((n_test_points, 2, 1, 3), device=device),
|
||||
directions=torch.rand((n_test_points, 2, 1, 3), device=device),
|
||||
lengths=torch.rand((n_test_points, 2, 1, 4), device=device),
|
||||
xys=None,
|
||||
)
|
||||
points = ray_bundle_to_ray_points(bundle)
|
||||
result_density, result_color, _ = func(bundle)
|
||||
|
||||
# construct the wanted result 'by hand'
|
||||
flat_points = points.view(-1, 3)
|
||||
expected_result_density, expected_result_color = [], []
|
||||
for point in flat_points:
|
||||
if scaffold(point) == 1:
|
||||
expected_result_density.append(point.sum(dim=-1, keepdim=True))
|
||||
expected_result_color.append(point * 2)
|
||||
else:
|
||||
expected_result_density.append(point.new_zeros((1,)))
|
||||
expected_result_color.append(point.new_zeros((3,)))
|
||||
expected_result_density = torch.stack(expected_result_density, dim=0).view(
|
||||
*points.shape[:-1], 1
|
||||
)
|
||||
expected_result_color = torch.stack(expected_result_color, dim=0).view(
|
||||
*points.shape[:-1], 3
|
||||
)
|
||||
|
||||
# check that thre result is expected
|
||||
assert torch.allclose(result_density, expected_result_density), (
|
||||
result_density,
|
||||
expected_result_density,
|
||||
)
|
||||
assert torch.allclose(result_color, expected_result_color), (
|
||||
result_color,
|
||||
expected_result_color,
|
||||
)
|
||||
|
||||
def test_cropping(self, scaffold_res=9):
|
||||
"""
|
||||
Tests whether implicit function finds the bounding box of the object and sends
|
||||
correct min and max points to voxel grids for rescaling.
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
func = self._get_simple_implicit_function(scaffold_res=scaffold_res).to(device)
|
||||
|
||||
assert scaffold_res >= 8
|
||||
div = (scaffold_res - 1) / 2
|
||||
true_min_point = torch.tensor(
|
||||
[-3 / div, 0 / div, -3 / div],
|
||||
device=device,
|
||||
)
|
||||
true_max_point = torch.tensor(
|
||||
[1 / div, 2 / div, 3 / div],
|
||||
device=device,
|
||||
)
|
||||
|
||||
def new_scaffold(points):
|
||||
# 1 if between true_min and true_max point else 0
|
||||
# return points.new_ones((*points.shape[:-1], 1))
|
||||
return (
|
||||
torch.logical_and(true_min_point <= points, points <= true_max_point)
|
||||
.all(dim=-1)
|
||||
.float()[..., None]
|
||||
)
|
||||
|
||||
called_crop = []
|
||||
|
||||
def assert_min_max_points(min_point, max_point):
|
||||
called_crop.append(1)
|
||||
self.assertClose(min_point, true_min_point)
|
||||
self.assertClose(max_point, true_max_point)
|
||||
|
||||
func.voxel_grid_density.crop_self = assert_min_max_points
|
||||
func.voxel_grid_color.crop_self = assert_min_max_points
|
||||
func.voxel_grid_scaffold.forward = new_scaffold
|
||||
func._scaffold_ready = True
|
||||
func._crop(epoch=0)
|
||||
assert len(called_crop) == 2
|
Loading…
x
Reference in New Issue
Block a user