mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Make Module.__init__ automatic
Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else. Reviewed By: shapovalov Differential Revision: D42760349 fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
This commit is contained in:
		
							parent
							
								
									97f8f9bf47
								
							
						
					
					
						commit
						9540c29023
					
				@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
 | 
			
		||||
class XRayRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
    n_pts_per_ray: int = 64
 | 
			
		||||
 | 
			
		||||
    # if there are other base classes, make sure to call `super().__init__()` explicitly
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # custom initialization
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
 | 
			
		||||
@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
 | 
			
		||||
        raise ValueError("Image size should be set in the dataset")
 | 
			
		||||
 | 
			
		||||
    # init the simple DBIR model
 | 
			
		||||
    model = ModelDBIR(  # pyre-ignore[28]: c’tor implicitly overridden
 | 
			
		||||
    model = ModelDBIR(
 | 
			
		||||
        render_image_width=image_size,
 | 
			
		||||
        render_image_height=image_size,
 | 
			
		||||
        bg_color=bg_color,
 | 
			
		||||
 | 
			
		||||
@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    # the training loop.
 | 
			
		||||
    log_vars: List[str] = field(default_factory=lambda: ["objective"])
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        *,  # force keyword-only arguments
 | 
			
		||||
 | 
			
		||||
@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    Base class for an extractor of a set of features from images.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def get_feat_dims(self) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        Returns:
 | 
			
		||||
 | 
			
		||||
@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
 | 
			
		||||
    feature_rescale: float = 1.0
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        if self.normalize_image:
 | 
			
		||||
            # register buffers needed to normalize the image
 | 
			
		||||
            for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
 | 
			
		||||
 | 
			
		||||
@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        if self.view_pooler_enabled:
 | 
			
		||||
            if self.image_feature_extractor_class_type is None:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
 | 
			
		||||
@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
    ignore_input: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        if self.n_instances <= 0:
 | 
			
		||||
            raise ValueError(f"Invalid n_instances {self.n_instances}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
 | 
			
		||||
    (`SequenceAutodecoder`).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def get_encoding_dim(self):
 | 
			
		||||
        """
 | 
			
		||||
        Returns the dimensionality of the returned encoding.
 | 
			
		||||
@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):  # pyre-ignore: 1
 | 
			
		||||
    autodecoder: Autodecoder
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def get_encoding_dim(self):
 | 
			
		||||
@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
 | 
			
		||||
    time_divisor: float = 1.0
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._harmonic_embedding = HarmonicEmbedding(
 | 
			
		||||
            n_harmonic_functions=self.n_harmonic_functions,
 | 
			
		||||
            append_input=self.append_input,
 | 
			
		||||
 | 
			
		||||
@ -14,9 +14,6 @@ from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImplicitFunctionBase(ABC, ReplaceableBase):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    space and transforms it into the required quantity (for example density and color).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, features: torch.Tensor, z: Optional[torch.Tensor] = None
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
 | 
			
		||||
    operation: DecoderActivation = DecoderActivation.IDENTITY
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__post_init__()
 | 
			
		||||
        if self.operation not in [
 | 
			
		||||
            DecoderActivation.RELU,
 | 
			
		||||
            DecoderActivation.SOFTPLUS,
 | 
			
		||||
@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
 | 
			
		||||
    use_xavier_init: bool = True
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            last_activation = {
 | 
			
		||||
                DecoderActivation.RELU: torch.nn.ReLU(True),
 | 
			
		||||
@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
 | 
			
		||||
    network: MLPWithInputSkips
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__post_init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
 | 
			
		||||
@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    encoding_dim: int = 0
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
 | 
			
		||||
 | 
			
		||||
        self.embed_fn = None
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # The harmonic embedding layer converts input 3D coordinates
 | 
			
		||||
        # to a representation that is more suitable for
 | 
			
		||||
        # processing with a deep neural network.
 | 
			
		||||
 | 
			
		||||
@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
 | 
			
		||||
    raymarch_function: Any = None
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._harmonic_embedding = HarmonicEmbedding(
 | 
			
		||||
            self.n_harmonic_functions, append_input=True
 | 
			
		||||
        )
 | 
			
		||||
@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
 | 
			
		||||
    ray_dir_in_camera_coords: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._harmonic_embedding = HarmonicEmbedding(
 | 
			
		||||
            self.n_harmonic_functions, append_input=True
 | 
			
		||||
        )
 | 
			
		||||
@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
 | 
			
		||||
    xyz_in_camera_coords: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        raymarch_input_embedding_dim = (
 | 
			
		||||
            HarmonicEmbedding.get_output_dim_static(
 | 
			
		||||
                self.in_features,
 | 
			
		||||
@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    pixel_generator: SRNPixelGenerator
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def create_raymarch_function(self) -> None:
 | 
			
		||||
@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    pixel_generator: SRNPixelGenerator
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def create_hypernet(self) -> None:
 | 
			
		||||
 | 
			
		||||
@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        if 0 not in self.resolution_changes:
 | 
			
		||||
            raise ValueError("There has to be key `0` in `resolution_changes`.")
 | 
			
		||||
 | 
			
		||||
@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
    param_groups: Dict[str, str] = field(default_factory=lambda: {})
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
        n_grids = 1  # Voxel grid objects are batched. We need only a single grid.
 | 
			
		||||
        shapes = self.voxel_grid.get_shapes(epoch=0)
 | 
			
		||||
 | 
			
		||||
@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    volume_cropping_epochs: Tuple[int, ...] = ()
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
 | 
			
		||||
 | 
			
		||||
@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    depend on the model's parameters.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, model: Any, keys_prefix: str = "loss_", **kwargs
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    `forward()` method produces losses and other metrics.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        raymarched: RendererOutput,
 | 
			
		||||
 | 
			
		||||
@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase):
 | 
			
		||||
    bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
 | 
			
		||||
    max_points: int = -1
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        *,  # force keyword-only arguments
 | 
			
		||||
 | 
			
		||||
@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase):
 | 
			
		||||
    Base class for all Renderer implementations.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def requires_object_mask(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Whether `forward` needs the object_mask.
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
    verbose: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._lstm = torch.nn.LSTMCell(
 | 
			
		||||
            input_size=self.n_feature_channels,
 | 
			
		||||
            hidden_size=self.hidden_size,
 | 
			
		||||
 | 
			
		||||
@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer(  # pyre-ignore: 13
 | 
			
		||||
    return_weights: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._refiners = {
 | 
			
		||||
            EvaluationMode.TRAINING: RayPointRefiner(
 | 
			
		||||
                n_pts_per_ray=self.n_pts_per_ray_fine_training,
 | 
			
		||||
 | 
			
		||||
@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
 | 
			
		||||
    random_sampling: bool
 | 
			
		||||
    add_input_samples: bool = True
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ray_bundle: ImplicitronRayBundle,
 | 
			
		||||
 | 
			
		||||
@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase):
 | 
			
		||||
    Base class for ray samplers.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        cameras: CamerasBase,
 | 
			
		||||
@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
			
		||||
    stratified_point_sampling_evaluation: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        if (self.n_rays_per_image_sampled_from_mask is not None) and (
 | 
			
		||||
            self.n_rays_total_training is not None
 | 
			
		||||
        ):
 | 
			
		||||
 | 
			
		||||
@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module):
 | 
			
		||||
    n_steps: int = 100
 | 
			
		||||
    n_secant_steps: int = 8
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        sdf: Callable[[torch.Tensor], torch.Tensor],
 | 
			
		||||
 | 
			
		||||
@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase):
 | 
			
		||||
    and marching along them in order to generate a feature render.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        rays_densities: torch.Tensor,
 | 
			
		||||
@ -98,8 +95,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
 | 
			
		||||
            surface_thickness: Denotes the overlap between the absorption
 | 
			
		||||
                function and the density function.
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        bg_color = torch.tensor(self.bg_color)
 | 
			
		||||
        if bg_color.ndim != 1:
 | 
			
		||||
            raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
 | 
			
		||||
 | 
			
		||||
@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):  # pyre-ign
 | 
			
		||||
    def __post_init__(
 | 
			
		||||
        self,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        render_features_dimensions = self.render_features_dimensions
 | 
			
		||||
        if len(self.bg_color) not in [1, render_features_dimensions]:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
 | 
			
		||||
@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
 | 
			
		||||
    the outputs.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def get_aggregated_feature_dim(
 | 
			
		||||
        self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
 | 
			
		||||
    ):
 | 
			
		||||
@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
 | 
			
		||||
        ReductionFunction.STD,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def get_aggregated_feature_dim(
 | 
			
		||||
        self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
 | 
			
		||||
    ):
 | 
			
		||||
@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
 | 
			
		||||
    weight_by_ray_angle_gamma: float = 1.0
 | 
			
		||||
    min_ray_angle_weight: float = 0.1
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def get_aggregated_feature_dim(
 | 
			
		||||
        self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
 | 
			
		||||
    ):
 | 
			
		||||
@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
 | 
			
		||||
    weight_by_ray_angle_gamma: float = 1.0
 | 
			
		||||
    min_ray_angle_weight: float = 0.1
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def get_aggregated_feature_dim(
 | 
			
		||||
        self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module):
 | 
			
		||||
    feature_aggregator: FeatureAggregatorBase
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
 | 
			
		||||
 | 
			
		||||
@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module):
 | 
			
		||||
    masked_sampling: bool = False
 | 
			
		||||
    sampling_mode: str = "bilinear"
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        *,  # force kw args
 | 
			
		||||
 | 
			
		||||
@ -184,6 +184,7 @@ ENABLED_SUFFIX: str = "_enabled"
 | 
			
		||||
CREATE_PREFIX: str = "create_"
 | 
			
		||||
IMPL_SUFFIX: str = "_impl"
 | 
			
		||||
TWEAK_SUFFIX: str = "_tweak_args"
 | 
			
		||||
_DATACLASS_INIT: str = "__dataclass_own_init__"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReplaceableBase:
 | 
			
		||||
@ -834,6 +835,9 @@ def expand_args_fields(
 | 
			
		||||
    then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
 | 
			
		||||
    the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
 | 
			
		||||
 | 
			
		||||
    In addition, if the class inherits torch.nn.Module, the generated __init__ will
 | 
			
		||||
    call torch.nn.Module's __init__ before doing anything else.
 | 
			
		||||
 | 
			
		||||
    Note that although the *_args members are intended to have type DictConfig, they
 | 
			
		||||
    are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
 | 
			
		||||
    in place of a dict, but not vice-versa. Allowing dict lets a class user specify
 | 
			
		||||
@ -912,9 +916,40 @@ def expand_args_fields(
 | 
			
		||||
    some_class._known_implementations = known_implementations
 | 
			
		||||
 | 
			
		||||
    dataclasses.dataclass(eq=False)(some_class)
 | 
			
		||||
    _fixup_class_init(some_class)
 | 
			
		||||
    return some_class
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _fixup_class_init(some_class) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    In-place modification of the some_class class which happens
 | 
			
		||||
    after dataclass processing.
 | 
			
		||||
 | 
			
		||||
    If the dataclass some_class inherits torch.nn.Module, then
 | 
			
		||||
    makes torch.nn.Module's __init__ be called before anything else
 | 
			
		||||
    on instantiation of some_class.
 | 
			
		||||
    This is a bit like attr's __pre_init__.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    assert _is_actually_dataclass(some_class)
 | 
			
		||||
    try:
 | 
			
		||||
        import torch
 | 
			
		||||
    except ModuleNotFoundError:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if not issubclass(some_class, torch.nn.Module):
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def init(self, *args, **kwargs) -> None:
 | 
			
		||||
        torch.nn.Module.__init__(self)
 | 
			
		||||
        getattr(self, _DATACLASS_INIT)(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    assert not hasattr(some_class, _DATACLASS_INIT)
 | 
			
		||||
 | 
			
		||||
    setattr(some_class, _DATACLASS_INIT, some_class.__init__)
 | 
			
		||||
    some_class.__init__ = init
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_default_args_field(
 | 
			
		||||
    C,
 | 
			
		||||
    *,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user