diff --git a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py index 6e99b5ab..fd33a3fd 100644 --- a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py +++ b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py @@ -54,15 +54,43 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): @registry.register -class IdentityDecoder(DecoderFunctionBase): +class ElementwiseDecoder(DecoderFunctionBase): """ - Decoding function which returns its input. + Decoding function which scales the input, adds shift and then applies + `relu`, `softplus`, `sigmoid` or nothing on its input: + `result = operation(input * scale + shift)` + + Members: + scale: a scalar with which input is multiplied before being shifted. + Defaults to 1. + shift: a scalar which is added to the scaled input before performing + the operation. Defaults to 0. + operation: which operation to perform on the transformed input. Options are: + `relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`. """ + scale: float = 1 + shift: float = 0 + operation: str = "identity" + + def __post_init__(self): + super().__post_init__() + if self.operation not in ["relu", "softplus", "sigmoid", "identity"]: + raise ValueError( + "`operation` can only be `relu`, `softplus`, `sigmoid` or identity." + ) + def forward( self, features: torch.Tensor, z: Optional[torch.Tensor] = None ) -> torch.Tensor: - return features + transfomed_input = features * self.scale + self.shift + if self.operation == "softplus": + return torch.nn.functional.softplus(transfomed_input) + if self.operation == "relu": + return torch.nn.functional.relu(transfomed_input) + if self.operation == "sigmoid": + return torch.nn.functional.sigmoid(transfomed_input) + return transfomed_input class MLPWithInputSkips(Configurable, torch.nn.Module): diff --git a/tests/implicitron/test_decoding_functions.py b/tests/implicitron/test_decoding_functions.py deleted file mode 100644 index 0a8db59a..00000000 --- a/tests/implicitron/test_decoding_functions.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import unittest - -import torch - -from pytorch3d.implicitron.models.implicit_function.decoding_functions import ( - IdentityDecoder, - MLPDecoder, -) -from pytorch3d.implicitron.tools.config import expand_args_fields - -from tests.common_testing import TestCaseMixin - - -class TestVoxelGrids(TestCaseMixin, unittest.TestCase): - def setUp(self): - torch.manual_seed(42) - expand_args_fields(IdentityDecoder) - expand_args_fields(MLPDecoder) - - def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2): - """ - Test that identity function returns its input - """ - func = IdentityDecoder() - for _ in range(n_tests): - _in = torch.randn(in_shape) - assert torch.allclose(func(_in), _in)