mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Decoding functions
Summary: Added replacable decoding functions which will be applied after the voxel grid to get color and density Reviewed By: bottler Differential Revision: D38829763 fbshipit-source-id: f21ce206c1c19548206ea2ce97d7ebea3de30a23
This commit is contained in:
parent
24f5f4a3e7
commit
e7c609f198
@ -4,16 +4,66 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This file contains
|
||||
- modules which get used by ImplicitFunction objects for decoding an embedding defined in
|
||||
space, e.g. to color or opacity.
|
||||
- DecoderFunctionBase and its subclasses, which wrap some of those modules, providing
|
||||
some such modules as an extension point which an ImplicitFunction object could use.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MLPWithInputSkips(torch.nn.Module):
|
||||
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Decoding function is a torch.nn.Module which takes the embedding of a location in
|
||||
space and transforms it into the required quantity (for example density and color).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
features (torch.Tensor): tensor of shape (batch, ..., num_in_features)
|
||||
z: optional tensor to append to parts of the decoding function
|
||||
Returns:
|
||||
decoded_features (torch.Tensor) : tensor of
|
||||
shape (batch, ..., num_out_features)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class IdentityDecoder(DecoderFunctionBase):
|
||||
"""
|
||||
Decoding function which returns its input.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
return features
|
||||
|
||||
|
||||
class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
"""
|
||||
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
|
||||
|
||||
@ -31,8 +81,56 @@ class MLPWithInputSkips(torch.nn.Module):
|
||||
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
|
||||
NeRF: Representing Scenes as Neural Radiance Fields for View
|
||||
Synthesis, ECCV2020
|
||||
|
||||
Members:
|
||||
n_layers: The number of linear layers of the MLP.
|
||||
input_dim: The number of channels of the input tensor.
|
||||
output_dim: The number of channels of the output.
|
||||
skip_dim: The number of channels of the tensor `z` appended when
|
||||
evaluating the skip layers.
|
||||
hidden_dim: The number of hidden units of the MLP.
|
||||
input_skips: The list of layer indices at which we append the skip
|
||||
tensor `z`.
|
||||
"""
|
||||
|
||||
n_layers: int = 8
|
||||
input_dim: int = 39
|
||||
output_dim: int = 256
|
||||
skip_dim: int = 39
|
||||
hidden_dim: int = 256
|
||||
input_skips: Tuple[int, ...] = (5,)
|
||||
skip_affine_trans: bool = False
|
||||
no_last_relu = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
layers = []
|
||||
skip_affine_layers = []
|
||||
for layeri in range(self.n_layers):
|
||||
dimin = self.hidden_dim if layeri > 0 else self.input_dim
|
||||
dimout = self.hidden_dim if layeri + 1 < self.n_layers else self.output_dim
|
||||
|
||||
if layeri > 0 and layeri in self.input_skips:
|
||||
if self.skip_affine_trans:
|
||||
skip_affine_layers.append(
|
||||
self._make_affine_layer(self.skip_dim, self.hidden_dim)
|
||||
)
|
||||
else:
|
||||
dimin = self.hidden_dim + self.skip_dim
|
||||
|
||||
linear = torch.nn.Linear(dimin, dimout)
|
||||
_xavier_init(linear)
|
||||
layers.append(
|
||||
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
||||
if not self.no_last_relu or layeri + 1 < self.n_layers
|
||||
else linear
|
||||
)
|
||||
self.mlp = torch.nn.ModuleList(layers)
|
||||
if self.skip_affine_trans:
|
||||
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
|
||||
self._input_skips = set(self.input_skips)
|
||||
self._skip_affine_trans = self.skip_affine_trans
|
||||
|
||||
def _make_affine_layer(self, input_dim, hidden_dim):
|
||||
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
|
||||
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
|
||||
@ -46,56 +144,6 @@ class MLPWithInputSkips(torch.nn.Module):
|
||||
std = torch.nn.functional.softplus(log_std)
|
||||
return (x - mu) * std
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int = 8,
|
||||
input_dim: int = 39,
|
||||
output_dim: int = 256,
|
||||
skip_dim: int = 39,
|
||||
hidden_dim: int = 256,
|
||||
input_skips: Tuple[int, ...] = (5,),
|
||||
skip_affine_trans: bool = False,
|
||||
no_last_relu=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
n_layers: The number of linear layers of the MLP.
|
||||
input_dim: The number of channels of the input tensor.
|
||||
output_dim: The number of channels of the output.
|
||||
skip_dim: The number of channels of the tensor `z` appended when
|
||||
evaluating the skip layers.
|
||||
hidden_dim: The number of hidden units of the MLP.
|
||||
input_skips: The list of layer indices at which we append the skip
|
||||
tensor `z`.
|
||||
"""
|
||||
super().__init__()
|
||||
layers = []
|
||||
skip_affine_layers = []
|
||||
for layeri in range(n_layers):
|
||||
dimin = hidden_dim if layeri > 0 else input_dim
|
||||
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
|
||||
|
||||
if layeri > 0 and layeri in input_skips:
|
||||
if skip_affine_trans:
|
||||
skip_affine_layers.append(
|
||||
self._make_affine_layer(skip_dim, hidden_dim)
|
||||
)
|
||||
else:
|
||||
dimin = hidden_dim + skip_dim
|
||||
|
||||
linear = torch.nn.Linear(dimin, dimout)
|
||||
_xavier_init(linear)
|
||||
layers.append(
|
||||
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
||||
if not no_last_relu or layeri + 1 < n_layers
|
||||
else linear
|
||||
)
|
||||
self.mlp = torch.nn.ModuleList(layers)
|
||||
if skip_affine_trans:
|
||||
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
|
||||
self._input_skips = set(input_skips)
|
||||
self._skip_affine_trans = skip_affine_trans
|
||||
|
||||
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Args:
|
||||
@ -121,6 +169,24 @@ class MLPWithInputSkips(torch.nn.Module):
|
||||
return y
|
||||
|
||||
|
||||
@registry.register
|
||||
class MLPDecoder(DecoderFunctionBase):
|
||||
"""
|
||||
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
||||
"""
|
||||
|
||||
network: MLPWithInputSkips
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def forward(
|
||||
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
return self.network(features, z)
|
||||
|
||||
|
||||
class TransformerWithInputSkips(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -9,7 +9,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
@ -214,6 +214,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
|
||||
append_xyz: Tuple[int, ...] = (5,)
|
||||
|
||||
def _construct_xyz_encoder(self, input_dim: int):
|
||||
expand_args_fields(MLPWithInputSkips)
|
||||
return MLPWithInputSkips(
|
||||
self.n_layers_xyz,
|
||||
input_dim,
|
||||
|
34
tests/implicitron/test_decoding_functions.py
Normal file
34
tests/implicitron/test_decoding_functions.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
|
||||
IdentityDecoder,
|
||||
MLPDecoder,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
expand_args_fields(IdentityDecoder)
|
||||
expand_args_fields(MLPDecoder)
|
||||
|
||||
def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2):
|
||||
"""
|
||||
Test that identity function returns its input
|
||||
"""
|
||||
func = IdentityDecoder()
|
||||
for _ in range(n_tests):
|
||||
_in = torch.randn(in_shape)
|
||||
assert torch.allclose(func(_in), _in)
|
Loading…
x
Reference in New Issue
Block a user