mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
MLP last layer config
Summary: Added initialization configuration for the last layer of the MLP decoding function. You can now set: - last activation function (tensorf uses sigmoid) - last bias init (tensorf uses 0, because of sigmoid ofc) - option to use xavier initialization (we use relu so this should not be set) Reviewed By: davnov134 Differential Revision: D40304981 fbshipit-source-id: ec398eb2235164ae85cb7c09b9660e843490ea04
This commit is contained in:
parent
a2659e1730
commit
a819ecb00b
@ -14,6 +14,7 @@ This file contains
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -30,6 +31,13 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderActivation(Enum):
|
||||||
|
RELU = "relu"
|
||||||
|
SOFTPLUS = "softplus"
|
||||||
|
SIGMOID = "sigmoid"
|
||||||
|
IDENTITY = "identity"
|
||||||
|
|
||||||
|
|
||||||
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Decoding function is a torch.nn.Module which takes the embedding of a location in
|
Decoding function is a torch.nn.Module which takes the embedding of a location in
|
||||||
@ -71,11 +79,16 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
|||||||
|
|
||||||
scale: float = 1
|
scale: float = 1
|
||||||
shift: float = 0
|
shift: float = 0
|
||||||
operation: str = "identity"
|
operation: DecoderActivation = DecoderActivation.IDENTITY
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
if self.operation not in ["relu", "softplus", "sigmoid", "identity"]:
|
if self.operation not in [
|
||||||
|
DecoderActivation.RELU,
|
||||||
|
DecoderActivation.SOFTPLUS,
|
||||||
|
DecoderActivation.SIGMOID,
|
||||||
|
DecoderActivation.IDENTITY,
|
||||||
|
]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
|
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
|
||||||
)
|
)
|
||||||
@ -84,11 +97,11 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
|||||||
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
transfomed_input = features * self.scale + self.shift
|
transfomed_input = features * self.scale + self.shift
|
||||||
if self.operation == "softplus":
|
if self.operation == DecoderActivation.SOFTPLUS:
|
||||||
return torch.nn.functional.softplus(transfomed_input)
|
return torch.nn.functional.softplus(transfomed_input)
|
||||||
if self.operation == "relu":
|
if self.operation == DecoderActivation.RELU:
|
||||||
return torch.nn.functional.relu(transfomed_input)
|
return torch.nn.functional.relu(transfomed_input)
|
||||||
if self.operation == "sigmoid":
|
if self.operation == DecoderActivation.SIGMOID:
|
||||||
return torch.nn.functional.sigmoid(transfomed_input)
|
return torch.nn.functional.sigmoid(transfomed_input)
|
||||||
return transfomed_input
|
return transfomed_input
|
||||||
|
|
||||||
@ -104,7 +117,15 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
appends a skip tensor `z` to the output of the preceding layer.
|
appends a skip tensor `z` to the output of the preceding layer.
|
||||||
|
|
||||||
Note that this follows the architecture described in the Supplementary
|
Note that this follows the architecture described in the Supplementary
|
||||||
Material (Fig. 7) of [1].
|
Material (Fig. 7) of [1], for which keep the defaults for:
|
||||||
|
- `last_layer_bias_init` to None
|
||||||
|
- `last_activation` to "relu"
|
||||||
|
- `use_xavier_init` to `true`
|
||||||
|
|
||||||
|
If you want to use this as a part of the color prediction in TensoRF model set:
|
||||||
|
- `last_layer_bias_init` to 0
|
||||||
|
- `last_activation` to "sigmoid"
|
||||||
|
- `use_xavier_init` to `False`
|
||||||
|
|
||||||
References:
|
References:
|
||||||
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
|
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
|
||||||
@ -121,6 +142,12 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
hidden_dim: The number of hidden units of the MLP.
|
hidden_dim: The number of hidden units of the MLP.
|
||||||
input_skips: The list of layer indices at which we append the skip
|
input_skips: The list of layer indices at which we append the skip
|
||||||
tensor `z`.
|
tensor `z`.
|
||||||
|
last_layer_bias_init: If set then all the biases in the last layer
|
||||||
|
are initialized to that value.
|
||||||
|
last_activation: Which activation to use in the last layer. Options are:
|
||||||
|
"relu", "softplus", "sigmoid" and "identity". Default is "relu".
|
||||||
|
use_xavier_init: If True uses xavier init for all linear layer weights.
|
||||||
|
Otherwise the default PyTorch initialization is used. Default True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n_layers: int = 8
|
n_layers: int = 8
|
||||||
@ -130,10 +157,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
hidden_dim: int = 256
|
hidden_dim: int = 256
|
||||||
input_skips: Tuple[int, ...] = (5,)
|
input_skips: Tuple[int, ...] = (5,)
|
||||||
skip_affine_trans: bool = False
|
skip_affine_trans: bool = False
|
||||||
no_last_relu = False
|
last_layer_bias_init: Optional[float] = None
|
||||||
|
last_activation: DecoderActivation = DecoderActivation.RELU
|
||||||
|
use_xavier_init: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if self.last_activation not in [
|
||||||
|
DecoderActivation.RELU,
|
||||||
|
DecoderActivation.SOFTPLUS,
|
||||||
|
DecoderActivation.SIGMOID,
|
||||||
|
DecoderActivation.IDENTITY,
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
"`last_activation` can only be `relu`,"
|
||||||
|
" `softplus`, `sigmoid` or identity."
|
||||||
|
)
|
||||||
|
last_activation = {
|
||||||
|
DecoderActivation.RELU: torch.nn.ReLU(True),
|
||||||
|
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
|
||||||
|
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
|
||||||
|
DecoderActivation.IDENTITY: torch.nn.Identity(),
|
||||||
|
}[self.last_activation]
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
skip_affine_layers = []
|
skip_affine_layers = []
|
||||||
for layeri in range(self.n_layers):
|
for layeri in range(self.n_layers):
|
||||||
@ -149,11 +196,14 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
dimin = self.hidden_dim + self.skip_dim
|
dimin = self.hidden_dim + self.skip_dim
|
||||||
|
|
||||||
linear = torch.nn.Linear(dimin, dimout)
|
linear = torch.nn.Linear(dimin, dimout)
|
||||||
|
if self.use_xavier_init:
|
||||||
_xavier_init(linear)
|
_xavier_init(linear)
|
||||||
|
if layeri == self.n_layers - 1 and self.last_layer_bias_init is not None:
|
||||||
|
torch.nn.init.constant_(linear.bias, self.last_layer_bias_init)
|
||||||
layers.append(
|
layers.append(
|
||||||
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
||||||
if not self.no_last_relu or layeri + 1 < self.n_layers
|
if not layeri + 1 < self.n_layers
|
||||||
else linear
|
else torch.nn.Sequential(linear, last_activation)
|
||||||
)
|
)
|
||||||
self.mlp = torch.nn.ModuleList(layers)
|
self.mlp = torch.nn.ModuleList(layers)
|
||||||
if self.skip_affine_trans:
|
if self.skip_affine_trans:
|
||||||
@ -164,6 +214,7 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
def _make_affine_layer(self, input_dim, hidden_dim):
|
def _make_affine_layer(self, input_dim, hidden_dim):
|
||||||
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
|
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
|
||||||
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
|
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
|
||||||
|
if self.use_xavier_init:
|
||||||
_xavier_init(l1)
|
_xavier_init(l1)
|
||||||
_xavier_init(l2)
|
_xavier_init(l2)
|
||||||
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
|
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user