Decoding functions

Summary: Added replacable decoding functions which will be applied after the voxel grid to get color and density

Reviewed By: bottler

Differential Revision: D38829763

fbshipit-source-id: f21ce206c1c19548206ea2ce97d7ebea3de30a23
This commit is contained in:
Darijan Gudelj 2022-08-26 08:47:30 -07:00 committed by Facebook GitHub Bot
parent 24f5f4a3e7
commit e7c609f198
3 changed files with 153 additions and 52 deletions

View File

@ -4,16 +4,66 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
This file contains
- modules which get used by ImplicitFunction objects for decoding an embedding defined in
space, e.g. to color or opacity.
- DecoderFunctionBase and its subclasses, which wrap some of those modules, providing
some such modules as an extension point which an ImplicitFunction object could use.
"""
import logging
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import (
Configurable,
registry,
ReplaceableBase,
run_auto_creation,
)
logger = logging.getLogger(__name__)
class MLPWithInputSkips(torch.nn.Module):
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
"""
Decoding function is a torch.nn.Module which takes the embedding of a location in
space and transforms it into the required quantity (for example density and color).
"""
def __post_init__(self):
super().__init__()
def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
features (torch.Tensor): tensor of shape (batch, ..., num_in_features)
z: optional tensor to append to parts of the decoding function
Returns:
decoded_features (torch.Tensor) : tensor of
shape (batch, ..., num_out_features)
"""
raise NotImplementedError()
@registry.register
class IdentityDecoder(DecoderFunctionBase):
"""
Decoding function which returns its input.
"""
def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
return features
class MLPWithInputSkips(Configurable, torch.nn.Module):
"""
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
@ -31,8 +81,56 @@ class MLPWithInputSkips(torch.nn.Module):
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
NeRF: Representing Scenes as Neural Radiance Fields for View
Synthesis, ECCV2020
Members:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
n_layers: int = 8
input_dim: int = 39
output_dim: int = 256
skip_dim: int = 39
hidden_dim: int = 256
input_skips: Tuple[int, ...] = (5,)
skip_affine_trans: bool = False
no_last_relu = False
def __post_init__(self):
super().__init__()
layers = []
skip_affine_layers = []
for layeri in range(self.n_layers):
dimin = self.hidden_dim if layeri > 0 else self.input_dim
dimout = self.hidden_dim if layeri + 1 < self.n_layers else self.output_dim
if layeri > 0 and layeri in self.input_skips:
if self.skip_affine_trans:
skip_affine_layers.append(
self._make_affine_layer(self.skip_dim, self.hidden_dim)
)
else:
dimin = self.hidden_dim + self.skip_dim
linear = torch.nn.Linear(dimin, dimout)
_xavier_init(linear)
layers.append(
torch.nn.Sequential(linear, torch.nn.ReLU(True))
if not self.no_last_relu or layeri + 1 < self.n_layers
else linear
)
self.mlp = torch.nn.ModuleList(layers)
if self.skip_affine_trans:
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
self._input_skips = set(self.input_skips)
self._skip_affine_trans = self.skip_affine_trans
def _make_affine_layer(self, input_dim, hidden_dim):
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
@ -46,56 +144,6 @@ class MLPWithInputSkips(torch.nn.Module):
std = torch.nn.functional.softplus(log_std)
return (x - mu) * std
def __init__(
self,
n_layers: int = 8,
input_dim: int = 39,
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 256,
input_skips: Tuple[int, ...] = (5,),
skip_affine_trans: bool = False,
no_last_relu=False,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super().__init__()
layers = []
skip_affine_layers = []
for layeri in range(n_layers):
dimin = hidden_dim if layeri > 0 else input_dim
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
if layeri > 0 and layeri in input_skips:
if skip_affine_trans:
skip_affine_layers.append(
self._make_affine_layer(skip_dim, hidden_dim)
)
else:
dimin = hidden_dim + skip_dim
linear = torch.nn.Linear(dimin, dimout)
_xavier_init(linear)
layers.append(
torch.nn.Sequential(linear, torch.nn.ReLU(True))
if not no_last_relu or layeri + 1 < n_layers
else linear
)
self.mlp = torch.nn.ModuleList(layers)
if skip_affine_trans:
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
self._input_skips = set(input_skips)
self._skip_affine_trans = skip_affine_trans
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
"""
Args:
@ -121,6 +169,24 @@ class MLPWithInputSkips(torch.nn.Module):
return y
@registry.register
class MLPDecoder(DecoderFunctionBase):
"""
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
"""
network: MLPWithInputSkips
def __post_init__(self):
super().__post_init__()
run_auto_creation(self)
def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
return self.network(features, z)
class TransformerWithInputSkips(torch.nn.Module):
def __init__(
self,

View File

@ -9,7 +9,7 @@ from typing import Optional, Tuple
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
@ -214,6 +214,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
append_xyz: Tuple[int, ...] = (5,)
def _construct_xyz_encoder(self, input_dim: int):
expand_args_fields(MLPWithInputSkips)
return MLPWithInputSkips(
self.n_layers_xyz,
input_dim,

View File

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
IdentityDecoder,
MLPDecoder,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from tests.common_testing import TestCaseMixin
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
expand_args_fields(IdentityDecoder)
expand_args_fields(MLPDecoder)
def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2):
"""
Test that identity function returns its input
"""
func = IdentityDecoder()
for _ in range(n_tests):
_in = torch.randn(in_shape)
assert torch.allclose(func(_in), _in)