mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Elementwise decoder
Summary: Tensorf does relu or softmax after the density grid. This diff adds the ability to replicate that. Reviewed By: bottler Differential Revision: D40023228 fbshipit-source-id: 9f19868cd68460af98ab6e61c7f708158c26dc08
This commit is contained in:
parent
a607dd063e
commit
76cddd90be
@ -54,15 +54,43 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class IdentityDecoder(DecoderFunctionBase):
|
class ElementwiseDecoder(DecoderFunctionBase):
|
||||||
"""
|
"""
|
||||||
Decoding function which returns its input.
|
Decoding function which scales the input, adds shift and then applies
|
||||||
|
`relu`, `softplus`, `sigmoid` or nothing on its input:
|
||||||
|
`result = operation(input * scale + shift)`
|
||||||
|
|
||||||
|
Members:
|
||||||
|
scale: a scalar with which input is multiplied before being shifted.
|
||||||
|
Defaults to 1.
|
||||||
|
shift: a scalar which is added to the scaled input before performing
|
||||||
|
the operation. Defaults to 0.
|
||||||
|
operation: which operation to perform on the transformed input. Options are:
|
||||||
|
`relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
scale: float = 1
|
||||||
|
shift: float = 0
|
||||||
|
operation: str = "identity"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.operation not in ["relu", "softplus", "sigmoid", "identity"]:
|
||||||
|
raise ValueError(
|
||||||
|
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return features
|
transfomed_input = features * self.scale + self.shift
|
||||||
|
if self.operation == "softplus":
|
||||||
|
return torch.nn.functional.softplus(transfomed_input)
|
||||||
|
if self.operation == "relu":
|
||||||
|
return torch.nn.functional.relu(transfomed_input)
|
||||||
|
if self.operation == "sigmoid":
|
||||||
|
return torch.nn.functional.sigmoid(transfomed_input)
|
||||||
|
return transfomed_input
|
||||||
|
|
||||||
|
|
||||||
class MLPWithInputSkips(Configurable, torch.nn.Module):
|
class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
|
|
||||||
IdentityDecoder,
|
|
||||||
MLPDecoder,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
|
||||||
|
|
||||||
from tests.common_testing import TestCaseMixin
|
|
||||||
|
|
||||||
|
|
||||||
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
expand_args_fields(IdentityDecoder)
|
|
||||||
expand_args_fields(MLPDecoder)
|
|
||||||
|
|
||||||
def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2):
|
|
||||||
"""
|
|
||||||
Test that identity function returns its input
|
|
||||||
"""
|
|
||||||
func = IdentityDecoder()
|
|
||||||
for _ in range(n_tests):
|
|
||||||
_in = torch.randn(in_shape)
|
|
||||||
assert torch.allclose(func(_in), _in)
|
|
Loading…
x
Reference in New Issue
Block a user