diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index 2f0e6e3c..a65401f7 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -63,8 +63,9 @@ generic_model_args: n_pts_per_ray_fine_evaluation: 64 append_coarse_samples_to_fine: true density_noise_std_train: 1.0 - view_sampler_args: - masked_sampling: false + view_pooler_args: + view_sampler_args: + masked_sampling: false image_feature_extractor_args: stages: - 1 diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml index 9e00bb12..88cfead0 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml @@ -1,4 +1,5 @@ generic_model_args: + image_feature_extractor_enabled: true image_feature_extractor_args: add_images: true add_masks: true diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml index 017be45e..c45c65a9 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml @@ -1,4 +1,5 @@ generic_model_args: + image_feature_extractor_enabled: true image_feature_extractor_args: add_images: true add_masks: true diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml index d1c43458..8039086c 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml @@ -1,4 +1,5 @@ generic_model_args: + image_feature_extractor_enabled: true image_feature_extractor_args: stages: - 1 @@ -11,6 +12,7 @@ generic_model_args: name: resnet34 normalize_image: true pretrained: true - feature_aggregator_AngleWeightedReductionFeatureAggregator_args: - reduction_functions: - - AVG + view_pooler_args: + feature_aggregator_AngleWeightedReductionFeatureAggregator_args: + reduction_functions: + - AVG diff --git a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml index 0f6c5933..3be35876 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml @@ -11,7 +11,6 @@ generic_model_args: num_passes: 1 output_rasterized_mc: true sampling_mode_training: mask_sample - view_pool: false sequence_autodecoder_args: n_instances: 20000 init_scale: 1.0 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml index f5ce474a..f9a978f5 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml @@ -3,7 +3,7 @@ defaults: - _self_ generic_model_args: chunk_size_grid: 16000 - view_pool: false + view_pooler_enabled: false sequence_autodecoder_args: n_instances: 20000 encoding_dim: 256 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml index 19e5ab0b..2fa03575 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml @@ -5,6 +5,6 @@ defaults: clip_grad: 1.0 generic_model_args: chunk_size_grid: 16000 - view_pool: true + view_pooler_enabled: true raysampler_args: n_rays_per_image_sampled_from_mask: 850 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml index e2a57e96..213e1124 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml @@ -4,7 +4,6 @@ defaults: - _self_ generic_model_args: chunk_size_grid: 16000 - view_pool: true raysampler_args: n_rays_per_image_sampled_from_mask: 800 n_pts_per_ray_training: 32 @@ -13,4 +12,6 @@ generic_model_args: n_pts_per_ray_fine_training: 16 n_pts_per_ray_fine_evaluation: 16 implicit_function_class_type: NeRFormerImplicitFunction - feature_aggregator_class_type: IdentityFeatureAggregator + view_pooler_enabled: true + view_pooler_args: + feature_aggregator_class_type: IdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml index f28b5961..982c5eaa 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml @@ -4,7 +4,6 @@ defaults: - _self_ generic_model_args: chunk_size_grid: 16000 - view_pool: true raysampler_args: n_rays_per_image_sampled_from_mask: 800 n_pts_per_ray_training: 32 @@ -13,4 +12,6 @@ generic_model_args: n_pts_per_ray_fine_training: 16 n_pts_per_ray_fine_evaluation: 16 implicit_function_class_type: NeRFormerImplicitFunction - feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator + view_pooler_enabled: true + view_pooler_args: + feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml index a4ff2030..e719f646 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml @@ -3,7 +3,7 @@ defaults: - _self_ generic_model_args: chunk_size_grid: 16000 - view_pool: false + view_pooler_enabled: false n_train_target_views: -1 num_passes: 1 loss_weights: diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml index f59662ea..02425a6b 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml @@ -4,7 +4,6 @@ defaults: - _self_ generic_model_args: chunk_size_grid: 32000 - view_pool: true num_passes: 1 n_train_target_views: -1 loss_weights: @@ -25,6 +24,7 @@ generic_model_args: stratified_point_sampling_evaluation: false renderer_class_type: LSTMRenderer implicit_function_class_type: SRNImplicitFunction + view_pooler_enabled: true solver_args: breed: adam lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml index 28553693..1baa2523 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml @@ -9,7 +9,7 @@ generic_model_args: loss_eikonal: 0.1 chunk_size_grid: 65536 num_passes: 1 - view_pool: false + view_pooler_enabled: false implicit_function_IdrFeatureField_args: n_harmonic_functions_xyz: 6 bias: 0.6 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml index 93f3ff5c..31ee863e 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml @@ -4,6 +4,6 @@ defaults: - _self_ generic_model_args: chunk_size_grid: 16000 - view_pool: true + view_pooler_enabled: true raysampler_args: n_rays_per_image_sampled_from_mask: 850 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml index 215a6477..2fad73c6 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml @@ -4,7 +4,7 @@ defaults: - _self_ generic_model_args: chunk_size_grid: 16000 - view_pool: true + view_pooler_enabled: true implicit_function_class_type: NeRFormerImplicitFunction raysampler_args: n_rays_per_image_sampled_from_mask: 800 @@ -13,4 +13,5 @@ generic_model_args: renderer_MultiPassEmissionAbsorptionRenderer_args: n_pts_per_ray_fine_training: 16 n_pts_per_ray_fine_evaluation: 16 - feature_aggregator_class_type: IdentityFeatureAggregator + view_pooler_args: + feature_aggregator_class_type: IdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml index 98575daf..e6a84489 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml @@ -4,7 +4,7 @@ defaults: generic_model_args: num_passes: 1 chunk_size_grid: 32000 - view_pool: false + view_pooler_enabled: false loss_weights: loss_rgb_mse: 200.0 loss_prev_stage_rgb_mse: 0.0 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml index 57e10183..c66b21c2 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml @@ -5,7 +5,7 @@ defaults: generic_model_args: num_passes: 1 chunk_size_grid: 32000 - view_pool: true + view_pooler_enabled: true loss_weights: loss_rgb_mse: 200.0 loss_prev_stage_rgb_mse: 0.0 diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 1683d017..80483644 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -49,8 +49,7 @@ from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa from .renderer.ray_sampler import RaySampler from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa from .resnet_feature_extractor import ResNetFeatureExtractor -from .view_pooling.feature_aggregation import FeatureAggregatorBase -from .view_pooling.view_sampling import ViewSampler +from .view_pooler.view_pooler import ViewPooler STD_LOG_VARS = ["objective", "epoch", "sec/it"] @@ -167,16 +166,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 registry. renderer: A renderer class which inherits from BaseRenderer. This is used to generate the images from the target view(s). + image_feature_extractor_enabled: If `True`, constructs and enables + the `image_feature_extractor` object. image_feature_extractor: A module for extrating features from an input image. - view_sampler: An instance of ViewSampler which is used for sampling of + view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object. + view_pooler: An instance of ViewPooler which is used for sampling of image-based features at the 2D projections of a set - of 3D points. - feature_aggregator_class_type: The name of the feature aggregator class which - is available in the global registry. - feature_aggregator: A feature aggregator class which inherits from - FeatureAggregatorBase. Typically, the aggregated features and their - masks are output by a `ViewSampler` which samples feature tensors extracted - from a set of source images. FeatureAggregator executes step (4) above. + of 3D points and aggregating the sampled features. implicit_function_class_type: The type of implicit function to use which is available in the global registry. implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions @@ -195,7 +191,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 mask_threshold: float = 0.5 output_rasterized_mc: bool = False bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) - view_pool: bool = False num_passes: int = 1 chunk_size_grid: int = 4096 render_features_dimensions: int = 3 @@ -215,13 +210,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer" renderer: BaseRenderer - # ---- view sampling settings - used if view_pool=True - # (This is only created if view_pool is False) - image_feature_extractor: ResNetFeatureExtractor - view_sampler: ViewSampler - # ---- ---- view sampling feature aggregator settings - feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator" - feature_aggregator: FeatureAggregatorBase + # ---- image feature extractor settings + image_feature_extractor_enabled: bool = False + image_feature_extractor: Optional[ResNetFeatureExtractor] + # ---- view pooler settings + view_pooler_enabled: bool = False + view_pooler: Optional[ViewPooler] # ---- implicit function settings implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction" @@ -356,32 +350,34 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 # custom_args hold additional arguments to the implicit function. custom_args = {} - if self.view_pool: + if self.image_feature_extractor_enabled: + # (2) Extract features for the image + img_feats = self.image_feature_extractor( # pyre-fixme[29] + image_rgb, fg_probability + ) + + if self.view_pooler_enabled: if sequence_name is None: raise ValueError("sequence_name must be provided for view pooling") - # (2) Extract features for the image - img_feats = self.image_feature_extractor(image_rgb, fg_probability) + if not self.image_feature_extractor_enabled: + raise ValueError( + "image_feature_extractor has to be enabled for for view pooling" + + " (I.e. set self.image_feature_extractor_enabled=True)." + ) - # (3) Sample features and masks at the ray points - curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731 - pts=pts, - seq_id_pts=sequence_name[:n_targets], - camera=camera, - seq_id_camera=sequence_name, - feats=img_feats, - masks=mask_crop, - ) # returns feats_sampled, masks_sampled + # (3-4) Sample features and masks at the ray points. + # Aggregate features from multiple views. + def curried_viewpooler(pts): + return self.view_pooler( + pts=pts, + seq_id_pts=sequence_name[:n_targets], + camera=camera, + seq_id_camera=sequence_name, + feats=img_feats, + masks=mask_crop, + ) - # (4) Aggregate features from multiple views - # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. - curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731 - *curried_view_sampler(pts=pts), - pts=pts, - camera=camera, - ) # TODO: do we need to pass a callback rather than compute here? - # precomputing will be faster for 2 passes - # -> but this is important for non-nerf - custom_args["fun_viewpool"] = curried_view_pool + custom_args["fun_viewpool"] = curried_viewpooler global_code = None if self.sequence_autodecoder.n_instances > 0: @@ -562,10 +558,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 def _get_viewpooled_feature_dim(self): return ( - self.feature_aggregator.get_aggregated_feature_dim( + self.view_pooler.get_aggregated_feature_dim( self.image_feature_extractor.get_feat_dims() ) - if self.view_pool + if self.view_pooler_enabled else 0 ) @@ -583,15 +579,20 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 "object_bounding_sphere" ] = self.raysampler_args["scene_extent"] - def create_image_feature_extractor(self): + def create_view_pooler(self): """ - Custom creation function called by run_auto_creation so that the - image_feature_extractor is not created if it is not be needed. + Custom creation function called by run_auto_creation checking + that image_feature_extractor is enabled when view_pooler is enabled. """ - if self.view_pool: - self.image_feature_extractor = ResNetFeatureExtractor( - **self.image_feature_extractor_args - ) + if self.view_pooler_enabled: + if not self.image_feature_extractor_enabled: + raise ValueError( + "image_feature_extractor has to be enabled for view pooling" + + " (I.e. set self.image_feature_extractor_enabled=True)." + ) + self.view_pooler = ViewPooler(**self.view_pooler_args) + else: + self.view_pooler = None def create_implicit_function(self) -> None: """ @@ -652,10 +653,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 ) if implicit_function_type.requires_pooling_without_aggregation(): - has_aggregation = hasattr(self.feature_aggregator, "reduction_functions") - if not self.view_pool or has_aggregation: + if self.view_pooler_enabled and self.view_pooler.has_aggregation(): raise ValueError( - "Chosen implicit function requires view pooling without aggregation." + "The chosen implicit function requires view pooling without aggregation." ) config_name = f"implicit_function_{self.implicit_function_class_type}_args" config = getattr(self, config_name, None) diff --git a/pytorch3d/implicitron/models/resnet_feature_extractor.py b/pytorch3d/implicitron/models/resnet_feature_extractor.py index 5fc20469..27c7e4ec 100644 --- a/pytorch3d/implicitron/models/resnet_feature_extractor.py +++ b/pytorch3d/implicitron/models/resnet_feature_extractor.py @@ -141,11 +141,12 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module): def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor: return (img - self._resnet_mean) / self._resnet_std - def get_feat_dims(self, size_dict: bool = False): - if size_dict: - return copy.deepcopy(self._feat_dim) - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.values)[[Na... - return sum(self._feat_dim.values()) + def get_feat_dims(self) -> int: + return ( + sum(self._feat_dim.values()) # pyre-fixme[29] + if len(self._feat_dim) > 0 # pyre-fixme[6] + else 0 + ) def forward( self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None diff --git a/pytorch3d/implicitron/models/view_pooling/__init__.py b/pytorch3d/implicitron/models/view_pooler/__init__.py similarity index 100% rename from pytorch3d/implicitron/models/view_pooling/__init__.py rename to pytorch3d/implicitron/models/view_pooler/__init__.py diff --git a/pytorch3d/implicitron/models/view_pooling/feature_aggregation.py b/pytorch3d/implicitron/models/view_pooler/feature_aggregator.py similarity index 92% rename from pytorch3d/implicitron/models/view_pooling/feature_aggregation.py rename to pytorch3d/implicitron/models/view_pooler/feature_aggregator.py index 9cdde134..26b8b7db 100644 --- a/pytorch3d/implicitron/models/view_pooling/feature_aggregation.py +++ b/pytorch3d/implicitron/models/view_pooler/feature_aggregator.py @@ -10,7 +10,7 @@ from typing import Dict, Optional, Sequence, Union import torch import torch.nn.functional as F -from pytorch3d.implicitron.models.view_pooling.view_sampling import ( +from pytorch3d.implicitron.models.view_pooler.view_sampler import ( cameras_points_cartesian_product, ) from pytorch3d.implicitron.tools.config import registry, ReplaceableBase @@ -82,6 +82,33 @@ class FeatureAggregatorBase(ABC, ReplaceableBase): """ raise NotImplementedError() + @abstractmethod + def get_aggregated_feature_dim( + self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] + ): + """ + Returns the final dimensionality of the output aggregated features. + + Args: + feats_or_feats_dim: Either a `dict` of sampled features `{f_i: t_i}` corresponding + to the `feats_sampled` argument of `forward`, + or an `int` representing the sum of dimensionalities of each `t_i`. + + Returns: + aggregated_feature_dim: The final dimensionality of the output + aggregated features. + """ + raise NotImplementedError() + + def has_aggregation(self) -> bool: + """ + Specifies whether the aggregator reduces the output `reduce_dim` dimension to 1. + + Returns: + has_aggregation: `True` if `reduce_dim==1`, else `False`. + """ + return hasattr(self, "reduction_functions") + @registry.register class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): @@ -94,8 +121,10 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): def __post_init__(self): super().__init__() - def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): - return _get_reduction_aggregator_feature_dim(feats, []) + def get_aggregated_feature_dim( + self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] + ): + return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, []) def forward( self, @@ -155,8 +184,12 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): def __post_init__(self): super().__init__() - def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): - return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions) + def get_aggregated_feature_dim( + self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] + ): + return _get_reduction_aggregator_feature_dim( + feats_or_feats_dim, self.reduction_functions + ) def forward( self, @@ -246,8 +279,12 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator def __post_init__(self): super().__init__() - def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): - return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions) + def get_aggregated_feature_dim( + self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] + ): + return _get_reduction_aggregator_feature_dim( + feats_or_feats_dim, self.reduction_functions + ) def forward( self, @@ -345,8 +382,10 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB def __post_init__(self): super().__init__() - def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): - return _get_reduction_aggregator_feature_dim(feats, []) + def get_aggregated_feature_dim( + self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] + ): + return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, []) def forward( self, diff --git a/pytorch3d/implicitron/models/view_pooler/view_pooler.py b/pytorch3d/implicitron/models/view_pooler/view_pooler.py new file mode 100644 index 00000000..eca64b30 --- /dev/null +++ b/pytorch3d/implicitron/models/view_pooler/view_pooler.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Union + +import torch +from pytorch3d.implicitron.tools.config import Configurable, run_auto_creation +from pytorch3d.renderer.cameras import CamerasBase + +from .feature_aggregator import FeatureAggregatorBase +from .view_sampler import ViewSampler + + +# pyre-ignore: 13 +class ViewPooler(Configurable, torch.nn.Module): + """ + Implements sampling of image-based features at the 2d projections of a set + of 3D points, and a subsequent aggregation of the resulting set of features + per-point. + + Args: + view_sampler: An instance of ViewSampler which is used for sampling of + image-based features at the 2D projections of a set + of 3D points. + feature_aggregator_class_type: The name of the feature aggregator class which + is available in the global registry. + feature_aggregator: A feature aggregator class which inherits from + FeatureAggregatorBase. Typically, the aggregated features and their + masks are output by a `ViewSampler` which samples feature tensors extracted + from a set of source images. FeatureAggregator executes step (4) above. + """ + + view_sampler: ViewSampler + feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator" + feature_aggregator: FeatureAggregatorBase + + def __post_init__(self): + super().__init__() + run_auto_creation(self) + + def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): + """ + Returns the final dimensionality of the output aggregated features. + + Args: + feats: Either a `dict` of sampled features `{f_i: t_i}` corresponding + to the `feats_sampled` argument of `feature_aggregator,forward`, + or an `int` representing the sum of dimensionalities of each `t_i`. + + Returns: + aggregated_feature_dim: The final dimensionality of the output + aggregated features. + """ + return self.feature_aggregator.get_aggregated_feature_dim(feats) + + def has_aggregation(self): + """ + Specifies whether the `feature_aggregator` reduces the output `reduce_dim` + dimension to 1. + + Returns: + has_aggregation: `True` if `reduce_dim==1`, else `False`. + """ + return self.feature_aggregator.has_aggregation() + + def forward( + self, + *, # force kw args + pts: torch.Tensor, + seq_id_pts: Union[List[int], List[str], torch.LongTensor], + camera: CamerasBase, + seq_id_camera: Union[List[int], List[str], torch.LongTensor], + feats: Dict[str, torch.Tensor], + masks: Optional[torch.Tensor], + **kwargs, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Project each point cloud from a batch of point clouds to corresponding + input cameras, sample features at the 2D projection locations in a batch + of source images, and aggregate the pointwise sampled features. + + Args: + pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords. + seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes + from which `pts` were extracted, or a list of string names. + camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`. + seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes + corresponding to cameras in `camera`, or a list of string names. + feats: a dict of tensors of per-image features `{feat_i: T_i}`. + Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`. + masks: `[n_cameras x 1 x H x W]`, define valid image regions + for sampling `feats`. + Returns: + feats_aggregated: If `feature_aggregator.concatenate_output==True`, a tensor + of shape `(pts_batch, reduce_dim, n_pts, sum(dim_1, ... dim_N))` + containing the aggregated features. `reduce_dim` depends on + the specific feature aggregator implementation and typically + equals 1 or `n_cameras`. + If `feature_aggregator.concatenate_output==False`, the aggregator + does not concatenate the aggregated features and returns a dictionary + of per-feature aggregations `{f_i: t_i_aggregated}` instead. + Each `t_i_aggregated` is of shape + `(pts_batch, reduce_dim, n_pts, aggr_dim_i)`. + """ + + # (1) Sample features and masks at the ray points + sampled_feats, sampled_masks = self.view_sampler( + pts=pts, + seq_id_pts=seq_id_pts, + camera=camera, + seq_id_camera=seq_id_camera, + feats=feats, + masks=masks, + ) + + # (2) Aggregate features from multiple views + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + feats_aggregated = self.feature_aggregator( # noqa: E731 + sampled_feats, + sampled_masks, + pts=pts, + camera=camera, + ) # TODO: do we need to pass a callback rather than compute here? + + return feats_aggregated diff --git a/pytorch3d/implicitron/models/view_pooling/view_sampling.py b/pytorch3d/implicitron/models/view_pooler/view_sampler.py similarity index 100% rename from pytorch3d/implicitron/models/view_pooling/view_sampling.py rename to pytorch3d/implicitron/models/view_pooler/view_sampler.py diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index 1f748d3b..19420e23 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -8,7 +8,6 @@ bg_color: - 0.0 - 0.0 - 0.0 -view_pool: false num_passes: 1 chunk_size_grid: 4096 render_features_dimensions: 3 @@ -17,7 +16,8 @@ n_train_target_views: 1 sampling_mode_training: mask_sample sampling_mode_evaluation: full_grid renderer_class_type: LSTMRenderer -feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator +image_feature_extractor_enabled: true +view_pooler_enabled: true implicit_function_class_type: IdrFeatureField loss_weights: loss_rgb_mse: 1.0 @@ -91,15 +91,17 @@ image_feature_extractor_args: add_images: true global_average_pool: false feature_rescale: 1.0 -view_sampler_args: - masked_sampling: false - sampling_mode: bilinear -feature_aggregator_AngleWeightedIdentityFeatureAggregator_args: - exclude_target_view: true - exclude_target_view_mask_features: true - concatenate_output: true - weight_by_ray_angle_gamma: 1.0 - min_ray_angle_weight: 0.1 +view_pooler_args: + feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator + view_sampler_args: + masked_sampling: false + sampling_mode: bilinear + feature_aggregator_AngleWeightedIdentityFeatureAggregator_args: + exclude_target_view: true + exclude_target_view_mask_features: true + concatenate_output: true + weight_by_ray_angle_gamma: 1.0 + min_ray_angle_weight: 0.1 implicit_function_IdrFeatureField_args: feature_vector_size: 3 d_in: 3 diff --git a/tests/implicitron/test_config_use.py b/tests/implicitron/test_config_use.py index a4f94343..6dc0a75a 100644 --- a/tests/implicitron/test_config_use.py +++ b/tests/implicitron/test_config_use.py @@ -20,9 +20,8 @@ from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer from pytorch3d.implicitron.models.renderer.multipass_ea import ( MultiPassEmissionAbsorptionRenderer, ) -from pytorch3d.implicitron.models.view_pooling.feature_aggregation import ( +from pytorch3d.implicitron.models.view_pooler.feature_aggregator import ( AngleWeightedIdentityFeatureAggregator, - AngleWeightedReductionFeatureAggregator, ) from pytorch3d.implicitron.tools.config import ( get_default_args, @@ -32,7 +31,10 @@ from pytorch3d.implicitron.tools.config import ( if os.environ.get("FB_TEST", False): from common_testing import get_tests_dir + + from .common_resources import provide_lpips_vgg else: + from common_resources import provide_lpips_vgg # noqa from tests.common_testing import get_tests_dir DATA_DIR = get_tests_dir() / "implicitron/data" @@ -46,28 +48,33 @@ class TestGenericModel(unittest.TestCase): self.maxDiff = None def test_create_gm(self): + provide_lpips_vgg() args = get_default_args(GenericModel) gm = GenericModel(**args) self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer) - self.assertIsInstance( - gm.feature_aggregator, AngleWeightedReductionFeatureAggregator - ) self.assertIsInstance( gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction ) self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) self.assertFalse(hasattr(gm, "implicit_function")) - self.assertFalse(hasattr(gm, "image_feature_extractor")) + self.assertIsNone(gm.view_pooler) + self.assertIsNone(gm.image_feature_extractor) - def test_create_gm_overrides(self): + def _test_create_gm_overrides(self): + provide_lpips_vgg() args = get_default_args(GenericModel) - args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator" + args.view_pooler_enabled = True + args.image_feature_extractor_enabled = True + args.view_pooler_args.feature_aggregator_class_type = ( + "AngleWeightedIdentityFeatureAggregator" + ) args.implicit_function_class_type = "IdrFeatureField" args.renderer_class_type = "LSTMRenderer" gm = GenericModel(**args) self.assertIsInstance(gm.renderer, LSTMRenderer) self.assertIsInstance( - gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator + gm.view_pooler.feature_aggregator, + AngleWeightedIdentityFeatureAggregator, ) self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) diff --git a/tests/implicitron/test_forward_pass.py b/tests/implicitron/test_forward_pass.py index 9829ea9b..b2028dac 100644 --- a/tests/implicitron/test_forward_pass.py +++ b/tests/implicitron/test_forward_pass.py @@ -56,7 +56,12 @@ class TestGenericModel(unittest.TestCase): cfg = _load_model_config_from_yaml(str(config_file)) model = GenericModel(**cfg) model.to(device) - self._one_model_test(model, device, eval_test=True) + self._one_model_test( + model, + device, + eval_test=True, + bw_test=True, + ) def _one_model_test( self, @@ -64,6 +69,7 @@ class TestGenericModel(unittest.TestCase): device, n_train_cameras: int = 5, eval_test: bool = True, + bw_test: bool = True, ): R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) @@ -86,8 +92,12 @@ class TestGenericModel(unittest.TestCase): **random_args, evaluation_mode=EvaluationMode.TRAINING, ) - self.assertGreater(train_preds["objective"].item(), 0) - train_preds["objective"].backward() + self.assertTrue( + train_preds["objective"].isfinite().item() + ) # check finiteness of the objective + + if bw_test: + train_preds["objective"].backward() if eval_test: model.eval() diff --git a/tests/implicitron/test_viewsampling.py b/tests/implicitron/test_viewsampling.py index dd438eb9..4094bf4a 100644 --- a/tests/implicitron/test_viewsampling.py +++ b/tests/implicitron/test_viewsampling.py @@ -9,7 +9,7 @@ import unittest import pytorch3d as pt3d import torch -from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler +from pytorch3d.implicitron.models.view_pooler.view_sampler import ViewSampler from pytorch3d.implicitron.tools.config import expand_args_fields