pytorch3d/tests/implicitron/test_config_use.py
Jeremy Reizenstein cdd2142dd5
implicitron v0 (#1133)
Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
2022-03-21 13:20:10 -07:00

82 lines
3.0 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
NeuralRadianceFieldImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
AngleWeightedIdentityFeatureAggregator,
AngleWeightedReductionFeatureAggregator,
)
from pytorch3d.implicitron.tools.config import (
get_default_args,
remove_unused_components,
)
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
else:
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False
# Tests the use of the config system in implicitron
class TestGenericModel(unittest.TestCase):
def setUp(self):
self.maxDiff = None
def test_create_gm(self):
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
)
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertFalse(hasattr(gm, "image_feature_extractor"))
def test_create_gm_overrides(self):
args = get_default_args(GenericModel)
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
args.implicit_function_class_type = "IdrFeatureField"
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
instance_args = OmegaConf.structured(gm)
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
(DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())