mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Refactor autodecoders
Summary: Refactors autodecoders. Tests pass. Reviewed By: bottler Differential Revision: D37592429 fbshipit-source-id: 8f5c9eac254e1fdf0704d5ec5f69eb42f6225113
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ae35824f21
commit
0dce883241
@@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16
|
||||
n_train_target_views: 1
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
global_encoder_class_type: SequenceAutodecoder
|
||||
raysampler_class_type: AdaptiveRaySampler
|
||||
renderer_class_type: LSTMRenderer
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
@@ -48,11 +49,12 @@ log_vars:
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
sequence_autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
|
||||
@@ -7,11 +7,13 @@
|
||||
import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
|
||||
ResNetFeatureExtractor,
|
||||
)
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
|
||||
SequenceAutodecoder,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||
IdrFeatureField,
|
||||
)
|
||||
@@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
self.assertIsInstance(
|
||||
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
||||
)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsNone(gm.global_encoder)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
self.assertIsNone(gm.view_pooler)
|
||||
self.assertIsNone(gm.image_feature_extractor)
|
||||
@@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
||||
args.implicit_function_class_type = "IdrFeatureField"
|
||||
args.global_encoder_class_type = "SequenceAutodecoder"
|
||||
idr_args = args.implicit_function_IdrFeatureField_args
|
||||
idr_args.n_harmonic_functions_xyz = 1729
|
||||
|
||||
@@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
|
||||
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
|
||||
@@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase):
|
||||
remove_unused_components(instance_args)
|
||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
if DEBUG:
|
||||
print(DATA_DIR)
|
||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||
|
||||
Reference in New Issue
Block a user