Decoding functions

Summary: Added replacable decoding functions which will be applied after the voxel grid to get color and density

Reviewed By: bottler

Differential Revision: D38829763

fbshipit-source-id: f21ce206c1c19548206ea2ce97d7ebea3de30a23
This commit is contained in:
Darijan Gudelj
2022-08-26 08:47:30 -07:00
committed by Facebook GitHub Bot
parent 24f5f4a3e7
commit e7c609f198
3 changed files with 153 additions and 52 deletions

View File

@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
IdentityDecoder,
MLPDecoder,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from tests.common_testing import TestCaseMixin
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
expand_args_fields(IdentityDecoder)
expand_args_fields(MLPDecoder)
def test_identity_function(self, in_shape=(33, 4, 1), n_tests=2):
"""
Test that identity function returns its input
"""
func = IdentityDecoder()
for _ in range(n_tests):
_in = torch.randn(in_shape)
assert torch.allclose(func(_in), _in)