mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Decoding functions
Summary: Added replacable decoding functions which will be applied after the voxel grid to get color and density Reviewed By: bottler Differential Revision: D38829763 fbshipit-source-id: f21ce206c1c19548206ea2ce97d7ebea3de30a23
This commit is contained in:
parent
24f5f4a3e7
commit
e7c609f198
@ -4,16 +4,66 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file contains
|
||||||
|
- modules which get used by ImplicitFunction objects for decoding an embedding defined in
|
||||||
|
space, e.g. to color or opacity.
|
||||||
|
- DecoderFunctionBase and its subclasses, which wrap some of those modules, providing
|
||||||
|
some such modules as an extension point which an ImplicitFunction object could use.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
Configurable,
|
||||||
|
registry,
|
||||||
|
ReplaceableBase,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MLPWithInputSkips(torch.nn.Module):
|
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Decoding function is a torch.nn.Module which takes the embedding of a location in
|
||||||
|
space and transforms it into the required quantity (for example density and color).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features (torch.Tensor): tensor of shape (batch, ..., num_in_features)
|
||||||
|
z: optional tensor to append to parts of the decoding function
|
||||||
|
Returns:
|
||||||
|
decoded_features (torch.Tensor) : tensor of
|
||||||
|
shape (batch, ..., num_out_features)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class IdentityDecoder(DecoderFunctionBase):
|
||||||
|
"""
|
||||||
|
Decoding function which returns its input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
|
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
|
||||||
|
|
||||||
@ -31,8 +81,56 @@ class MLPWithInputSkips(torch.nn.Module):
|
|||||||
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
|
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
|
||||||
NeRF: Representing Scenes as Neural Radiance Fields for View
|
NeRF: Representing Scenes as Neural Radiance Fields for View
|
||||||
Synthesis, ECCV2020
|
Synthesis, ECCV2020
|
||||||
|
|
||||||
|
Members:
|
||||||
|
n_layers: The number of linear layers of the MLP.
|
||||||
|
input_dim: The number of channels of the input tensor.
|
||||||
|
output_dim: The number of channels of the output.
|
||||||
|
skip_dim: The number of channels of the tensor `z` appended when
|
||||||
|
evaluating the skip layers.
|
||||||
|
hidden_dim: The number of hidden units of the MLP.
|
||||||
|
input_skips: The list of layer indices at which we append the skip
|
||||||
|
tensor `z`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
n_layers: int = 8
|
||||||
|
input_dim: int = 39
|
||||||
|
output_dim: int = 256
|
||||||
|
skip_dim: int = 39
|
||||||
|
hidden_dim: int = 256
|
||||||
|
input_skips: Tuple[int, ...] = (5,)
|
||||||
|
skip_affine_trans: bool = False
|
||||||
|
no_last_relu = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
skip_affine_layers = []
|
||||||
|
for layeri in range(self.n_layers):
|
||||||
|
dimin = self.hidden_dim if layeri > 0 else self.input_dim
|
||||||
|
dimout = self.hidden_dim if layeri + 1 < self.n_layers else self.output_dim
|
||||||
|
|
||||||
|
if layeri > 0 and layeri in self.input_skips:
|
||||||
|
if self.skip_affine_trans:
|
||||||
|
skip_affine_layers.append(
|
||||||
|
self._make_affine_layer(self.skip_dim, self.hidden_dim)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dimin = self.hidden_dim + self.skip_dim
|
||||||
|
|
||||||
|
linear = torch.nn.Linear(dimin, dimout)
|
||||||
|
_xavier_init(linear)
|
||||||
|
layers.append(
|
||||||
|
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
||||||
|
if not self.no_last_relu or layeri + 1 < self.n_layers
|
||||||
|
else linear
|
||||||
|
)
|
||||||
|
self.mlp = torch.nn.ModuleList(layers)
|
||||||
|
if self.skip_affine_trans:
|
||||||
|
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
|
||||||
|
self._input_skips = set(self.input_skips)
|
||||||
|
self._skip_affine_trans = self.skip_affine_trans
|
||||||
|
|
||||||
def _make_affine_layer(self, input_dim, hidden_dim):
|
def _make_affine_layer(self, input_dim, hidden_dim):
|
||||||
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
|
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
|
||||||
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
|
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
|
||||||
@ -46,56 +144,6 @@ class MLPWithInputSkips(torch.nn.Module):
|
|||||||
std = torch.nn.functional.softplus(log_std)
|
std = torch.nn.functional.softplus(log_std)
|
||||||
return (x - mu) * std
|
return (x - mu) * std
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_layers: int = 8,
|
|
||||||
input_dim: int = 39,
|
|
||||||
output_dim: int = 256,
|
|
||||||
skip_dim: int = 39,
|
|
||||||
hidden_dim: int = 256,
|
|
||||||
input_skips: Tuple[int, ...] = (5,),
|
|
||||||
skip_affine_trans: bool = False,
|
|
||||||
no_last_relu=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
n_layers: The number of linear layers of the MLP.
|
|
||||||
input_dim: The number of channels of the input tensor.
|
|
||||||
output_dim: The number of channels of the output.
|
|
||||||
skip_dim: The number of channels of the tensor `z` appended when
|
|
||||||
evaluating the skip layers.
|
|
||||||
hidden_dim: The number of hidden units of the MLP.
|
|
||||||
input_skips: The list of layer indices at which we append the skip
|
|
||||||
tensor `z`.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
layers = []
|
|
||||||
skip_affine_layers = []
|
|
||||||
for layeri in range(n_layers):
|
|
||||||
dimin = hidden_dim if layeri > 0 else input_dim
|
|
||||||
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
|
|
||||||
|
|
||||||
if layeri > 0 and layeri in input_skips:
|
|
||||||
if skip_affine_trans:
|
|
||||||
skip_affine_layers.append(
|
|
||||||
self._make_affine_layer(skip_dim, hidden_dim)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dimin = hidden_dim + skip_dim
|
|
||||||
|
|
||||||
linear = torch.nn.Linear(dimin, dimout)
|
|
||||||
_xavier_init(linear)
|
|
||||||
layers.append(
|
|
||||||
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
|
||||||
if not no_last_relu or layeri + 1 < n_layers
|
|
||||||
else linear
|
|
||||||
)
|
|
||||||
self.mlp = torch.nn.ModuleList(layers)
|
|
||||||
if skip_affine_trans:
|
|
||||||
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
|
|
||||||
self._input_skips = set(input_skips)
|
|
||||||
self._skip_affine_trans = skip_affine_trans
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -121,6 +169,24 @@ class MLPWithInputSkips(torch.nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class MLPDecoder(DecoderFunctionBase):
|
||||||
|
"""
|
||||||
|
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
network: MLPWithInputSkips
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.network(features, z)
|
||||||
|
|
||||||
|
|
||||||
class TransformerWithInputSkips(torch.nn.Module):
|
class TransformerWithInputSkips(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -9,7 +9,7 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||||
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
|
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
@ -214,6 +214,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
|
|||||||
append_xyz: Tuple[int, ...] = (5,)
|
append_xyz: Tuple[int, ...] = (5,)
|
||||||
|
|
||||||
def _construct_xyz_encoder(self, input_dim: int):
|
def _construct_xyz_encoder(self, input_dim: int):
|
||||||
|
expand_args_fields(MLPWithInputSkips)
|
||||||
return MLPWithInputSkips(
|
return MLPWithInputSkips(
|
||||||
self.n_layers_xyz,
|
self.n_layers_xyz,
|
||||||
input_dim,
|
input_dim,
|
||||||
|
34
tests/implicitron/test_decoding_functions.py
Normal file
34
tests/implicitron/test_decoding_functions.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
|
||||||
|
IdentityDecoder,
|
||||||
|
MLPDecoder,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||||
|
|
||||||
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
expand_args_fields(IdentityDecoder)
|
||||||
|
expand_args_fields(MLPDecoder)
|
||||||
|
|
||||||
|
def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2):
|
||||||
|
"""
|
||||||
|
Test that identity function returns its input
|
||||||
|
"""
|
||||||
|
func = IdentityDecoder()
|
||||||
|
for _ in range(n_tests):
|
||||||
|
_in = torch.randn(in_shape)
|
||||||
|
assert torch.allclose(func(_in), _in)
|
Loading…
x
Reference in New Issue
Block a user