MLP last layer config

Summary:
Added initialization configuration for the last layer of the MLP decoding function. You can now set:
- last activation function (tensorf uses sigmoid)
- last bias init (tensorf uses 0, because of sigmoid ofc)
- option to use xavier initialization (we use relu so this should not be set)

Reviewed By: davnov134

Differential Revision: D40304981

fbshipit-source-id: ec398eb2235164ae85cb7c09b9660e843490ea04
This commit is contained in:
Jeremy Reizenstein 2022-10-18 15:58:18 -07:00 committed by Facebook GitHub Bot
parent a2659e1730
commit a819ecb00b

View File

@ -14,6 +14,7 @@ This file contains
import logging
from enum import Enum
from typing import Optional, Tuple
import torch
@ -30,6 +31,13 @@ from pytorch3d.implicitron.tools.config import (
logger = logging.getLogger(__name__)
class DecoderActivation(Enum):
RELU = "relu"
SOFTPLUS = "softplus"
SIGMOID = "sigmoid"
IDENTITY = "identity"
class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
"""
Decoding function is a torch.nn.Module which takes the embedding of a location in
@ -71,11 +79,16 @@ class ElementwiseDecoder(DecoderFunctionBase):
scale: float = 1
shift: float = 0
operation: str = "identity"
operation: DecoderActivation = DecoderActivation.IDENTITY
def __post_init__(self):
super().__post_init__()
if self.operation not in ["relu", "softplus", "sigmoid", "identity"]:
if self.operation not in [
DecoderActivation.RELU,
DecoderActivation.SOFTPLUS,
DecoderActivation.SIGMOID,
DecoderActivation.IDENTITY,
]:
raise ValueError(
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
)
@ -84,11 +97,11 @@ class ElementwiseDecoder(DecoderFunctionBase):
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
transfomed_input = features * self.scale + self.shift
if self.operation == "softplus":
if self.operation == DecoderActivation.SOFTPLUS:
return torch.nn.functional.softplus(transfomed_input)
if self.operation == "relu":
if self.operation == DecoderActivation.RELU:
return torch.nn.functional.relu(transfomed_input)
if self.operation == "sigmoid":
if self.operation == DecoderActivation.SIGMOID:
return torch.nn.functional.sigmoid(transfomed_input)
return transfomed_input
@ -104,7 +117,15 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
appends a skip tensor `z` to the output of the preceding layer.
Note that this follows the architecture described in the Supplementary
Material (Fig. 7) of [1].
Material (Fig. 7) of [1], for which keep the defaults for:
- `last_layer_bias_init` to None
- `last_activation` to "relu"
- `use_xavier_init` to `true`
If you want to use this as a part of the color prediction in TensoRF model set:
- `last_layer_bias_init` to 0
- `last_activation` to "sigmoid"
- `use_xavier_init` to `False`
References:
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
@ -121,6 +142,12 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
last_layer_bias_init: If set then all the biases in the last layer
are initialized to that value.
last_activation: Which activation to use in the last layer. Options are:
"relu", "softplus", "sigmoid" and "identity". Default is "relu".
use_xavier_init: If True uses xavier init for all linear layer weights.
Otherwise the default PyTorch initialization is used. Default True.
"""
n_layers: int = 8
@ -130,10 +157,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
hidden_dim: int = 256
input_skips: Tuple[int, ...] = (5,)
skip_affine_trans: bool = False
no_last_relu = False
last_layer_bias_init: Optional[float] = None
last_activation: DecoderActivation = DecoderActivation.RELU
use_xavier_init: bool = True
def __post_init__(self):
super().__init__()
if self.last_activation not in [
DecoderActivation.RELU,
DecoderActivation.SOFTPLUS,
DecoderActivation.SIGMOID,
DecoderActivation.IDENTITY,
]:
raise ValueError(
"`last_activation` can only be `relu`,"
" `softplus`, `sigmoid` or identity."
)
last_activation = {
DecoderActivation.RELU: torch.nn.ReLU(True),
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
DecoderActivation.IDENTITY: torch.nn.Identity(),
}[self.last_activation]
layers = []
skip_affine_layers = []
for layeri in range(self.n_layers):
@ -149,11 +196,14 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
dimin = self.hidden_dim + self.skip_dim
linear = torch.nn.Linear(dimin, dimout)
_xavier_init(linear)
if self.use_xavier_init:
_xavier_init(linear)
if layeri == self.n_layers - 1 and self.last_layer_bias_init is not None:
torch.nn.init.constant_(linear.bias, self.last_layer_bias_init)
layers.append(
torch.nn.Sequential(linear, torch.nn.ReLU(True))
if not self.no_last_relu or layeri + 1 < self.n_layers
else linear
if not layeri + 1 < self.n_layers
else torch.nn.Sequential(linear, last_activation)
)
self.mlp = torch.nn.ModuleList(layers)
if self.skip_affine_trans:
@ -164,8 +214,9 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
def _make_affine_layer(self, input_dim, hidden_dim):
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
_xavier_init(l1)
_xavier_init(l2)
if self.use_xavier_init:
_xavier_init(l1)
_xavier_init(l2)
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
def _apply_affine_layer(self, layer, x, z):