mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Refactor autodecoders
Summary: Refactors autodecoders. Tests pass. Reviewed By: bottler Differential Revision: D37592429 fbshipit-source-id: 8f5c9eac254e1fdf0704d5ec5f69eb42f6225113
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ae35824f21
commit
0dce883241
@@ -7,11 +7,13 @@
|
||||
import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
|
||||
ResNetFeatureExtractor,
|
||||
)
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.global_encoder.global_encoder import (
|
||||
SequenceAutodecoder,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||
IdrFeatureField,
|
||||
)
|
||||
@@ -50,7 +52,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
self.assertIsInstance(
|
||||
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
||||
)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsNone(gm.global_encoder)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
self.assertIsNone(gm.view_pooler)
|
||||
self.assertIsNone(gm.image_feature_extractor)
|
||||
@@ -64,6 +66,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
||||
args.implicit_function_class_type = "IdrFeatureField"
|
||||
args.global_encoder_class_type = "SequenceAutodecoder"
|
||||
idr_args = args.implicit_function_IdrFeatureField_args
|
||||
idr_args.n_harmonic_functions_xyz = 1729
|
||||
|
||||
@@ -76,7 +79,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsInstance(gm.global_encoder, SequenceAutodecoder)
|
||||
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
|
||||
@@ -87,5 +90,6 @@ class TestGenericModel(unittest.TestCase):
|
||||
remove_unused_components(instance_args)
|
||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
if DEBUG:
|
||||
print(DATA_DIR)
|
||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||
|
||||
Reference in New Issue
Block a user