Refactor autodecoders

Summary: Refactors autodecoders. Tests pass.

Reviewed By: bottler

Differential Revision: D37592429

fbshipit-source-id: 8f5c9eac254e1fdf0704d5ec5f69eb42f6225113
This commit is contained in:
David Novotny
2022-07-04 07:18:03 -07:00
committed by Facebook GitHub Bot
parent ae35824f21
commit 0dce883241
10 changed files with 230 additions and 87 deletions

View File

@@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
global_encoder_class_type: SequenceAutodecoder
raysampler_class_type: AdaptiveRaySampler
renderer_class_type: LSTMRenderer
image_feature_extractor_class_type: ResNetFeatureExtractor
@@ -48,11 +49,12 @@ log_vars:
- objective
- epoch
- sec/it
sequence_autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400

View File

@@ -7,11 +7,13 @@
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
SequenceAutodecoder,
)
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)
@@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase):
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertIsNone(gm.global_encoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertIsNone(gm.view_pooler)
self.assertIsNone(gm.image_feature_extractor)
@@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase):
)
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
args.implicit_function_class_type = "IdrFeatureField"
args.global_encoder_class_type = "SequenceAutodecoder"
idr_args = args.implicit_function_IdrFeatureField_args
idr_args.n_harmonic_functions_xyz = 1729
@@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase):
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
self.assertFalse(hasattr(gm, "implicit_function"))
@@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase):
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
print(DATA_DIR)
(DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())