mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Hard population of registry system with pre_expand
Summary: Provide an extension point pre_expand to let a configurable class A make sure another class B is registered before A is expanded. This reduces top level imports. Reviewed By: bottler Differential Revision: D44504122 fbshipit-source-id: c418bebbe6d33862d239be592d9751378eee3a62
This commit is contained in:
parent
813e941de5
commit
c759fc560f
@ -13,13 +13,8 @@ from pytorch3d.implicitron.tools.config import (
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
|
||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa
|
||||
|
||||
|
||||
class DataSourceBase(ReplaceableBase):
|
||||
@ -60,6 +55,26 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
data_loader_map_provider: DataLoaderMapProviderBase
|
||||
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
||||
|
||||
@classmethod
|
||||
def pre_expand(cls) -> None:
|
||||
# use try/finally to bypass cinder's lazy imports
|
||||
try:
|
||||
from .blender_dataset_map_provider import ( # noqa: F401
|
||||
BlenderDatasetMapProvider,
|
||||
)
|
||||
from .json_index_dataset_map_provider import ( # noqa: F401
|
||||
JsonIndexDatasetMapProvider,
|
||||
)
|
||||
from .json_index_dataset_map_provider_v2 import ( # noqa: F401
|
||||
JsonIndexDatasetMapProviderV2,
|
||||
)
|
||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa: F401
|
||||
from .rendered_mesh_dataset_map_provider import ( # noqa: F401
|
||||
RenderedMeshDatasetMapProvider,
|
||||
)
|
||||
finally:
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
|
||||
|
@ -20,23 +20,8 @@ from pytorch3d.implicitron.models.base_model import (
|
||||
ImplicitronRender,
|
||||
)
|
||||
from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase
|
||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa
|
||||
ResNetFeatureExtractor,
|
||||
)
|
||||
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
|
||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa
|
||||
IdrFeatureField,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa
|
||||
NeRFormerImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa
|
||||
SRNHyperNetImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa
|
||||
VoxelGridImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.metrics import (
|
||||
RegularizationMetricsBase,
|
||||
ViewMetricsBase,
|
||||
@ -50,14 +35,7 @@ from pytorch3d.implicitron.models.renderer.base import (
|
||||
RendererOutput,
|
||||
RenderSamplingMode,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer # noqa
|
||||
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
|
||||
MultiPassEmissionAbsorptionRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
|
||||
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa
|
||||
SignedDistanceFunctionRenderer,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.models.utils import (
|
||||
apply_chunked,
|
||||
@ -315,6 +293,37 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pre_expand(cls) -> None:
|
||||
# use try/finally to bypass cinder's lazy imports
|
||||
try:
|
||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa: F401, B950
|
||||
ResNetFeatureExtractor,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
|
||||
IdrFeatureField,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
|
||||
NeRFormerImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
|
||||
SRNHyperNetImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa: F401, B950
|
||||
VoxelGridImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
|
||||
LSTMRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
|
||||
MultiPassEmissionAbsorptionRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
|
||||
SignedDistanceFunctionRenderer,
|
||||
)
|
||||
finally:
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
if self.view_pooler_enabled:
|
||||
if self.image_feature_extractor_class_type is None:
|
||||
|
@ -258,6 +258,31 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pre_expand(cls) -> None:
|
||||
# use try/finally to bypass cinder's lazy imports
|
||||
try:
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
|
||||
IdrFeatureField,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
|
||||
NeuralRadianceFieldImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
|
||||
SRNImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
|
||||
LSTMRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa: F401
|
||||
MultiPassEmissionAbsorptionRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
|
||||
SignedDistanceFunctionRenderer,
|
||||
)
|
||||
finally:
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
# The attribute will be filled by run_auto_creation
|
||||
run_auto_creation(self)
|
||||
|
@ -185,6 +185,7 @@ CREATE_PREFIX: str = "create_"
|
||||
IMPL_SUFFIX: str = "_impl"
|
||||
TWEAK_SUFFIX: str = "_tweak_args"
|
||||
_DATACLASS_INIT: str = "__dataclass_own_init__"
|
||||
PRE_EXPAND_NAME: str = "pre_expand"
|
||||
|
||||
|
||||
class ReplaceableBase:
|
||||
@ -838,6 +839,9 @@ def expand_args_fields(
|
||||
In addition, if the class inherits torch.nn.Module, the generated __init__ will
|
||||
call torch.nn.Module's __init__ before doing anything else.
|
||||
|
||||
Before any transformation of the class, if the class has a classmethod called
|
||||
`pre_expand`, it will be called with no arguments.
|
||||
|
||||
Note that although the *_args members are intended to have type DictConfig, they
|
||||
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
|
||||
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
|
||||
@ -858,6 +862,9 @@ def expand_args_fields(
|
||||
if _is_actually_dataclass(some_class):
|
||||
return some_class
|
||||
|
||||
if hasattr(some_class, PRE_EXPAND_NAME):
|
||||
getattr(some_class, PRE_EXPAND_NAME)()
|
||||
|
||||
# The functions this class's run_auto_creation will run.
|
||||
creation_functions: List[str] = []
|
||||
# The classes which this type knows about from the registry
|
||||
|
@ -10,6 +10,7 @@ import unittest
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from unittest.mock import Mock
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@ -805,6 +806,39 @@ class TestConfig(unittest.TestCase):
|
||||
|
||||
self.assertEqual(control_args, ["Orange", "Orange", True, True])
|
||||
|
||||
def test_pre_expand(self):
|
||||
# Check that the precreate method of a class is called once before
|
||||
# when expand_args_fields is called on the class.
|
||||
|
||||
class A(Configurable):
|
||||
n: int = 9
|
||||
|
||||
@classmethod
|
||||
def pre_expand(cls):
|
||||
pass
|
||||
|
||||
A.pre_expand = Mock()
|
||||
expand_args_fields(A)
|
||||
A.pre_expand.assert_called()
|
||||
|
||||
def test_pre_expand_replaceable(self):
|
||||
# Check that the precreate method of a class is called once before
|
||||
# when expand_args_fields is called on the class.
|
||||
|
||||
class A(ReplaceableBase):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def pre_expand(cls):
|
||||
pass
|
||||
|
||||
class A1(A):
|
||||
n: 9
|
||||
|
||||
A.pre_expand = Mock()
|
||||
expand_args_fields(A1)
|
||||
A.pre_expand.assert_called()
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MockDataclass:
|
||||
|
Loading…
x
Reference in New Issue
Block a user