ViewPooler class

Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator.

Reviewed By: shapovalov

Differential Revision: D35852367

fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
This commit is contained in:
David Novotny
2022-05-12 12:50:03 -07:00
committed by Facebook GitHub Bot
parent bef959c755
commit 47d06c8924
26 changed files with 304 additions and 110 deletions

View File

@@ -20,9 +20,8 @@ from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
from pytorch3d.implicitron.models.view_pooler.feature_aggregator import (
AngleWeightedIdentityFeatureAggregator,
AngleWeightedReductionFeatureAggregator,
)
from pytorch3d.implicitron.tools.config import (
get_default_args,
@@ -32,7 +31,10 @@ from pytorch3d.implicitron.tools.config import (
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
from .common_resources import provide_lpips_vgg
else:
from common_resources import provide_lpips_vgg # noqa
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
@@ -46,28 +48,33 @@ class TestGenericModel(unittest.TestCase):
self.maxDiff = None
def test_create_gm(self):
provide_lpips_vgg()
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
)
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertFalse(hasattr(gm, "image_feature_extractor"))
self.assertIsNone(gm.view_pooler)
self.assertIsNone(gm.image_feature_extractor)
def test_create_gm_overrides(self):
def _test_create_gm_overrides(self):
provide_lpips_vgg()
args = get_default_args(GenericModel)
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
args.view_pooler_enabled = True
args.image_feature_extractor_enabled = True
args.view_pooler_args.feature_aggregator_class_type = (
"AngleWeightedIdentityFeatureAggregator"
)
args.implicit_function_class_type = "IdrFeatureField"
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
gm.view_pooler.feature_aggregator,
AngleWeightedIdentityFeatureAggregator,
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)