mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Hard population of registry system with pre_expand
Summary: Provide an extension point pre_expand to let a configurable class A make sure another class B is registered before A is expanded. This reduces top level imports. Reviewed By: bottler Differential Revision: D44504122 fbshipit-source-id: c418bebbe6d33862d239be592d9751378eee3a62
This commit is contained in:
		
							parent
							
								
									813e941de5
								
							
						
					
					
						commit
						c759fc560f
					
				@ -13,13 +13,8 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider  # noqa
 | 
			
		||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
 | 
			
		||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
 | 
			
		||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider  # noqa
 | 
			
		||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2  # noqa
 | 
			
		||||
from .llff_dataset_map_provider import LlffDatasetMapProvider  # noqa
 | 
			
		||||
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider  # noqa
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataSourceBase(ReplaceableBase):
 | 
			
		||||
@ -60,6 +55,26 @@ class ImplicitronDataSource(DataSourceBase):  # pyre-ignore[13]
 | 
			
		||||
    data_loader_map_provider: DataLoaderMapProviderBase
 | 
			
		||||
    data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def pre_expand(cls) -> None:
 | 
			
		||||
        # use try/finally to bypass cinder's lazy imports
 | 
			
		||||
        try:
 | 
			
		||||
            from .blender_dataset_map_provider import (  # noqa: F401
 | 
			
		||||
                BlenderDatasetMapProvider,
 | 
			
		||||
            )
 | 
			
		||||
            from .json_index_dataset_map_provider import (  # noqa: F401
 | 
			
		||||
                JsonIndexDatasetMapProvider,
 | 
			
		||||
            )
 | 
			
		||||
            from .json_index_dataset_map_provider_v2 import (  # noqa: F401
 | 
			
		||||
                JsonIndexDatasetMapProviderV2,
 | 
			
		||||
            )
 | 
			
		||||
            from .llff_dataset_map_provider import LlffDatasetMapProvider  # noqa: F401
 | 
			
		||||
            from .rendered_mesh_dataset_map_provider import (  # noqa: F401
 | 
			
		||||
                RenderedMeshDatasetMapProvider,
 | 
			
		||||
            )
 | 
			
		||||
        finally:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
        self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
 | 
			
		||||
 | 
			
		||||
@ -20,23 +20,8 @@ from pytorch3d.implicitron.models.base_model import (
 | 
			
		||||
    ImplicitronRender,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase
 | 
			
		||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (  # noqa
 | 
			
		||||
    ResNetFeatureExtractor,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (  # noqa
 | 
			
		||||
    IdrFeatureField,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (  # noqa
 | 
			
		||||
    NeRFormerImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (  # noqa
 | 
			
		||||
    SRNHyperNetImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (  # noqa
 | 
			
		||||
    VoxelGridImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.metrics import (
 | 
			
		||||
    RegularizationMetricsBase,
 | 
			
		||||
    ViewMetricsBase,
 | 
			
		||||
@ -50,14 +35,7 @@ from pytorch3d.implicitron.models.renderer.base import (
 | 
			
		||||
    RendererOutput,
 | 
			
		||||
    RenderSamplingMode,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer  # noqa
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (  # noqa
 | 
			
		||||
    MultiPassEmissionAbsorptionRenderer,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.sdf_renderer import (  # noqa
 | 
			
		||||
    SignedDistanceFunctionRenderer,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.utils import (
 | 
			
		||||
    apply_chunked,
 | 
			
		||||
@ -315,6 +293,37 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def pre_expand(cls) -> None:
 | 
			
		||||
        # use try/finally to bypass cinder's lazy imports
 | 
			
		||||
        try:
 | 
			
		||||
            from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (  # noqa: F401, B950
 | 
			
		||||
                ResNetFeatureExtractor,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (  # noqa: F401, B950
 | 
			
		||||
                IdrFeatureField,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (  # noqa: F401, B950
 | 
			
		||||
                NeRFormerImplicitFunction,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (  # noqa: F401, B950
 | 
			
		||||
                SRNHyperNetImplicitFunction,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (  # noqa: F401, B950
 | 
			
		||||
                VoxelGridImplicitFunction,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.renderer.lstm_renderer import (  # noqa: F401
 | 
			
		||||
                LSTMRenderer,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.renderer.multipass_ea import (  # noqa
 | 
			
		||||
                MultiPassEmissionAbsorptionRenderer,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.renderer.sdf_renderer import (  # noqa: F401
 | 
			
		||||
                SignedDistanceFunctionRenderer,
 | 
			
		||||
            )
 | 
			
		||||
        finally:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if self.view_pooler_enabled:
 | 
			
		||||
            if self.image_feature_extractor_class_type is None:
 | 
			
		||||
 | 
			
		||||
@ -258,6 +258,31 @@ class OverfitModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def pre_expand(cls) -> None:
 | 
			
		||||
        # use try/finally to bypass cinder's lazy imports
 | 
			
		||||
        try:
 | 
			
		||||
            from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (  # noqa: F401, B950
 | 
			
		||||
                IdrFeatureField,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (  # noqa: F401, B950
 | 
			
		||||
                NeuralRadianceFieldImplicitFunction,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (  # noqa: F401, B950
 | 
			
		||||
                SRNImplicitFunction,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.renderer.lstm_renderer import (  # noqa: F401
 | 
			
		||||
                LSTMRenderer,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.renderer.multipass_ea import (  # noqa: F401
 | 
			
		||||
                MultiPassEmissionAbsorptionRenderer,
 | 
			
		||||
            )
 | 
			
		||||
            from pytorch3d.implicitron.models.renderer.sdf_renderer import (  # noqa: F401
 | 
			
		||||
                SignedDistanceFunctionRenderer,
 | 
			
		||||
            )
 | 
			
		||||
        finally:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        # The attribute will be filled by run_auto_creation
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
@ -185,6 +185,7 @@ CREATE_PREFIX: str = "create_"
 | 
			
		||||
IMPL_SUFFIX: str = "_impl"
 | 
			
		||||
TWEAK_SUFFIX: str = "_tweak_args"
 | 
			
		||||
_DATACLASS_INIT: str = "__dataclass_own_init__"
 | 
			
		||||
PRE_EXPAND_NAME: str = "pre_expand"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplaceableBase:
 | 
			
		||||
@ -838,6 +839,9 @@ def expand_args_fields(
 | 
			
		||||
    In addition, if the class inherits torch.nn.Module, the generated __init__ will
 | 
			
		||||
    call torch.nn.Module's __init__ before doing anything else.
 | 
			
		||||
 | 
			
		||||
    Before any transformation of the class, if the class has a classmethod called
 | 
			
		||||
    `pre_expand`, it will be called with no arguments.
 | 
			
		||||
 | 
			
		||||
    Note that although the *_args members are intended to have type DictConfig, they
 | 
			
		||||
    are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
 | 
			
		||||
    in place of a dict, but not vice-versa. Allowing dict lets a class user specify
 | 
			
		||||
@ -858,6 +862,9 @@ def expand_args_fields(
 | 
			
		||||
    if _is_actually_dataclass(some_class):
 | 
			
		||||
        return some_class
 | 
			
		||||
 | 
			
		||||
    if hasattr(some_class, PRE_EXPAND_NAME):
 | 
			
		||||
        getattr(some_class, PRE_EXPAND_NAME)()
 | 
			
		||||
 | 
			
		||||
    # The functions this class's run_auto_creation will run.
 | 
			
		||||
    creation_functions: List[str] = []
 | 
			
		||||
    # The classes which this type knows about from the registry
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,7 @@ import unittest
 | 
			
		||||
from dataclasses import dataclass, field, is_dataclass
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple
 | 
			
		||||
from unittest.mock import Mock
 | 
			
		||||
 | 
			
		||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
@ -805,6 +806,39 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(control_args, ["Orange", "Orange", True, True])
 | 
			
		||||
 | 
			
		||||
    def test_pre_expand(self):
 | 
			
		||||
        # Check that the precreate method of a class is called once before
 | 
			
		||||
        # when expand_args_fields is called on the class.
 | 
			
		||||
 | 
			
		||||
        class A(Configurable):
 | 
			
		||||
            n: int = 9
 | 
			
		||||
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def pre_expand(cls):
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        A.pre_expand = Mock()
 | 
			
		||||
        expand_args_fields(A)
 | 
			
		||||
        A.pre_expand.assert_called()
 | 
			
		||||
 | 
			
		||||
    def test_pre_expand_replaceable(self):
 | 
			
		||||
        # Check that the precreate method of a class is called once before
 | 
			
		||||
        # when expand_args_fields is called on the class.
 | 
			
		||||
 | 
			
		||||
        class A(ReplaceableBase):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
            @classmethod
 | 
			
		||||
            def pre_expand(cls):
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        class A1(A):
 | 
			
		||||
            n: 9
 | 
			
		||||
 | 
			
		||||
        A.pre_expand = Mock()
 | 
			
		||||
        expand_args_fields(A1)
 | 
			
		||||
        A.pre_expand.assert_called()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass(eq=False)
 | 
			
		||||
class MockDataclass:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user