mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-04 04:42:49 +08:00
Hard population of registry system with pre_expand
Summary: Provide an extension point pre_expand to let a configurable class A make sure another class B is registered before A is expanded. This reduces top level imports. Reviewed By: bottler Differential Revision: D44504122 fbshipit-source-id: c418bebbe6d33862d239be592d9751378eee3a62
This commit is contained in:
parent
813e941de5
commit
c759fc560f
@ -13,13 +13,8 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
|
||||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
|
||||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
|
||||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
|
||||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
|
||||||
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa
|
|
||||||
|
|
||||||
|
|
||||||
class DataSourceBase(ReplaceableBase):
|
class DataSourceBase(ReplaceableBase):
|
||||||
@ -60,6 +55,26 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
|||||||
data_loader_map_provider: DataLoaderMapProviderBase
|
data_loader_map_provider: DataLoaderMapProviderBase
|
||||||
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_expand(cls) -> None:
|
||||||
|
# use try/finally to bypass cinder's lazy imports
|
||||||
|
try:
|
||||||
|
from .blender_dataset_map_provider import ( # noqa: F401
|
||||||
|
BlenderDatasetMapProvider,
|
||||||
|
)
|
||||||
|
from .json_index_dataset_map_provider import ( # noqa: F401
|
||||||
|
JsonIndexDatasetMapProvider,
|
||||||
|
)
|
||||||
|
from .json_index_dataset_map_provider_v2 import ( # noqa: F401
|
||||||
|
JsonIndexDatasetMapProviderV2,
|
||||||
|
)
|
||||||
|
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa: F401
|
||||||
|
from .rendered_mesh_dataset_map_provider import ( # noqa: F401
|
||||||
|
RenderedMeshDatasetMapProvider,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
|
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
|
||||||
|
@ -20,23 +20,8 @@ from pytorch3d.implicitron.models.base_model import (
|
|||||||
ImplicitronRender,
|
ImplicitronRender,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase
|
from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase
|
||||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa
|
|
||||||
ResNetFeatureExtractor,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
|
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
|
||||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
|
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
|
||||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa
|
|
||||||
IdrFeatureField,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa
|
|
||||||
NeRFormerImplicitFunction,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa
|
|
||||||
SRNHyperNetImplicitFunction,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa
|
|
||||||
VoxelGridImplicitFunction,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.models.metrics import (
|
from pytorch3d.implicitron.models.metrics import (
|
||||||
RegularizationMetricsBase,
|
RegularizationMetricsBase,
|
||||||
ViewMetricsBase,
|
ViewMetricsBase,
|
||||||
@ -50,14 +35,7 @@ from pytorch3d.implicitron.models.renderer.base import (
|
|||||||
RendererOutput,
|
RendererOutput,
|
||||||
RenderSamplingMode,
|
RenderSamplingMode,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer # noqa
|
|
||||||
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
|
|
||||||
MultiPassEmissionAbsorptionRenderer,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
|
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
|
||||||
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa
|
|
||||||
SignedDistanceFunctionRenderer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.utils import (
|
from pytorch3d.implicitron.models.utils import (
|
||||||
apply_chunked,
|
apply_chunked,
|
||||||
@ -315,6 +293,37 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_expand(cls) -> None:
|
||||||
|
# use try/finally to bypass cinder's lazy imports
|
||||||
|
try:
|
||||||
|
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa: F401, B950
|
||||||
|
ResNetFeatureExtractor,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
|
||||||
|
IdrFeatureField,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
|
||||||
|
NeRFormerImplicitFunction,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
|
||||||
|
SRNHyperNetImplicitFunction,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa: F401, B950
|
||||||
|
VoxelGridImplicitFunction,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
|
||||||
|
LSTMRenderer,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
|
||||||
|
MultiPassEmissionAbsorptionRenderer,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
|
||||||
|
SignedDistanceFunctionRenderer,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.view_pooler_enabled:
|
if self.view_pooler_enabled:
|
||||||
if self.image_feature_extractor_class_type is None:
|
if self.image_feature_extractor_class_type is None:
|
||||||
|
@ -258,6 +258,31 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_expand(cls) -> None:
|
||||||
|
# use try/finally to bypass cinder's lazy imports
|
||||||
|
try:
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
|
||||||
|
IdrFeatureField,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
|
||||||
|
NeuralRadianceFieldImplicitFunction,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
|
||||||
|
SRNImplicitFunction,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
|
||||||
|
LSTMRenderer,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa: F401
|
||||||
|
MultiPassEmissionAbsorptionRenderer,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
|
||||||
|
SignedDistanceFunctionRenderer,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# The attribute will be filled by run_auto_creation
|
# The attribute will be filled by run_auto_creation
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
@ -185,6 +185,7 @@ CREATE_PREFIX: str = "create_"
|
|||||||
IMPL_SUFFIX: str = "_impl"
|
IMPL_SUFFIX: str = "_impl"
|
||||||
TWEAK_SUFFIX: str = "_tweak_args"
|
TWEAK_SUFFIX: str = "_tweak_args"
|
||||||
_DATACLASS_INIT: str = "__dataclass_own_init__"
|
_DATACLASS_INIT: str = "__dataclass_own_init__"
|
||||||
|
PRE_EXPAND_NAME: str = "pre_expand"
|
||||||
|
|
||||||
|
|
||||||
class ReplaceableBase:
|
class ReplaceableBase:
|
||||||
@ -838,6 +839,9 @@ def expand_args_fields(
|
|||||||
In addition, if the class inherits torch.nn.Module, the generated __init__ will
|
In addition, if the class inherits torch.nn.Module, the generated __init__ will
|
||||||
call torch.nn.Module's __init__ before doing anything else.
|
call torch.nn.Module's __init__ before doing anything else.
|
||||||
|
|
||||||
|
Before any transformation of the class, if the class has a classmethod called
|
||||||
|
`pre_expand`, it will be called with no arguments.
|
||||||
|
|
||||||
Note that although the *_args members are intended to have type DictConfig, they
|
Note that although the *_args members are intended to have type DictConfig, they
|
||||||
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
|
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
|
||||||
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
|
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
|
||||||
@ -858,6 +862,9 @@ def expand_args_fields(
|
|||||||
if _is_actually_dataclass(some_class):
|
if _is_actually_dataclass(some_class):
|
||||||
return some_class
|
return some_class
|
||||||
|
|
||||||
|
if hasattr(some_class, PRE_EXPAND_NAME):
|
||||||
|
getattr(some_class, PRE_EXPAND_NAME)()
|
||||||
|
|
||||||
# The functions this class's run_auto_creation will run.
|
# The functions this class's run_auto_creation will run.
|
||||||
creation_functions: List[str] = []
|
creation_functions: List[str] = []
|
||||||
# The classes which this type knows about from the registry
|
# The classes which this type knows about from the registry
|
||||||
|
@ -10,6 +10,7 @@ import unittest
|
|||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@ -805,6 +806,39 @@ class TestConfig(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(control_args, ["Orange", "Orange", True, True])
|
self.assertEqual(control_args, ["Orange", "Orange", True, True])
|
||||||
|
|
||||||
|
def test_pre_expand(self):
|
||||||
|
# Check that the precreate method of a class is called once before
|
||||||
|
# when expand_args_fields is called on the class.
|
||||||
|
|
||||||
|
class A(Configurable):
|
||||||
|
n: int = 9
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_expand(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
A.pre_expand = Mock()
|
||||||
|
expand_args_fields(A)
|
||||||
|
A.pre_expand.assert_called()
|
||||||
|
|
||||||
|
def test_pre_expand_replaceable(self):
|
||||||
|
# Check that the precreate method of a class is called once before
|
||||||
|
# when expand_args_fields is called on the class.
|
||||||
|
|
||||||
|
class A(ReplaceableBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_expand(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class A1(A):
|
||||||
|
n: 9
|
||||||
|
|
||||||
|
A.pre_expand = Mock()
|
||||||
|
expand_args_fields(A1)
|
||||||
|
A.pre_expand.assert_called()
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class MockDataclass:
|
class MockDataclass:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user