Hard population of registry system with pre_expand

Summary: Provide an extension point pre_expand to let a configurable class A make sure another class B is registered before A is expanded. This reduces top level imports.

Reviewed By: bottler

Differential Revision: D44504122

fbshipit-source-id: c418bebbe6d33862d239be592d9751378eee3a62
This commit is contained in:
Dejan Kovachev 2023-03-31 07:44:38 -07:00 committed by Facebook GitHub Bot
parent 813e941de5
commit c759fc560f
5 changed files with 117 additions and 27 deletions

View File

@ -13,13 +13,8 @@ from pytorch3d.implicitron.tools.config import (
) )
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa
class DataSourceBase(ReplaceableBase): class DataSourceBase(ReplaceableBase):
@ -60,6 +55,26 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
data_loader_map_provider: DataLoaderMapProviderBase data_loader_map_provider: DataLoaderMapProviderBase
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider" data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
@classmethod
def pre_expand(cls) -> None:
# use try/finally to bypass cinder's lazy imports
try:
from .blender_dataset_map_provider import ( # noqa: F401
BlenderDatasetMapProvider,
)
from .json_index_dataset_map_provider import ( # noqa: F401
JsonIndexDatasetMapProvider,
)
from .json_index_dataset_map_provider_v2 import ( # noqa: F401
JsonIndexDatasetMapProviderV2,
)
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa: F401
from .rendered_mesh_dataset_map_provider import ( # noqa: F401
RenderedMeshDatasetMapProvider,
)
finally:
pass
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None

View File

@ -20,23 +20,8 @@ from pytorch3d.implicitron.models.base_model import (
ImplicitronRender, ImplicitronRender,
) )
from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa
NeRFormerImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa
SRNHyperNetImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa
VoxelGridImplicitFunction,
)
from pytorch3d.implicitron.models.metrics import ( from pytorch3d.implicitron.models.metrics import (
RegularizationMetricsBase, RegularizationMetricsBase,
ViewMetricsBase, ViewMetricsBase,
@ -50,14 +35,7 @@ from pytorch3d.implicitron.models.renderer.base import (
RendererOutput, RendererOutput,
RenderSamplingMode, RenderSamplingMode,
) )
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer # noqa
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa
SignedDistanceFunctionRenderer,
)
from pytorch3d.implicitron.models.utils import ( from pytorch3d.implicitron.models.utils import (
apply_chunked, apply_chunked,
@ -315,6 +293,37 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
] ]
) )
@classmethod
def pre_expand(cls) -> None:
# use try/finally to bypass cinder's lazy imports
try:
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa: F401, B950
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
NeRFormerImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
SRNHyperNetImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa: F401, B950
VoxelGridImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
LSTMRenderer,
)
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
SignedDistanceFunctionRenderer,
)
finally:
pass
def __post_init__(self): def __post_init__(self):
if self.view_pooler_enabled: if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None: if self.image_feature_extractor_class_type is None:

View File

@ -258,6 +258,31 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
] ]
) )
@classmethod
def pre_expand(cls) -> None:
# use try/finally to bypass cinder's lazy imports
try:
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
NeuralRadianceFieldImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
SRNImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
LSTMRenderer,
)
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa: F401
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
SignedDistanceFunctionRenderer,
)
finally:
pass
def __post_init__(self): def __post_init__(self):
# The attribute will be filled by run_auto_creation # The attribute will be filled by run_auto_creation
run_auto_creation(self) run_auto_creation(self)

View File

@ -185,6 +185,7 @@ CREATE_PREFIX: str = "create_"
IMPL_SUFFIX: str = "_impl" IMPL_SUFFIX: str = "_impl"
TWEAK_SUFFIX: str = "_tweak_args" TWEAK_SUFFIX: str = "_tweak_args"
_DATACLASS_INIT: str = "__dataclass_own_init__" _DATACLASS_INIT: str = "__dataclass_own_init__"
PRE_EXPAND_NAME: str = "pre_expand"
class ReplaceableBase: class ReplaceableBase:
@ -838,6 +839,9 @@ def expand_args_fields(
In addition, if the class inherits torch.nn.Module, the generated __init__ will In addition, if the class inherits torch.nn.Module, the generated __init__ will
call torch.nn.Module's __init__ before doing anything else. call torch.nn.Module's __init__ before doing anything else.
Before any transformation of the class, if the class has a classmethod called
`pre_expand`, it will be called with no arguments.
Note that although the *_args members are intended to have type DictConfig, they Note that although the *_args members are intended to have type DictConfig, they
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
in place of a dict, but not vice-versa. Allowing dict lets a class user specify in place of a dict, but not vice-versa. Allowing dict lets a class user specify
@ -858,6 +862,9 @@ def expand_args_fields(
if _is_actually_dataclass(some_class): if _is_actually_dataclass(some_class):
return some_class return some_class
if hasattr(some_class, PRE_EXPAND_NAME):
getattr(some_class, PRE_EXPAND_NAME)()
# The functions this class's run_auto_creation will run. # The functions this class's run_auto_creation will run.
creation_functions: List[str] = [] creation_functions: List[str] = []
# The classes which this type knows about from the registry # The classes which this type knows about from the registry

View File

@ -10,6 +10,7 @@ import unittest
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
@ -805,6 +806,39 @@ class TestConfig(unittest.TestCase):
self.assertEqual(control_args, ["Orange", "Orange", True, True]) self.assertEqual(control_args, ["Orange", "Orange", True, True])
def test_pre_expand(self):
# Check that the precreate method of a class is called once before
# when expand_args_fields is called on the class.
class A(Configurable):
n: int = 9
@classmethod
def pre_expand(cls):
pass
A.pre_expand = Mock()
expand_args_fields(A)
A.pre_expand.assert_called()
def test_pre_expand_replaceable(self):
# Check that the precreate method of a class is called once before
# when expand_args_fields is called on the class.
class A(ReplaceableBase):
pass
@classmethod
def pre_expand(cls):
pass
class A1(A):
n: 9
A.pre_expand = Mock()
expand_args_fields(A1)
A.pre_expand.assert_called()
@dataclass(eq=False) @dataclass(eq=False)
class MockDataclass: class MockDataclass: