mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
MLP last layer config
Summary: Added initialization configuration for the last layer of the MLP decoding function. You can now set: - last activation function (tensorf uses sigmoid) - last bias init (tensorf uses 0, because of sigmoid ofc) - option to use xavier initialization (we use relu so this should not be set) Reviewed By: davnov134 Differential Revision: D40304981 fbshipit-source-id: ec398eb2235164ae85cb7c09b9660e843490ea04
This commit is contained in:
parent
a2659e1730
commit
a819ecb00b
@ -14,6 +14,7 @@ This file contains
|
||||
|
||||
import logging
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -30,6 +31,13 @@ from pytorch3d.implicitron.tools.config import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DecoderActivation(Enum):
|
||||
RELU = "relu"
|
||||
SOFTPLUS = "softplus"
|
||||
SIGMOID = "sigmoid"
|
||||
IDENTITY = "identity"
|
||||
|
||||
|
||||
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Decoding function is a torch.nn.Module which takes the embedding of a location in
|
||||
@ -71,11 +79,16 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
||||
|
||||
scale: float = 1
|
||||
shift: float = 0
|
||||
operation: str = "identity"
|
||||
operation: DecoderActivation = DecoderActivation.IDENTITY
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.operation not in ["relu", "softplus", "sigmoid", "identity"]:
|
||||
if self.operation not in [
|
||||
DecoderActivation.RELU,
|
||||
DecoderActivation.SOFTPLUS,
|
||||
DecoderActivation.SIGMOID,
|
||||
DecoderActivation.IDENTITY,
|
||||
]:
|
||||
raise ValueError(
|
||||
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
|
||||
)
|
||||
@ -84,11 +97,11 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
||||
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
transfomed_input = features * self.scale + self.shift
|
||||
if self.operation == "softplus":
|
||||
if self.operation == DecoderActivation.SOFTPLUS:
|
||||
return torch.nn.functional.softplus(transfomed_input)
|
||||
if self.operation == "relu":
|
||||
if self.operation == DecoderActivation.RELU:
|
||||
return torch.nn.functional.relu(transfomed_input)
|
||||
if self.operation == "sigmoid":
|
||||
if self.operation == DecoderActivation.SIGMOID:
|
||||
return torch.nn.functional.sigmoid(transfomed_input)
|
||||
return transfomed_input
|
||||
|
||||
@ -104,7 +117,15 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
appends a skip tensor `z` to the output of the preceding layer.
|
||||
|
||||
Note that this follows the architecture described in the Supplementary
|
||||
Material (Fig. 7) of [1].
|
||||
Material (Fig. 7) of [1], for which keep the defaults for:
|
||||
- `last_layer_bias_init` to None
|
||||
- `last_activation` to "relu"
|
||||
- `use_xavier_init` to `true`
|
||||
|
||||
If you want to use this as a part of the color prediction in TensoRF model set:
|
||||
- `last_layer_bias_init` to 0
|
||||
- `last_activation` to "sigmoid"
|
||||
- `use_xavier_init` to `False`
|
||||
|
||||
References:
|
||||
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
|
||||
@ -121,6 +142,12 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
hidden_dim: The number of hidden units of the MLP.
|
||||
input_skips: The list of layer indices at which we append the skip
|
||||
tensor `z`.
|
||||
last_layer_bias_init: If set then all the biases in the last layer
|
||||
are initialized to that value.
|
||||
last_activation: Which activation to use in the last layer. Options are:
|
||||
"relu", "softplus", "sigmoid" and "identity". Default is "relu".
|
||||
use_xavier_init: If True uses xavier init for all linear layer weights.
|
||||
Otherwise the default PyTorch initialization is used. Default True.
|
||||
"""
|
||||
|
||||
n_layers: int = 8
|
||||
@ -130,10 +157,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
hidden_dim: int = 256
|
||||
input_skips: Tuple[int, ...] = (5,)
|
||||
skip_affine_trans: bool = False
|
||||
no_last_relu = False
|
||||
last_layer_bias_init: Optional[float] = None
|
||||
last_activation: DecoderActivation = DecoderActivation.RELU
|
||||
use_xavier_init: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
if self.last_activation not in [
|
||||
DecoderActivation.RELU,
|
||||
DecoderActivation.SOFTPLUS,
|
||||
DecoderActivation.SIGMOID,
|
||||
DecoderActivation.IDENTITY,
|
||||
]:
|
||||
raise ValueError(
|
||||
"`last_activation` can only be `relu`,"
|
||||
" `softplus`, `sigmoid` or identity."
|
||||
)
|
||||
last_activation = {
|
||||
DecoderActivation.RELU: torch.nn.ReLU(True),
|
||||
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
|
||||
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
|
||||
DecoderActivation.IDENTITY: torch.nn.Identity(),
|
||||
}[self.last_activation]
|
||||
|
||||
layers = []
|
||||
skip_affine_layers = []
|
||||
for layeri in range(self.n_layers):
|
||||
@ -149,11 +196,14 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
dimin = self.hidden_dim + self.skip_dim
|
||||
|
||||
linear = torch.nn.Linear(dimin, dimout)
|
||||
_xavier_init(linear)
|
||||
if self.use_xavier_init:
|
||||
_xavier_init(linear)
|
||||
if layeri == self.n_layers - 1 and self.last_layer_bias_init is not None:
|
||||
torch.nn.init.constant_(linear.bias, self.last_layer_bias_init)
|
||||
layers.append(
|
||||
torch.nn.Sequential(linear, torch.nn.ReLU(True))
|
||||
if not self.no_last_relu or layeri + 1 < self.n_layers
|
||||
else linear
|
||||
if not layeri + 1 < self.n_layers
|
||||
else torch.nn.Sequential(linear, last_activation)
|
||||
)
|
||||
self.mlp = torch.nn.ModuleList(layers)
|
||||
if self.skip_affine_trans:
|
||||
@ -164,8 +214,9 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
def _make_affine_layer(self, input_dim, hidden_dim):
|
||||
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
|
||||
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
|
||||
_xavier_init(l1)
|
||||
_xavier_init(l2)
|
||||
if self.use_xavier_init:
|
||||
_xavier_init(l1)
|
||||
_xavier_init(l2)
|
||||
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
|
||||
|
||||
def _apply_affine_layer(self, layer, x, z):
|
||||
|
Loading…
x
Reference in New Issue
Block a user