mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Make Module.__init__ automatic
Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else. Reviewed By: shapovalov Differential Revision: D42760349 fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
This commit is contained in:
parent
97f8f9bf47
commit
9540c29023
@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
|
||||
class XRayRenderer(BaseRenderer, torch.nn.Module):
|
||||
n_pts_per_ray: int = 64
|
||||
|
||||
# if there are other base classes, make sure to call `super().__init__()` explicitly
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
# custom initialization
|
||||
|
||||
def forward(
|
||||
|
@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
|
||||
raise ValueError("Image size should be set in the dataset")
|
||||
|
||||
# init the simple DBIR model
|
||||
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
|
||||
model = ModelDBIR(
|
||||
render_image_width=image_size,
|
||||
render_image_height=image_size,
|
||||
bg_color=bg_color,
|
||||
|
@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
|
||||
# the training loop.
|
||||
log_vars: List[str] = field(default_factory=lambda: ["objective"])
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*, # force keyword-only arguments
|
||||
|
@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
|
||||
Base class for an extractor of a set of features from images.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_feat_dims(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
|
@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
feature_rescale: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
if self.normalize_image:
|
||||
# register buffers needed to normalize the image
|
||||
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
|
||||
|
@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
if self.view_pooler_enabled:
|
||||
if self.image_feature_extractor_class_type is None:
|
||||
raise ValueError(
|
||||
|
@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
ignore_input: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
if self.n_instances <= 0:
|
||||
raise ValueError(f"Invalid n_instances {self.n_instances}")
|
||||
|
||||
|
@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
|
||||
(`SequenceAutodecoder`).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_encoding_dim(self):
|
||||
"""
|
||||
Returns the dimensionality of the returned encoding.
|
||||
@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
|
||||
autodecoder: Autodecoder
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def get_encoding_dim(self):
|
||||
@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||
time_divisor: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self._harmonic_embedding = HarmonicEmbedding(
|
||||
n_harmonic_functions=self.n_harmonic_functions,
|
||||
append_input=self.append_input,
|
||||
|
@ -14,9 +14,6 @@ from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
|
||||
class ImplicitFunctionBase(ABC, ReplaceableBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
|
@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
||||
space and transforms it into the required quantity (for example density and color).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
||||
operation: DecoderActivation = DecoderActivation.IDENTITY
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.operation not in [
|
||||
DecoderActivation.RELU,
|
||||
DecoderActivation.SOFTPLUS,
|
||||
@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
use_xavier_init: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
try:
|
||||
last_activation = {
|
||||
DecoderActivation.RELU: torch.nn.ReLU(True),
|
||||
@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
|
||||
network: MLPWithInputSkips
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def forward(
|
||||
|
@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
encoding_dim: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
|
||||
|
||||
self.embed_fn = None
|
||||
|
@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
# The harmonic embedding layer converts input 3D coordinates
|
||||
# to a representation that is more suitable for
|
||||
# processing with a deep neural network.
|
||||
|
@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
||||
raymarch_function: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self._harmonic_embedding = HarmonicEmbedding(
|
||||
self.n_harmonic_functions, append_input=True
|
||||
)
|
||||
@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
ray_dir_in_camera_coords: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self._harmonic_embedding = HarmonicEmbedding(
|
||||
self.n_harmonic_functions, append_input=True
|
||||
)
|
||||
@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
||||
xyz_in_camera_coords: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
raymarch_input_embedding_dim = (
|
||||
HarmonicEmbedding.get_output_dim_static(
|
||||
self.in_features,
|
||||
@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
pixel_generator: SRNPixelGenerator
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def create_raymarch_function(self) -> None:
|
||||
@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
pixel_generator: SRNPixelGenerator
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def create_hypernet(self) -> None:
|
||||
|
@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
if 0 not in self.resolution_changes:
|
||||
raise ValueError("There has to be key `0` in `resolution_changes`.")
|
||||
|
||||
@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
|
||||
shapes = self.voxel_grid.get_shapes(epoch=0)
|
||||
|
@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
volume_cropping_epochs: Tuple[int, ...] = ()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
# pyre-ignore[16]
|
||||
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
||||
|
@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||
depend on the model's parameters.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||
`forward()` method produces losses and other metrics.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
|
@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase):
|
||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
max_points: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*, # force keyword-only arguments
|
||||
|
@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase):
|
||||
Base class for all Renderer implementations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def requires_object_mask(self) -> bool:
|
||||
"""
|
||||
Whether `forward` needs the object_mask.
|
||||
|
@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
verbose: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self._lstm = torch.nn.LSTMCell(
|
||||
input_size=self.n_feature_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
|
@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
||||
return_weights: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self._refiners = {
|
||||
EvaluationMode.TRAINING: RayPointRefiner(
|
||||
n_pts_per_ray=self.n_pts_per_ray_fine_training,
|
||||
|
@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
random_sampling: bool
|
||||
add_input_samples: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ray_bundle: ImplicitronRayBundle,
|
||||
|
@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase):
|
||||
Base class for ray samplers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cameras: CamerasBase,
|
||||
@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
stratified_point_sampling_evaluation: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
if (self.n_rays_per_image_sampled_from_mask is not None) and (
|
||||
self.n_rays_total_training is not None
|
||||
):
|
||||
|
@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module):
|
||||
n_steps: int = 100
|
||||
n_secant_steps: int = 8
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sdf: Callable[[torch.Tensor], torch.Tensor],
|
||||
|
@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase):
|
||||
and marching along them in order to generate a feature render.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
rays_densities: torch.Tensor,
|
||||
@ -98,8 +95,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
surface_thickness: Denotes the overlap between the absorption
|
||||
function and the density function.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
bg_color = torch.tensor(self.bg_color)
|
||||
if bg_color.ndim != 1:
|
||||
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
|
||||
|
@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
||||
def __post_init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
render_features_dimensions = self.render_features_dimensions
|
||||
if len(self.bg_color) not in [1, render_features_dimensions]:
|
||||
raise ValueError(
|
||||
|
@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
||||
the outputs.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
||||
ReductionFunction.STD,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
|
||||
weight_by_ray_angle_gamma: float = 1.0
|
||||
min_ray_angle_weight: float = 0.1
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
|
||||
weight_by_ray_angle_gamma: float = 1.0
|
||||
min_ray_angle_weight: float = 0.1
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
|
@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module):
|
||||
feature_aggregator: FeatureAggregatorBase
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||
|
@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module):
|
||||
masked_sampling: bool = False
|
||||
sampling_mode: str = "bilinear"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*, # force kw args
|
||||
|
@ -184,6 +184,7 @@ ENABLED_SUFFIX: str = "_enabled"
|
||||
CREATE_PREFIX: str = "create_"
|
||||
IMPL_SUFFIX: str = "_impl"
|
||||
TWEAK_SUFFIX: str = "_tweak_args"
|
||||
_DATACLASS_INIT: str = "__dataclass_own_init__"
|
||||
|
||||
|
||||
class ReplaceableBase:
|
||||
@ -834,6 +835,9 @@ def expand_args_fields(
|
||||
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
|
||||
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
|
||||
|
||||
In addition, if the class inherits torch.nn.Module, the generated __init__ will
|
||||
call torch.nn.Module's __init__ before doing anything else.
|
||||
|
||||
Note that although the *_args members are intended to have type DictConfig, they
|
||||
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
|
||||
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
|
||||
@ -912,9 +916,40 @@ def expand_args_fields(
|
||||
some_class._known_implementations = known_implementations
|
||||
|
||||
dataclasses.dataclass(eq=False)(some_class)
|
||||
_fixup_class_init(some_class)
|
||||
return some_class
|
||||
|
||||
|
||||
def _fixup_class_init(some_class) -> None:
|
||||
"""
|
||||
In-place modification of the some_class class which happens
|
||||
after dataclass processing.
|
||||
|
||||
If the dataclass some_class inherits torch.nn.Module, then
|
||||
makes torch.nn.Module's __init__ be called before anything else
|
||||
on instantiation of some_class.
|
||||
This is a bit like attr's __pre_init__.
|
||||
"""
|
||||
|
||||
assert _is_actually_dataclass(some_class)
|
||||
try:
|
||||
import torch
|
||||
except ModuleNotFoundError:
|
||||
return
|
||||
|
||||
if not issubclass(some_class, torch.nn.Module):
|
||||
return
|
||||
|
||||
def init(self, *args, **kwargs) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
getattr(self, _DATACLASS_INIT)(*args, **kwargs)
|
||||
|
||||
assert not hasattr(some_class, _DATACLASS_INIT)
|
||||
|
||||
setattr(some_class, _DATACLASS_INIT, some_class.__init__)
|
||||
some_class.__init__ = init
|
||||
|
||||
|
||||
def get_default_args_field(
|
||||
C,
|
||||
*,
|
||||
|
Loading…
x
Reference in New Issue
Block a user