diff --git a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py index fd33a3fd..2713ea46 100644 --- a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py +++ b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py @@ -14,6 +14,7 @@ This file contains import logging +from enum import Enum from typing import Optional, Tuple import torch @@ -30,6 +31,13 @@ from pytorch3d.implicitron.tools.config import ( logger = logging.getLogger(__name__) +class DecoderActivation(Enum): + RELU = "relu" + SOFTPLUS = "softplus" + SIGMOID = "sigmoid" + IDENTITY = "identity" + + class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): """ Decoding function is a torch.nn.Module which takes the embedding of a location in @@ -71,11 +79,16 @@ class ElementwiseDecoder(DecoderFunctionBase): scale: float = 1 shift: float = 0 - operation: str = "identity" + operation: DecoderActivation = DecoderActivation.IDENTITY def __post_init__(self): super().__post_init__() - if self.operation not in ["relu", "softplus", "sigmoid", "identity"]: + if self.operation not in [ + DecoderActivation.RELU, + DecoderActivation.SOFTPLUS, + DecoderActivation.SIGMOID, + DecoderActivation.IDENTITY, + ]: raise ValueError( "`operation` can only be `relu`, `softplus`, `sigmoid` or identity." ) @@ -84,11 +97,11 @@ class ElementwiseDecoder(DecoderFunctionBase): self, features: torch.Tensor, z: Optional[torch.Tensor] = None ) -> torch.Tensor: transfomed_input = features * self.scale + self.shift - if self.operation == "softplus": + if self.operation == DecoderActivation.SOFTPLUS: return torch.nn.functional.softplus(transfomed_input) - if self.operation == "relu": + if self.operation == DecoderActivation.RELU: return torch.nn.functional.relu(transfomed_input) - if self.operation == "sigmoid": + if self.operation == DecoderActivation.SIGMOID: return torch.nn.functional.sigmoid(transfomed_input) return transfomed_input @@ -104,7 +117,15 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): appends a skip tensor `z` to the output of the preceding layer. Note that this follows the architecture described in the Supplementary - Material (Fig. 7) of [1]. + Material (Fig. 7) of [1], for which keep the defaults for: + - `last_layer_bias_init` to None + - `last_activation` to "relu" + - `use_xavier_init` to `true` + + If you want to use this as a part of the color prediction in TensoRF model set: + - `last_layer_bias_init` to 0 + - `last_activation` to "sigmoid" + - `use_xavier_init` to `False` References: [1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik @@ -121,6 +142,12 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): hidden_dim: The number of hidden units of the MLP. input_skips: The list of layer indices at which we append the skip tensor `z`. + last_layer_bias_init: If set then all the biases in the last layer + are initialized to that value. + last_activation: Which activation to use in the last layer. Options are: + "relu", "softplus", "sigmoid" and "identity". Default is "relu". + use_xavier_init: If True uses xavier init for all linear layer weights. + Otherwise the default PyTorch initialization is used. Default True. """ n_layers: int = 8 @@ -130,10 +157,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): hidden_dim: int = 256 input_skips: Tuple[int, ...] = (5,) skip_affine_trans: bool = False - no_last_relu = False + last_layer_bias_init: Optional[float] = None + last_activation: DecoderActivation = DecoderActivation.RELU + use_xavier_init: bool = True def __post_init__(self): super().__init__() + + if self.last_activation not in [ + DecoderActivation.RELU, + DecoderActivation.SOFTPLUS, + DecoderActivation.SIGMOID, + DecoderActivation.IDENTITY, + ]: + raise ValueError( + "`last_activation` can only be `relu`," + " `softplus`, `sigmoid` or identity." + ) + last_activation = { + DecoderActivation.RELU: torch.nn.ReLU(True), + DecoderActivation.SOFTPLUS: torch.nn.Softplus(), + DecoderActivation.SIGMOID: torch.nn.Sigmoid(), + DecoderActivation.IDENTITY: torch.nn.Identity(), + }[self.last_activation] + layers = [] skip_affine_layers = [] for layeri in range(self.n_layers): @@ -149,11 +196,14 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): dimin = self.hidden_dim + self.skip_dim linear = torch.nn.Linear(dimin, dimout) - _xavier_init(linear) + if self.use_xavier_init: + _xavier_init(linear) + if layeri == self.n_layers - 1 and self.last_layer_bias_init is not None: + torch.nn.init.constant_(linear.bias, self.last_layer_bias_init) layers.append( torch.nn.Sequential(linear, torch.nn.ReLU(True)) - if not self.no_last_relu or layeri + 1 < self.n_layers - else linear + if not layeri + 1 < self.n_layers + else torch.nn.Sequential(linear, last_activation) ) self.mlp = torch.nn.ModuleList(layers) if self.skip_affine_trans: @@ -164,8 +214,9 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): def _make_affine_layer(self, input_dim, hidden_dim): l1 = torch.nn.Linear(input_dim, hidden_dim * 2) l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2) - _xavier_init(l1) - _xavier_init(l2) + if self.use_xavier_init: + _xavier_init(l1) + _xavier_init(l2) return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2) def _apply_affine_layer(self, layer, x, z):