mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Make Module.__init__ automatic
Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else. Reviewed By: shapovalov Differential Revision: D42760349 fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
This commit is contained in:
parent
97f8f9bf47
commit
9540c29023
@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
|
|||||||
class XRayRenderer(BaseRenderer, torch.nn.Module):
|
class XRayRenderer(BaseRenderer, torch.nn.Module):
|
||||||
n_pts_per_ray: int = 64
|
n_pts_per_ray: int = 64
|
||||||
|
|
||||||
# if there are other base classes, make sure to call `super().__init__()` explicitly
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
# custom initialization
|
# custom initialization
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
|
|||||||
raise ValueError("Image size should be set in the dataset")
|
raise ValueError("Image size should be set in the dataset")
|
||||||
|
|
||||||
# init the simple DBIR model
|
# init the simple DBIR model
|
||||||
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
|
model = ModelDBIR(
|
||||||
render_image_width=image_size,
|
render_image_width=image_size,
|
||||||
render_image_height=image_size,
|
render_image_height=image_size,
|
||||||
bg_color=bg_color,
|
bg_color=bg_color,
|
||||||
|
@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
|
|||||||
# the training loop.
|
# the training loop.
|
||||||
log_vars: List[str] = field(default_factory=lambda: ["objective"])
|
log_vars: List[str] = field(default_factory=lambda: ["objective"])
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*, # force keyword-only arguments
|
*, # force keyword-only arguments
|
||||||
|
@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
|
|||||||
Base class for an extractor of a set of features from images.
|
Base class for an extractor of a set of features from images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_feat_dims(self) -> int:
|
def get_feat_dims(self) -> int:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
feature_rescale: float = 1.0
|
feature_rescale: float = 1.0
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
if self.normalize_image:
|
if self.normalize_image:
|
||||||
# register buffers needed to normalize the image
|
# register buffers needed to normalize the image
|
||||||
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
|
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
|
||||||
|
@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if self.view_pooler_enabled:
|
if self.view_pooler_enabled:
|
||||||
if self.image_feature_extractor_class_type is None:
|
if self.image_feature_extractor_class_type is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
ignore_input: bool = False
|
ignore_input: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if self.n_instances <= 0:
|
if self.n_instances <= 0:
|
||||||
raise ValueError(f"Invalid n_instances {self.n_instances}")
|
raise ValueError(f"Invalid n_instances {self.n_instances}")
|
||||||
|
|
||||||
|
@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
|
|||||||
(`SequenceAutodecoder`).
|
(`SequenceAutodecoder`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_encoding_dim(self):
|
def get_encoding_dim(self):
|
||||||
"""
|
"""
|
||||||
Returns the dimensionality of the returned encoding.
|
Returns the dimensionality of the returned encoding.
|
||||||
@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
|
|||||||
autodecoder: Autodecoder
|
autodecoder: Autodecoder
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
def get_encoding_dim(self):
|
def get_encoding_dim(self):
|
||||||
@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
|||||||
time_divisor: float = 1.0
|
time_divisor: float = 1.0
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
self._harmonic_embedding = HarmonicEmbedding(
|
self._harmonic_embedding = HarmonicEmbedding(
|
||||||
n_harmonic_functions=self.n_harmonic_functions,
|
n_harmonic_functions=self.n_harmonic_functions,
|
||||||
append_input=self.append_input,
|
append_input=self.append_input,
|
||||||
|
@ -14,9 +14,6 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
|
|
||||||
|
|
||||||
class ImplicitFunctionBase(ABC, ReplaceableBase):
|
class ImplicitFunctionBase(ABC, ReplaceableBase):
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
|||||||
space and transforms it into the required quantity (for example density and color).
|
space and transforms it into the required quantity (for example density and color).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
|||||||
operation: DecoderActivation = DecoderActivation.IDENTITY
|
operation: DecoderActivation = DecoderActivation.IDENTITY
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
|
||||||
if self.operation not in [
|
if self.operation not in [
|
||||||
DecoderActivation.RELU,
|
DecoderActivation.RELU,
|
||||||
DecoderActivation.SOFTPLUS,
|
DecoderActivation.SOFTPLUS,
|
||||||
@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
use_xavier_init: bool = True
|
use_xavier_init: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
last_activation = {
|
last_activation = {
|
||||||
DecoderActivation.RELU: torch.nn.ReLU(True),
|
DecoderActivation.RELU: torch.nn.ReLU(True),
|
||||||
@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
|
|||||||
network: MLPWithInputSkips
|
network: MLPWithInputSkips
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
encoding_dim: int = 0
|
encoding_dim: int = 0
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
|
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
|
||||||
|
|
||||||
self.embed_fn = None
|
self.embed_fn = None
|
||||||
|
@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
# The harmonic embedding layer converts input 3D coordinates
|
# The harmonic embedding layer converts input 3D coordinates
|
||||||
# to a representation that is more suitable for
|
# to a representation that is more suitable for
|
||||||
# processing with a deep neural network.
|
# processing with a deep neural network.
|
||||||
|
@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
|||||||
raymarch_function: Any = None
|
raymarch_function: Any = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
self._harmonic_embedding = HarmonicEmbedding(
|
self._harmonic_embedding = HarmonicEmbedding(
|
||||||
self.n_harmonic_functions, append_input=True
|
self.n_harmonic_functions, append_input=True
|
||||||
)
|
)
|
||||||
@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
|||||||
ray_dir_in_camera_coords: bool = False
|
ray_dir_in_camera_coords: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
self._harmonic_embedding = HarmonicEmbedding(
|
self._harmonic_embedding = HarmonicEmbedding(
|
||||||
self.n_harmonic_functions, append_input=True
|
self.n_harmonic_functions, append_input=True
|
||||||
)
|
)
|
||||||
@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
|||||||
xyz_in_camera_coords: bool = False
|
xyz_in_camera_coords: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
raymarch_input_embedding_dim = (
|
raymarch_input_embedding_dim = (
|
||||||
HarmonicEmbedding.get_output_dim_static(
|
HarmonicEmbedding.get_output_dim_static(
|
||||||
self.in_features,
|
self.in_features,
|
||||||
@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
pixel_generator: SRNPixelGenerator
|
pixel_generator: SRNPixelGenerator
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
def create_raymarch_function(self) -> None:
|
def create_raymarch_function(self) -> None:
|
||||||
@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
pixel_generator: SRNPixelGenerator
|
pixel_generator: SRNPixelGenerator
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
def create_hypernet(self) -> None:
|
def create_hypernet(self) -> None:
|
||||||
|
@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
if 0 not in self.resolution_changes:
|
if 0 not in self.resolution_changes:
|
||||||
raise ValueError("There has to be key `0` in `resolution_changes`.")
|
raise ValueError("There has to be key `0` in `resolution_changes`.")
|
||||||
|
|
||||||
@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
|
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
|
||||||
shapes = self.voxel_grid.get_shapes(epoch=0)
|
shapes = self.voxel_grid.get_shapes(epoch=0)
|
||||||
|
@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
volume_cropping_epochs: Tuple[int, ...] = ()
|
volume_cropping_epochs: Tuple[int, ...] = ()
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__init__()
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
||||||
|
@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
|
|||||||
depend on the model's parameters.
|
depend on the model's parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
|||||||
`forward()` method produces losses and other metrics.
|
`forward()` method produces losses and other metrics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
raymarched: RendererOutput,
|
raymarched: RendererOutput,
|
||||||
|
@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase):
|
|||||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
max_points: int = -1
|
max_points: int = -1
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*, # force keyword-only arguments
|
*, # force keyword-only arguments
|
||||||
|
@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase):
|
|||||||
Base class for all Renderer implementations.
|
Base class for all Renderer implementations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def requires_object_mask(self) -> bool:
|
def requires_object_mask(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Whether `forward` needs the object_mask.
|
Whether `forward` needs the object_mask.
|
||||||
|
@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
self._lstm = torch.nn.LSTMCell(
|
self._lstm = torch.nn.LSTMCell(
|
||||||
input_size=self.n_feature_channels,
|
input_size=self.n_feature_channels,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
|
@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
return_weights: bool = False
|
return_weights: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
self._refiners = {
|
self._refiners = {
|
||||||
EvaluationMode.TRAINING: RayPointRefiner(
|
EvaluationMode.TRAINING: RayPointRefiner(
|
||||||
n_pts_per_ray=self.n_pts_per_ray_fine_training,
|
n_pts_per_ray=self.n_pts_per_ray_fine_training,
|
||||||
|
@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
|||||||
random_sampling: bool
|
random_sampling: bool
|
||||||
add_input_samples: bool = True
|
add_input_samples: bool = True
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ray_bundle: ImplicitronRayBundle,
|
input_ray_bundle: ImplicitronRayBundle,
|
||||||
|
@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase):
|
|||||||
Base class for ray samplers.
|
Base class for ray samplers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
cameras: CamerasBase,
|
cameras: CamerasBase,
|
||||||
@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
stratified_point_sampling_evaluation: bool = False
|
stratified_point_sampling_evaluation: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if (self.n_rays_per_image_sampled_from_mask is not None) and (
|
if (self.n_rays_per_image_sampled_from_mask is not None) and (
|
||||||
self.n_rays_total_training is not None
|
self.n_rays_total_training is not None
|
||||||
):
|
):
|
||||||
|
@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module):
|
|||||||
n_steps: int = 100
|
n_steps: int = 100
|
||||||
n_secant_steps: int = 8
|
n_secant_steps: int = 8
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sdf: Callable[[torch.Tensor], torch.Tensor],
|
sdf: Callable[[torch.Tensor], torch.Tensor],
|
||||||
|
@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase):
|
|||||||
and marching along them in order to generate a feature render.
|
and marching along them in order to generate a feature render.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
rays_densities: torch.Tensor,
|
rays_densities: torch.Tensor,
|
||||||
@ -98,8 +95,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
|||||||
surface_thickness: Denotes the overlap between the absorption
|
surface_thickness: Denotes the overlap between the absorption
|
||||||
function and the density function.
|
function and the density function.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
bg_color = torch.tensor(self.bg_color)
|
bg_color = torch.tensor(self.bg_color)
|
||||||
if bg_color.ndim != 1:
|
if bg_color.ndim != 1:
|
||||||
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
|
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
|
||||||
|
@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
|||||||
def __post_init__(
|
def __post_init__(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
super().__init__()
|
|
||||||
render_features_dimensions = self.render_features_dimensions
|
render_features_dimensions = self.render_features_dimensions
|
||||||
if len(self.bg_color) not in [1, render_features_dimensions]:
|
if len(self.bg_color) not in [1, render_features_dimensions]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
|||||||
the outputs.
|
the outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_aggregated_feature_dim(
|
def get_aggregated_feature_dim(
|
||||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
):
|
):
|
||||||
@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
|||||||
ReductionFunction.STD,
|
ReductionFunction.STD,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_aggregated_feature_dim(
|
def get_aggregated_feature_dim(
|
||||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
):
|
):
|
||||||
@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
|
|||||||
weight_by_ray_angle_gamma: float = 1.0
|
weight_by_ray_angle_gamma: float = 1.0
|
||||||
min_ray_angle_weight: float = 0.1
|
min_ray_angle_weight: float = 0.1
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_aggregated_feature_dim(
|
def get_aggregated_feature_dim(
|
||||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
):
|
):
|
||||||
@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
|
|||||||
weight_by_ray_angle_gamma: float = 1.0
|
weight_by_ray_angle_gamma: float = 1.0
|
||||||
min_ray_angle_weight: float = 0.1
|
min_ray_angle_weight: float = 0.1
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_aggregated_feature_dim(
|
def get_aggregated_feature_dim(
|
||||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
):
|
):
|
||||||
|
@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module):
|
|||||||
feature_aggregator: FeatureAggregatorBase
|
feature_aggregator: FeatureAggregatorBase
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||||
|
@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module):
|
|||||||
masked_sampling: bool = False
|
masked_sampling: bool = False
|
||||||
sampling_mode: str = "bilinear"
|
sampling_mode: str = "bilinear"
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*, # force kw args
|
*, # force kw args
|
||||||
|
@ -184,6 +184,7 @@ ENABLED_SUFFIX: str = "_enabled"
|
|||||||
CREATE_PREFIX: str = "create_"
|
CREATE_PREFIX: str = "create_"
|
||||||
IMPL_SUFFIX: str = "_impl"
|
IMPL_SUFFIX: str = "_impl"
|
||||||
TWEAK_SUFFIX: str = "_tweak_args"
|
TWEAK_SUFFIX: str = "_tweak_args"
|
||||||
|
_DATACLASS_INIT: str = "__dataclass_own_init__"
|
||||||
|
|
||||||
|
|
||||||
class ReplaceableBase:
|
class ReplaceableBase:
|
||||||
@ -834,6 +835,9 @@ def expand_args_fields(
|
|||||||
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
|
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
|
||||||
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
|
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
|
||||||
|
|
||||||
|
In addition, if the class inherits torch.nn.Module, the generated __init__ will
|
||||||
|
call torch.nn.Module's __init__ before doing anything else.
|
||||||
|
|
||||||
Note that although the *_args members are intended to have type DictConfig, they
|
Note that although the *_args members are intended to have type DictConfig, they
|
||||||
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
|
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
|
||||||
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
|
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
|
||||||
@ -912,9 +916,40 @@ def expand_args_fields(
|
|||||||
some_class._known_implementations = known_implementations
|
some_class._known_implementations = known_implementations
|
||||||
|
|
||||||
dataclasses.dataclass(eq=False)(some_class)
|
dataclasses.dataclass(eq=False)(some_class)
|
||||||
|
_fixup_class_init(some_class)
|
||||||
return some_class
|
return some_class
|
||||||
|
|
||||||
|
|
||||||
|
def _fixup_class_init(some_class) -> None:
|
||||||
|
"""
|
||||||
|
In-place modification of the some_class class which happens
|
||||||
|
after dataclass processing.
|
||||||
|
|
||||||
|
If the dataclass some_class inherits torch.nn.Module, then
|
||||||
|
makes torch.nn.Module's __init__ be called before anything else
|
||||||
|
on instantiation of some_class.
|
||||||
|
This is a bit like attr's __pre_init__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert _is_actually_dataclass(some_class)
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not issubclass(some_class, torch.nn.Module):
|
||||||
|
return
|
||||||
|
|
||||||
|
def init(self, *args, **kwargs) -> None:
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
getattr(self, _DATACLASS_INIT)(*args, **kwargs)
|
||||||
|
|
||||||
|
assert not hasattr(some_class, _DATACLASS_INIT)
|
||||||
|
|
||||||
|
setattr(some_class, _DATACLASS_INIT, some_class.__init__)
|
||||||
|
some_class.__init__ = init
|
||||||
|
|
||||||
|
|
||||||
def get_default_args_field(
|
def get_default_args_field(
|
||||||
C,
|
C,
|
||||||
*,
|
*,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user