mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
voxel_grid_implicit_function
Reviewed By: shapovalov Differential Revision: D40622304 fbshipit-source-id: 277515a55c46d9b8300058b439526539a7fe00a0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
611aba9a20
commit
74754bbf17
@@ -52,6 +52,9 @@ from .implicit_function.scene_representation_networks import ( # noqa
|
||||
SRNHyperNetImplicitFunction,
|
||||
SRNImplicitFunction,
|
||||
)
|
||||
from .implicit_function.voxel_grid_implicit_function import ( # noqa
|
||||
VoxelGridImplicitFunction,
|
||||
)
|
||||
|
||||
from .renderer.base import (
|
||||
BaseRenderer,
|
||||
|
||||
@@ -0,0 +1,616 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import fields
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
|
||||
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
|
||||
DecoderFunctionBase,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid import VoxelGridModule
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
enable_get_default_args,
|
||||
get_default_args_field,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
|
||||
enable_get_default_args(HarmonicEmbedding)
|
||||
|
||||
|
||||
@registry.register
|
||||
# pyre-ignore[13]
|
||||
class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
"""
|
||||
This implicit function consists of two streams, one for the density calculation and one
|
||||
for the color calculation. Each of these streams has three main parts:
|
||||
1) Voxel grids:
|
||||
They take the (x, y, z) position and return the embedding of that point.
|
||||
These components are replaceable, you can make your own or choose one of
|
||||
several options.
|
||||
2) Harmonic embeddings:
|
||||
Convert each feature into series of 'harmonic features', feature is passed through
|
||||
sine and cosine functions. Input is of shape [minibatch, ..., D] output
|
||||
[minibatch, ..., (n_harmonic_functions * 2 + int(append_input)) * D]. Appends
|
||||
input by default. If you want it to behave like identity, put n_harmonic_functions=0
|
||||
and append_input=True.
|
||||
3) Decoding functions:
|
||||
The decoder is an instance of the DecoderFunctionBase and converts the embedding
|
||||
of a spatial location to density/color. Examples are Identity which returns its
|
||||
input and the MLP which uses fully connected nerual network to transform the input.
|
||||
These components are replaceable, you can make your own or choose from
|
||||
several options.
|
||||
|
||||
Calculating density is done in three steps:
|
||||
1) Evaluating the voxel grid on points
|
||||
2) Embedding the outputs with harmonic embedding
|
||||
3) Passing through the Density decoder
|
||||
|
||||
To calculate the color we need the embedding and the viewing direction, it has five steps:
|
||||
1) Transforming the viewing direction with camera
|
||||
2) Evaluating the voxel grid on points
|
||||
3) Embedding the outputs with harmonic embedding
|
||||
4) Embedding the normalized direction with harmonic embedding
|
||||
5) Passing everything through the Color decoder
|
||||
|
||||
If using the Implicitron configuration system the input_dim to the decoding functions will
|
||||
be set to the output_dim of the Harmonic embeddings.
|
||||
|
||||
A speed up comes from using the scaffold, a low resolution voxel grid.
|
||||
The scaffold is referenced as "binary occupancy grid mask" in TensoRF paper and "AlphaMask"
|
||||
in official TensoRF implementation.
|
||||
The scaffold is used in:
|
||||
1) filtering points in empty space
|
||||
- controlled by `scaffold_filter_points` boolean. If set to True, points for which
|
||||
scaffold predicts that are in empty space will return 0 density and
|
||||
(0, 0, 0) color.
|
||||
2) calculating the bounding box of an object and cropping the voxel grids
|
||||
- controlled by `volume_cropping_epochs`.
|
||||
- at those epochs the implicit function will find the bounding box of an object
|
||||
inside it and crop density and color grids. Cropping of the voxel grids means
|
||||
preserving only voxel values that are inside the bounding box and changing the
|
||||
resolution to match the original, while preserving the new cropped location in
|
||||
world coordinates.
|
||||
|
||||
The scaffold has to exist before attempting filtering and cropping, and is created on
|
||||
`scaffold_calculating_epochs`. Each voxel in the scaffold is labeled as having density 1 if
|
||||
the point in the center of it evaluates to greater than `scaffold_empty_space_threshold`.
|
||||
3D max pooling is performed on the densities of the points in 3D.
|
||||
Scaffold features are off by default.
|
||||
|
||||
Members:
|
||||
voxel_grid_density (VoxelGridBase): voxel grid to use for density estimation
|
||||
voxel_grid_color (VoxelGridBase): voxel grid to use for color estimation
|
||||
|
||||
harmonic_embedder_xyz_density (HarmonicEmbedder): Function to transform the outputs of
|
||||
the voxel_grid_density
|
||||
harmonic_embedder_xyz_color (HarmonicEmbedder): Function to transform the outputs of
|
||||
the voxel_grid_color for density
|
||||
harmonic_embedder_dir_color (HarmonicEmbedder): Function to transform the outputs of
|
||||
the voxel_grid_color for color
|
||||
|
||||
decoder_density (DecoderFunctionBase): decoder function to use for density estimation
|
||||
color_density (DecoderFunctionBase): decoder function to use for color estimation
|
||||
|
||||
use_multiple_streams (bool): if you want the density and color calculations to run on
|
||||
different cuda streams set this to True. Default True.
|
||||
xyz_ray_dir_in_camera_coords (bool): This is true if the directions are given in
|
||||
camera coordinates. Default False.
|
||||
|
||||
voxel_grid_scaffold (VoxelGridModule): which holds the scaffold. Extents and
|
||||
translation of it are set to those of voxel_grid_density.
|
||||
scaffold_calculating_epochs (Tuple[int, ...]): at which epochs to recalculate the
|
||||
scaffold. (The scaffold will be created automatically at the beginning of
|
||||
the calculation.)
|
||||
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
|
||||
voxel grid which stores scaffold
|
||||
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
|
||||
this it will be considered as empty space and the scaffold at that point would
|
||||
evaluate as empty space.
|
||||
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
|
||||
at the same time. To calculate the scaffold we need to query `get_density()` at
|
||||
every voxel, this calculation can be split into scaffold depth number of xy plane
|
||||
calculations if you want the lowest memory usage, one calculation to calculate the
|
||||
whole scaffold, but with higher memory footprint or any other number of planes.
|
||||
Setting to 'inf' calculates all planes at the same time. Defaults to 'inf'.
|
||||
scaffold_max_pool_kernel_size (int): Size of the pooling region to use when
|
||||
calculating the scaffold. Defaults to 3.
|
||||
scaffold_filter_points (bool): If set to True the points will be filtered using
|
||||
`self.voxel_grid_scaffold`. Filtered points will be predicted as having 0 density
|
||||
and (0, 0, 0) color. The points which were not evaluated as empty space will be
|
||||
passed through the steps outlined above.
|
||||
volume_cropping_epochs: on which epochs to crop the voxel grids to fit the object's
|
||||
bounding box. Scaffold has to be calculated before cropping.
|
||||
"""
|
||||
|
||||
# ---- voxel grid for density
|
||||
voxel_grid_density: VoxelGridModule
|
||||
|
||||
# ---- voxel grid for color
|
||||
voxel_grid_color: VoxelGridModule
|
||||
|
||||
# ---- harmonic embeddings density
|
||||
harmonic_embedder_xyz_density_args: DictConfig = get_default_args_field(
|
||||
HarmonicEmbedding
|
||||
)
|
||||
harmonic_embedder_xyz_color_args: DictConfig = get_default_args_field(
|
||||
HarmonicEmbedding
|
||||
)
|
||||
harmonic_embedder_dir_color_args: DictConfig = get_default_args_field(
|
||||
HarmonicEmbedding
|
||||
)
|
||||
|
||||
# ---- decoder function for density
|
||||
decoder_density_class_type: str = "MLPDecoder"
|
||||
decoder_density: DecoderFunctionBase
|
||||
|
||||
# ---- decoder function for color
|
||||
decoder_color_class_type: str = "MLPDecoder"
|
||||
decoder_color: DecoderFunctionBase
|
||||
|
||||
# ---- cuda streams
|
||||
use_multiple_streams: bool = True
|
||||
|
||||
# ---- camera
|
||||
xyz_ray_dir_in_camera_coords: bool = False
|
||||
|
||||
# --- scaffold
|
||||
# voxel_grid_scaffold: VoxelGridModule
|
||||
scaffold_calculating_epochs: Tuple[int, ...] = ()
|
||||
scaffold_resolution: Tuple[int, int, int] = (128, 128, 128)
|
||||
scaffold_empty_space_threshold: float = 0.001
|
||||
scaffold_occupancy_chunk_size: Union[str, int] = "inf"
|
||||
scaffold_max_pool_kernel_size: int = 3
|
||||
scaffold_filter_points: bool = True
|
||||
|
||||
# --- cropping
|
||||
volume_cropping_epochs: Tuple[int, ...] = ()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
# pyre-ignore[16]
|
||||
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
||||
# pyre-ignore[16]
|
||||
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_xyz_density_args
|
||||
)
|
||||
# pyre-ignore[16]
|
||||
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_xyz_color_args
|
||||
)
|
||||
# pyre-ignore[16]
|
||||
self.harmonic_embedder_dir_color = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_dir_color_args
|
||||
)
|
||||
# pyre-ignore[16]
|
||||
self._scaffold_ready = False
|
||||
if type(self.scaffold_occupancy_chunk_size) != int:
|
||||
if self.scaffold_occupancy_chunk_size != "inf":
|
||||
raise ValueError(
|
||||
"`scaffold_occupancy_chunk_size` has to be int or 'inf'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ray_bundle: ImplicitronRayBundle,
|
||||
fun_viewpool=None,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
global_code=None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
|
||||
"""
|
||||
The forward function accepts the parametrizations of 3D points sampled along
|
||||
projection rays. The forward pass is responsible for attaching a 3D vector
|
||||
and a 1D scalar representing the point's RGB color and opacity respectively.
|
||||
|
||||
Args:
|
||||
ray_bundle: An ImplicitronRayBundle object containing the following variables:
|
||||
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||
origins of the sampling rays in world coords.
|
||||
directions: A tensor of shape `(minibatch, ..., 3)`
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||
containing the lengths at which the rays are sampled.
|
||||
fun_viewpool: an optional callback with the signature
|
||||
fun_fiewpool(points) -> pooled_features
|
||||
where points is a [N_TGT x N x 3] tensor of world coords,
|
||||
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
|
||||
of the features pooled from the context images.
|
||||
camera: A camera model which will be used to transform the viewing
|
||||
directions
|
||||
|
||||
Returns:
|
||||
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
|
||||
denoting the opacitiy of each ray point.
|
||||
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||
denoting the color of each ray point.
|
||||
"""
|
||||
# ########## convert the ray parametrizations to world coordinates ########## #
|
||||
# points.shape = [minibatch x n_rays_width x n_rays_height x pts_per_ray x 3]
|
||||
# pyre-ignore[6]
|
||||
points = ray_bundle_to_ray_points(ray_bundle)
|
||||
directions = ray_bundle.directions.reshape(-1, 3)
|
||||
input_shape = points.shape
|
||||
points = points.view(-1, 3)
|
||||
|
||||
# ########## filter the points using the scaffold ########## #
|
||||
if self._scaffold_ready and self.scaffold_filter_points:
|
||||
# pyre-ignore[29]
|
||||
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
points = points[non_empty_points]
|
||||
directions = directions[non_empty_points]
|
||||
if len(points) == 0:
|
||||
warnings.warn(
|
||||
"The scaffold has filtered all the points."
|
||||
"The voxel grids and decoding functions will not be run."
|
||||
)
|
||||
return (
|
||||
points.new_zeros((*input_shape[:-1], 1)),
|
||||
points.new_zeros((*input_shape[:-1], 3)),
|
||||
{},
|
||||
)
|
||||
|
||||
# ########## calculate color and density ########## #
|
||||
rays_densities, rays_colors = self.calculate_density_and_color(
|
||||
points, directions, camera
|
||||
)
|
||||
|
||||
if not (self._scaffold_ready and self.scaffold_filter_points):
|
||||
return (
|
||||
rays_densities.view((*input_shape[:-1], rays_densities.shape[-1])),
|
||||
rays_colors.view((*input_shape[:-1], rays_colors.shape[-1])),
|
||||
{},
|
||||
)
|
||||
|
||||
# ########## merge scaffold calculated points ########## #
|
||||
# Create a zeroed tensor corresponding to a point with density=0 and fill it
|
||||
# with calculated density for points which are not in empty space. Do the
|
||||
# same for color
|
||||
rays_densities_combined = rays_densities.new_zeros(
|
||||
(math.prod(input_shape[:-1]), rays_densities.shape[-1])
|
||||
)
|
||||
rays_colors_combined = rays_colors.new_zeros(
|
||||
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
|
||||
)
|
||||
# pyre-ignore[61]
|
||||
rays_densities_combined[non_empty_points] = rays_densities
|
||||
# pyre-ignore[61]
|
||||
rays_colors_combined[non_empty_points] = rays_colors
|
||||
|
||||
return (
|
||||
rays_densities_combined.view((*input_shape[:-1], rays_densities.shape[-1])),
|
||||
rays_colors_combined.view((*input_shape[:-1], rays_colors.shape[-1])),
|
||||
{},
|
||||
)
|
||||
|
||||
def calculate_density_and_color(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
directions: torch.Tensor,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calculates density and color at `points`.
|
||||
If enabled use cuda streams.
|
||||
|
||||
Args:
|
||||
points: points at which to calculate density and color.
|
||||
Tensor of shape [..., 3].
|
||||
directions: from which directions are the points viewed
|
||||
Tensor of shape [..., 3].
|
||||
camera: A camera model which will be used to transform the viewing
|
||||
directions
|
||||
Returns:
|
||||
Tuple of color (tensor of shape [..., 3]) and density
|
||||
(tensor of shape [..., 1])
|
||||
"""
|
||||
if self.use_multiple_streams and points.is_cuda:
|
||||
current_stream = torch.cuda.current_stream(points.device)
|
||||
other_stream = torch.cuda.Stream(points.device)
|
||||
other_stream.wait_stream(current_stream)
|
||||
|
||||
with torch.cuda.stream(other_stream):
|
||||
# rays_densities.shape =
|
||||
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
|
||||
rays_densities = self.get_density(points)
|
||||
|
||||
# rays_colors.shape =
|
||||
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
|
||||
rays_colors = self.get_color(points, camera, directions)
|
||||
|
||||
current_stream.wait_stream(other_stream)
|
||||
else:
|
||||
# Same calculation as above, just serial.
|
||||
rays_densities = self.get_density(points)
|
||||
rays_colors = self.get_color(points, camera, directions)
|
||||
return rays_densities, rays_colors
|
||||
|
||||
def get_density(self, points: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates density at points:
|
||||
1) Evaluates the voxel grid on points
|
||||
2) Embeds the outputs with harmonic embedding
|
||||
3) Passes everything through the Density decoder
|
||||
|
||||
Args:
|
||||
points: tensor of shape [..., 3]
|
||||
where the last dimension is the points in the (x, y, z)
|
||||
Returns:
|
||||
calculated densities of shape [..., density_dim], `density_dim` is the
|
||||
feature dimensionality which `decoder_density` returns
|
||||
"""
|
||||
embeds_density = self.voxel_grid_density(points)
|
||||
# pyre-ignore[29]
|
||||
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
|
||||
# shape = [..., density_dim]
|
||||
return self.decoder_density(harmonic_embedding_density)
|
||||
|
||||
def get_color(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
camera: Optional[CamerasBase],
|
||||
directions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates color at points using the viewing direction:
|
||||
1) Transforms the viewing direction with camera
|
||||
2) Evaluates the voxel grid on points
|
||||
3) Embeds the outputs with harmonic embedding
|
||||
4) Embeds the normalized direction with harmonic embedding
|
||||
5) Passes everything through the Color decoder
|
||||
Args:
|
||||
points: tensor of shape (..., 3)
|
||||
where the last dimension is the points in the (x, y, z)
|
||||
camera: A camera model which will be used to transform the viewing
|
||||
directions
|
||||
directions: A tensor of shape `(..., 3)`
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
"""
|
||||
# ########## transform direction ########## #
|
||||
if self.xyz_ray_dir_in_camera_coords:
|
||||
if camera is None:
|
||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||
directions = directions @ camera.R
|
||||
|
||||
# ########## get voxel grid output ########## #
|
||||
# embeds_color.shape = [..., pts_per_ray, n_features]
|
||||
embeds_color = self.voxel_grid_color(points)
|
||||
|
||||
# ########## embed with the harmonic function ########## #
|
||||
# Obtain the harmonic embedding of the voxel grid output.
|
||||
# pyre-ignore[29]
|
||||
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
|
||||
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
# pyre-ignore[29]
|
||||
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
|
||||
rays_directions_normed
|
||||
)
|
||||
|
||||
n_rays = directions.shape[0]
|
||||
points_per_ray: int = points.shape[0] // n_rays
|
||||
|
||||
harmonic_embedding_dir = torch.repeat_interleave(
|
||||
harmonic_embedding_dir, points_per_ray, dim=0
|
||||
)
|
||||
|
||||
# total color embedding is concatenation of the harmonic embedding of voxel grid
|
||||
# output and harmonic embedding of the normalized direction
|
||||
total_color_embedding = torch.cat(
|
||||
(harmonic_embedding_color, harmonic_embedding_dir), dim=-1
|
||||
)
|
||||
|
||||
# ########## evaluate color with the decoding function ########## #
|
||||
# rays_colors.shape = [..., pts_per_ray, 3] in [0-1]
|
||||
return self.decoder_color(total_color_embedding)
|
||||
|
||||
@staticmethod
|
||||
def allows_multiple_passes() -> bool:
|
||||
"""
|
||||
Returns True as this implicit function allows
|
||||
multiple passes. Overridden from ImplicitFunctionBase.
|
||||
"""
|
||||
return True
|
||||
|
||||
def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
|
||||
"""
|
||||
Method which expresses interest in subscribing to optimization epoch updates.
|
||||
This implicit function subscribes to epochs to calculate the scaffold and to
|
||||
crop voxel grids, so this method combines wanted epochs and wraps their callbacks.
|
||||
|
||||
Returns:
|
||||
list of epochs on which to call a callable and callable to be called on
|
||||
particular epoch. The callable returns True if parameter change has
|
||||
happened else False and it must be supplied with one argument, epoch.
|
||||
"""
|
||||
|
||||
def callback(epoch) -> bool:
|
||||
change = False
|
||||
if epoch in self.scaffold_calculating_epochs:
|
||||
change = self._get_scaffold(epoch)
|
||||
if epoch in self.volume_cropping_epochs:
|
||||
change = self._crop(epoch) or change
|
||||
return change
|
||||
|
||||
# remove duplicates
|
||||
call_epochs = list(
|
||||
set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs)
|
||||
)
|
||||
return call_epochs, callback
|
||||
|
||||
def _crop(self, epoch: int) -> bool:
|
||||
"""
|
||||
Finds the bounding box of an object represented in the scaffold and crops
|
||||
density and color voxel grids to match that bounding box. If density of the
|
||||
scaffold is 0 everywhere (there is no object in it) no change will
|
||||
happen.
|
||||
|
||||
Args:
|
||||
epoch: ignored
|
||||
Returns:
|
||||
True (indicating that parameter change has happened) if there is
|
||||
an object inside, else False.
|
||||
"""
|
||||
# find bounding box
|
||||
# pyre-ignore[16]
|
||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
|
||||
# pyre-ignore[29]
|
||||
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
non_zero_idxs = torch.nonzero(occupancy)
|
||||
if len(non_zero_idxs) == 0:
|
||||
return False
|
||||
min_indices = tuple(torch.min(non_zero_idxs, dim=0)[0])
|
||||
max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0])
|
||||
min_point, max_point = points[min_indices], points[max_indices]
|
||||
|
||||
# crop the voxel grids
|
||||
self.voxel_grid_density.crop_self(min_point, max_point)
|
||||
self.voxel_grid_color.crop_self(min_point, max_point)
|
||||
return True
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_scaffold(self, epoch: int) -> bool:
|
||||
"""
|
||||
Creates a low resolution grid which is used to filter points that are in empty
|
||||
space.
|
||||
|
||||
Args:
|
||||
epoch: epoch on which it is called, ignored inside method
|
||||
Returns:
|
||||
Always False: Modifies `self.voxel_grid_scaffold` member.
|
||||
"""
|
||||
|
||||
planes = []
|
||||
# pyre-ignore[16]
|
||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||
|
||||
chunk_size = (
|
||||
self.scaffold_occupancy_chunk_size
|
||||
if type(self.scaffold_occupancy_chunk_size) == int
|
||||
else points.shape[-1]
|
||||
)
|
||||
for k in range(0, points.shape[-1], chunk_size):
|
||||
points_in_planes = points[..., k : k + chunk_size]
|
||||
planes.append(self.get_density(points_in_planes)[..., 0])
|
||||
|
||||
density_cube = torch.cat(planes, dim=-1)
|
||||
density_cube = torch.nn.functional.max_pool3d(
|
||||
density_cube[None, None],
|
||||
kernel_size=self.scaffold_max_pool_kernel_size,
|
||||
padding=self.scaffold_max_pool_kernel_size // 2,
|
||||
stride=1,
|
||||
)
|
||||
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
|
||||
# pyre-ignore[16]
|
||||
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
|
||||
# pyre-ignore[16]
|
||||
self._scaffold_ready = True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
args.pop("input_dim", None)
|
||||
|
||||
def create_decoder_density_impl(self, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
||||
calculate the input dimension of the decoder in the config file this function
|
||||
calculates the required input dimension and sets the input dimension of the
|
||||
decoding function to this value.
|
||||
"""
|
||||
grid_args = self.voxel_grid_density_args
|
||||
# pyre-ignore[6]
|
||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||
|
||||
embedder_args = self.harmonic_embedder_xyz_density_args
|
||||
input_dim = HarmonicEmbedding.get_output_dim_static(
|
||||
grid_output_dim,
|
||||
embedder_args["n_harmonic_functions"],
|
||||
embedder_args["append_input"],
|
||||
)
|
||||
|
||||
cls = registry.get(DecoderFunctionBase, type)
|
||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||
if need_input_dim:
|
||||
self.decoder_density = cls(input_dim=input_dim, **args)
|
||||
else:
|
||||
self.decoder_density = cls(**args)
|
||||
|
||||
@classmethod
|
||||
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
args.pop("input_dim", None)
|
||||
|
||||
def create_decoder_color_impl(self, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
||||
calculate the input dimension of the decoder in the config file this function
|
||||
calculates the required input dimension and sets the input dimension of the
|
||||
decoding function to this value.
|
||||
"""
|
||||
grid_args = self.voxel_grid_color_args
|
||||
# pyre-ignore[6]
|
||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||
|
||||
embedder_args = self.harmonic_embedder_xyz_color_args
|
||||
input_dim0 = HarmonicEmbedding.get_output_dim_static(
|
||||
grid_output_dim,
|
||||
embedder_args["n_harmonic_functions"],
|
||||
embedder_args["append_input"],
|
||||
)
|
||||
|
||||
dir_dim = 3
|
||||
embedder_args = self.harmonic_embedder_dir_color_args
|
||||
input_dim1 = HarmonicEmbedding.get_output_dim_static(
|
||||
dir_dim,
|
||||
embedder_args["n_harmonic_functions"],
|
||||
embedder_args["append_input"],
|
||||
)
|
||||
|
||||
input_dim = input_dim0 + input_dim1
|
||||
|
||||
cls = registry.get(DecoderFunctionBase, type)
|
||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||
if need_input_dim:
|
||||
self.decoder_color = cls(input_dim=input_dim, **args)
|
||||
else:
|
||||
self.decoder_color = cls(**args)
|
||||
|
||||
def _create_voxel_grid_scaffold(self) -> VoxelGridModule:
|
||||
"""
|
||||
Creates object to become self.voxel_grid_scaffold:
|
||||
- makes `self.voxel_grid_scaffold` have same world to local mapping as
|
||||
`self.voxel_grid_density`
|
||||
"""
|
||||
return VoxelGridModule(
|
||||
# pyre-ignore[29]
|
||||
extents=self.voxel_grid_density_args["extents"],
|
||||
# pyre-ignore[29]
|
||||
translation=self.voxel_grid_density_args["translation"],
|
||||
voxel_grid_class_type="FullResolutionVoxelGrid",
|
||||
hold_voxel_grid_as_parameters=False,
|
||||
voxel_grid_FullResolutionVoxelGrid_args={
|
||||
"resolution_changes": {0: self.scaffold_resolution},
|
||||
"padding": "zeros",
|
||||
"align_corners": True,
|
||||
"mode": "trilinear",
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user