Make Module.__init__ automatic

Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else.

Reviewed By: shapovalov

Differential Revision: D42760349

fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
This commit is contained in:
Jeremy Reizenstein 2023-01-27 07:07:46 -08:00 committed by Facebook GitHub Bot
parent 97f8f9bf47
commit 9540c29023
29 changed files with 36 additions and 87 deletions

View File

@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
class XRayRenderer(BaseRenderer, torch.nn.Module): class XRayRenderer(BaseRenderer, torch.nn.Module):
n_pts_per_ray: int = 64 n_pts_per_ray: int = 64
# if there are other base classes, make sure to call `super().__init__()` explicitly
def __post_init__(self): def __post_init__(self):
super().__init__()
# custom initialization # custom initialization
def forward( def forward(

View File

@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
raise ValueError("Image size should be set in the dataset") raise ValueError("Image size should be set in the dataset")
# init the simple DBIR model # init the simple DBIR model
model = ModelDBIR( # pyre-ignore[28]: ctor implicitly overridden model = ModelDBIR(
render_image_width=image_size, render_image_width=image_size,
render_image_height=image_size, render_image_height=image_size,
bg_color=bg_color, bg_color=bg_color,

View File

@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
# the training loop. # the training loop.
log_vars: List[str] = field(default_factory=lambda: ["objective"]) log_vars: List[str] = field(default_factory=lambda: ["objective"])
def __init__(self) -> None:
super().__init__()
def forward( def forward(
self, self,
*, # force keyword-only arguments *, # force keyword-only arguments

View File

@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
Base class for an extractor of a set of features from images. Base class for an extractor of a set of features from images.
""" """
def __init__(self):
super().__init__()
def get_feat_dims(self) -> int: def get_feat_dims(self) -> int:
""" """
Returns: Returns:

View File

@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
feature_rescale: float = 1.0 feature_rescale: float = 1.0
def __post_init__(self): def __post_init__(self):
super().__init__()
if self.normalize_image: if self.normalize_image:
# register buffers needed to normalize the image # register buffers needed to normalize the image
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):

View File

@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
) )
def __post_init__(self): def __post_init__(self):
super().__init__()
if self.view_pooler_enabled: if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None: if self.image_feature_extractor_class_type is None:
raise ValueError( raise ValueError(

View File

@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
ignore_input: bool = False ignore_input: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
if self.n_instances <= 0: if self.n_instances <= 0:
raise ValueError(f"Invalid n_instances {self.n_instances}") raise ValueError(f"Invalid n_instances {self.n_instances}")

View File

@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
(`SequenceAutodecoder`). (`SequenceAutodecoder`).
""" """
def __init__(self) -> None:
super().__init__()
def get_encoding_dim(self): def get_encoding_dim(self):
""" """
Returns the dimensionality of the returned encoding. Returns the dimensionality of the returned encoding.
@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
autodecoder: Autodecoder autodecoder: Autodecoder
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
def get_encoding_dim(self): def get_encoding_dim(self):
@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
time_divisor: float = 1.0 time_divisor: float = 1.0
def __post_init__(self): def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding( self._harmonic_embedding = HarmonicEmbedding(
n_harmonic_functions=self.n_harmonic_functions, n_harmonic_functions=self.n_harmonic_functions,
append_input=self.append_input, append_input=self.append_input,

View File

@ -14,9 +14,6 @@ from pytorch3d.renderer.cameras import CamerasBase
class ImplicitFunctionBase(ABC, ReplaceableBase): class ImplicitFunctionBase(ABC, ReplaceableBase):
def __init__(self):
super().__init__()
@abstractmethod @abstractmethod
def forward( def forward(
self, self,

View File

@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
space and transforms it into the required quantity (for example density and color). space and transforms it into the required quantity (for example density and color).
""" """
def __post_init__(self):
super().__init__()
def forward( def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
operation: DecoderActivation = DecoderActivation.IDENTITY operation: DecoderActivation = DecoderActivation.IDENTITY
def __post_init__(self): def __post_init__(self):
super().__post_init__()
if self.operation not in [ if self.operation not in [
DecoderActivation.RELU, DecoderActivation.RELU,
DecoderActivation.SOFTPLUS, DecoderActivation.SOFTPLUS,
@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
use_xavier_init: bool = True use_xavier_init: bool = True
def __post_init__(self): def __post_init__(self):
super().__init__()
try: try:
last_activation = { last_activation = {
DecoderActivation.RELU: torch.nn.ReLU(True), DecoderActivation.RELU: torch.nn.ReLU(True),
@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
network: MLPWithInputSkips network: MLPWithInputSkips
def __post_init__(self): def __post_init__(self):
super().__post_init__()
run_auto_creation(self) run_auto_creation(self)
def forward( def forward(

View File

@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
encoding_dim: int = 0 encoding_dim: int = 0
def __post_init__(self): def __post_init__(self):
super().__init__()
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size] dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
self.embed_fn = None self.embed_fn = None

View File

@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
""" """
def __post_init__(self): def __post_init__(self):
super().__init__()
# The harmonic embedding layer converts input 3D coordinates # The harmonic embedding layer converts input 3D coordinates
# to a representation that is more suitable for # to a representation that is more suitable for
# processing with a deep neural network. # processing with a deep neural network.

View File

@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
raymarch_function: Any = None raymarch_function: Any = None
def __post_init__(self): def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding( self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True self.n_harmonic_functions, append_input=True
) )
@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
ray_dir_in_camera_coords: bool = False ray_dir_in_camera_coords: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding( self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True self.n_harmonic_functions, append_input=True
) )
@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
xyz_in_camera_coords: bool = False xyz_in_camera_coords: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
raymarch_input_embedding_dim = ( raymarch_input_embedding_dim = (
HarmonicEmbedding.get_output_dim_static( HarmonicEmbedding.get_output_dim_static(
self.in_features, self.in_features,
@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
pixel_generator: SRNPixelGenerator pixel_generator: SRNPixelGenerator
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
def create_raymarch_function(self) -> None: def create_raymarch_function(self) -> None:
@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
pixel_generator: SRNPixelGenerator pixel_generator: SRNPixelGenerator
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
def create_hypernet(self) -> None: def create_hypernet(self) -> None:

View File

@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
) )
def __post_init__(self): def __post_init__(self):
super().__init__()
if 0 not in self.resolution_changes: if 0 not in self.resolution_changes:
raise ValueError("There has to be key `0` in `resolution_changes`.") raise ValueError("There has to be key `0` in `resolution_changes`.")
@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
param_groups: Dict[str, str] = field(default_factory=lambda: {}) param_groups: Dict[str, str] = field(default_factory=lambda: {})
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid. n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes(epoch=0) shapes = self.voxel_grid.get_shapes(epoch=0)

View File

@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
volume_cropping_epochs: Tuple[int, ...] = () volume_cropping_epochs: Tuple[int, ...] = ()
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__init__()
run_auto_creation(self) run_auto_creation(self)
# pyre-ignore[16] # pyre-ignore[16]
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold() self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()

View File

@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
depend on the model's parameters. depend on the model's parameters.
""" """
def __post_init__(self) -> None:
super().__init__()
def forward( def forward(
self, model: Any, keys_prefix: str = "loss_", **kwargs self, model: Any, keys_prefix: str = "loss_", **kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
`forward()` method produces losses and other metrics. `forward()` method produces losses and other metrics.
""" """
def __post_init__(self) -> None:
super().__init__()
def forward( def forward(
self, self,
raymarched: RendererOutput, raymarched: RendererOutput,

View File

@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase):
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
max_points: int = -1 max_points: int = -1
def __post_init__(self):
super().__init__()
def forward( def forward(
self, self,
*, # force keyword-only arguments *, # force keyword-only arguments

View File

@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase):
Base class for all Renderer implementations. Base class for all Renderer implementations.
""" """
def __init__(self) -> None:
super().__init__()
def requires_object_mask(self) -> bool: def requires_object_mask(self) -> bool:
""" """
Whether `forward` needs the object_mask. Whether `forward` needs the object_mask.

View File

@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
verbose: bool = False verbose: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
self._lstm = torch.nn.LSTMCell( self._lstm = torch.nn.LSTMCell(
input_size=self.n_feature_channels, input_size=self.n_feature_channels,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,

View File

@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
return_weights: bool = False return_weights: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
self._refiners = { self._refiners = {
EvaluationMode.TRAINING: RayPointRefiner( EvaluationMode.TRAINING: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_training, n_pts_per_ray=self.n_pts_per_ray_fine_training,

View File

@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
random_sampling: bool random_sampling: bool
add_input_samples: bool = True add_input_samples: bool = True
def __post_init__(self) -> None:
super().__init__()
def forward( def forward(
self, self,
input_ray_bundle: ImplicitronRayBundle, input_ray_bundle: ImplicitronRayBundle,

View File

@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase):
Base class for ray samplers. Base class for ray samplers.
""" """
def __init__(self):
super().__init__()
def forward( def forward(
self, self,
cameras: CamerasBase, cameras: CamerasBase,
@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
stratified_point_sampling_evaluation: bool = False stratified_point_sampling_evaluation: bool = False
def __post_init__(self): def __post_init__(self):
super().__init__()
if (self.n_rays_per_image_sampled_from_mask is not None) and ( if (self.n_rays_per_image_sampled_from_mask is not None) and (
self.n_rays_total_training is not None self.n_rays_total_training is not None
): ):

View File

@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module):
n_steps: int = 100 n_steps: int = 100
n_secant_steps: int = 8 n_secant_steps: int = 8
def __post_init__(self):
super().__init__()
def forward( def forward(
self, self,
sdf: Callable[[torch.Tensor], torch.Tensor], sdf: Callable[[torch.Tensor], torch.Tensor],

View File

@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase):
and marching along them in order to generate a feature render. and marching along them in order to generate a feature render.
""" """
def __init__(self):
super().__init__()
def forward( def forward(
self, self,
rays_densities: torch.Tensor, rays_densities: torch.Tensor,
@ -98,8 +95,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
surface_thickness: Denotes the overlap between the absorption surface_thickness: Denotes the overlap between the absorption
function and the density function. function and the density function.
""" """
super().__init__()
bg_color = torch.tensor(self.bg_color) bg_color = torch.tensor(self.bg_color)
if bg_color.ndim != 1: if bg_color.ndim != 1:
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor") raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")

View File

@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
def __post_init__( def __post_init__(
self, self,
): ):
super().__init__()
render_features_dimensions = self.render_features_dimensions render_features_dimensions = self.render_features_dimensions
if len(self.bg_color) not in [1, render_features_dimensions]: if len(self.bg_color) not in [1, render_features_dimensions]:
raise ValueError( raise ValueError(

View File

@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
the outputs. the outputs.
""" """
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim( def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
): ):
@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
ReductionFunction.STD, ReductionFunction.STD,
) )
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim( def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
): ):
@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
weight_by_ray_angle_gamma: float = 1.0 weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1 min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim( def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
): ):
@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
weight_by_ray_angle_gamma: float = 1.0 weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1 min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim( def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
): ):

View File

@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module):
feature_aggregator: FeatureAggregatorBase feature_aggregator: FeatureAggregatorBase
def __post_init__(self): def __post_init__(self):
super().__init__()
run_auto_creation(self) run_auto_creation(self)
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):

View File

@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module):
masked_sampling: bool = False masked_sampling: bool = False
sampling_mode: str = "bilinear" sampling_mode: str = "bilinear"
def __post_init__(self):
super().__init__()
def forward( def forward(
self, self,
*, # force kw args *, # force kw args

View File

@ -184,6 +184,7 @@ ENABLED_SUFFIX: str = "_enabled"
CREATE_PREFIX: str = "create_" CREATE_PREFIX: str = "create_"
IMPL_SUFFIX: str = "_impl" IMPL_SUFFIX: str = "_impl"
TWEAK_SUFFIX: str = "_tweak_args" TWEAK_SUFFIX: str = "_tweak_args"
_DATACLASS_INIT: str = "__dataclass_own_init__"
class ReplaceableBase: class ReplaceableBase:
@ -834,6 +835,9 @@ def expand_args_fields(
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args). the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
In addition, if the class inherits torch.nn.Module, the generated __init__ will
call torch.nn.Module's __init__ before doing anything else.
Note that although the *_args members are intended to have type DictConfig, they Note that although the *_args members are intended to have type DictConfig, they
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
in place of a dict, but not vice-versa. Allowing dict lets a class user specify in place of a dict, but not vice-versa. Allowing dict lets a class user specify
@ -912,9 +916,40 @@ def expand_args_fields(
some_class._known_implementations = known_implementations some_class._known_implementations = known_implementations
dataclasses.dataclass(eq=False)(some_class) dataclasses.dataclass(eq=False)(some_class)
_fixup_class_init(some_class)
return some_class return some_class
def _fixup_class_init(some_class) -> None:
"""
In-place modification of the some_class class which happens
after dataclass processing.
If the dataclass some_class inherits torch.nn.Module, then
makes torch.nn.Module's __init__ be called before anything else
on instantiation of some_class.
This is a bit like attr's __pre_init__.
"""
assert _is_actually_dataclass(some_class)
try:
import torch
except ModuleNotFoundError:
return
if not issubclass(some_class, torch.nn.Module):
return
def init(self, *args, **kwargs) -> None:
torch.nn.Module.__init__(self)
getattr(self, _DATACLASS_INIT)(*args, **kwargs)
assert not hasattr(some_class, _DATACLASS_INIT)
setattr(some_class, _DATACLASS_INIT, some_class.__init__)
some_class.__init__ = init
def get_default_args_field( def get_default_args_field(
C, C,
*, *,