From 9ec9d057cc70f2d951de3da8914d69280b9deade Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 18 May 2022 08:50:18 -0700 Subject: [PATCH] Make feature extractor pluggable Summary: Make ResNetFeatureExtractor be an implementation of FeatureExtractorBase. Reviewed By: davnov134 Differential Revision: D35433098 fbshipit-source-id: 0664a9166a88e150231cfe2eceba017ae55aed3a --- projects/implicitron_trainer/README.md | 3 +- .../configs/repro_base.yaml | 2 +- .../configs/repro_feat_extractor_normed.yaml | 4 +- .../repro_feat_extractor_transformer.yaml | 4 +- .../repro_feat_extractor_unnormed.yaml | 4 +- .../models/feature_extractor/__init__.py | 7 +++ .../feature_extractor/feature_extractor.py | 44 ++++++++++++++ .../resnet_feature_extractor.py | 35 ++++++----- pytorch3d/implicitron/models/generic_model.py | 60 ++++++++++--------- tests/implicitron/data/overrides.yaml | 4 +- tests/implicitron/test_config_use.py | 8 ++- tests/implicitron/test_forward_pass.py | 33 +++++++++- 12 files changed, 151 insertions(+), 57 deletions(-) create mode 100644 pytorch3d/implicitron/models/feature_extractor/__init__.py create mode 100644 pytorch3d/implicitron/models/feature_extractor/feature_extractor.py rename pytorch3d/implicitron/models/{ => feature_extractor}/resnet_feature_extractor.py (92%) diff --git a/projects/implicitron_trainer/README.md b/projects/implicitron_trainer/README.md index be12fc58..7b4eb72b 100644 --- a/projects/implicitron_trainer/README.md +++ b/projects/implicitron_trainer/README.md @@ -224,7 +224,8 @@ generic_model_args: GenericModel └-- hypernet_args: SRNRaymarchHyperNet └-- pixel_generator_args: SRNPixelGenerator ╘== IdrFeatureField -└-- image_feature_extractor_args: ResNetFeatureExtractor +└-- image_feature_extractor_*_args: FeatureExtractorBase + ╘== ResNetFeatureExtractor └-- view_sampler_args: ViewSampler └-- feature_aggregator_*_args: FeatureAggregatorBase ╘== IdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index 92c12ac1..595e69e5 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -64,7 +64,7 @@ generic_model_args: view_pooler_args: view_sampler_args: masked_sampling: false - image_feature_extractor_args: + image_feature_extractor_ResNetFeatureExtractor_args: stages: - 1 - 2 diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml index 88cfead0..1ea74b23 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml @@ -1,6 +1,6 @@ generic_model_args: - image_feature_extractor_enabled: true - image_feature_extractor_args: + image_feature_extractor_class_type: ResNetFeatureExtractor + image_feature_extractor_ResNetFeatureExtractor_args: add_images: true add_masks: true first_max_pool: true diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml index c45c65a9..734ab43e 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml @@ -1,6 +1,6 @@ generic_model_args: - image_feature_extractor_enabled: true - image_feature_extractor_args: + image_feature_extractor_class_type: ResNetFeatureExtractor + image_feature_extractor_ResNetFeatureExtractor_args: add_images: true add_masks: true first_max_pool: false diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml index 8039086c..bc7bc37e 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml @@ -1,6 +1,6 @@ generic_model_args: - image_feature_extractor_enabled: true - image_feature_extractor_args: + image_feature_extractor_class_type: ResNetFeatureExtractor + image_feature_extractor_ResNetFeatureExtractor_args: stages: - 1 - 2 diff --git a/pytorch3d/implicitron/models/feature_extractor/__init__.py b/pytorch3d/implicitron/models/feature_extractor/__init__.py new file mode 100644 index 00000000..9141562c --- /dev/null +++ b/pytorch3d/implicitron/models/feature_extractor/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .feature_extractor import FeatureExtractorBase diff --git a/pytorch3d/implicitron/models/feature_extractor/feature_extractor.py b/pytorch3d/implicitron/models/feature_extractor/feature_extractor.py new file mode 100644 index 00000000..2d8206c4 --- /dev/null +++ b/pytorch3d/implicitron/models/feature_extractor/feature_extractor.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional + +import torch +from pytorch3d.implicitron.tools.config import ReplaceableBase + + +class FeatureExtractorBase(ReplaceableBase, torch.nn.Module): + """ + Base class for an extractor of a set of features from images. + """ + + def __init__(self): + super().__init__() + + def get_feat_dims(self) -> int: + """ + Returns: + total number of feature dimensions of the output. + (i.e. sum_i(dim_i)) + """ + raise NotImplementedError + + def forward( + self, + imgs: Optional[torch.Tensor], + masks: Optional[torch.Tensor] = None, + **kwargs + ) -> Dict[Any, torch.Tensor]: + """ + Args: + imgs: A batch of input images of shape `(B, 3, H, W)`. + masks: A batch of input masks of shape `(B, 3, H, W)`. + + Returns: + out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i` + and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`. + """ + raise NotImplementedError diff --git a/pytorch3d/implicitron/models/resnet_feature_extractor.py b/pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py similarity index 92% rename from pytorch3d/implicitron/models/resnet_feature_extractor.py rename to pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py index 27c7e4ec..8167f2cd 100644 --- a/pytorch3d/implicitron/models/resnet_feature_extractor.py +++ b/pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy import logging import math from typing import Any, Dict, Optional, Tuple @@ -12,7 +11,9 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn.functional as Fu import torchvision -from pytorch3d.implicitron.tools.config import Configurable +from pytorch3d.implicitron.tools.config import registry + +from . import FeatureExtractorBase logger = logging.getLogger(__name__) @@ -32,7 +33,8 @@ _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] -class ResNetFeatureExtractor(Configurable, torch.nn.Module): +@registry.register +class ResNetFeatureExtractor(FeatureExtractorBase): """ Implements an image feature extractor. Depending on the settings allows to extract: @@ -142,14 +144,14 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module): return (img - self._resnet_mean) / self._resnet_std def get_feat_dims(self) -> int: - return ( - sum(self._feat_dim.values()) # pyre-fixme[29] - if len(self._feat_dim) > 0 # pyre-fixme[6] - else 0 - ) + # pyre-fixme[29] + return sum(self._feat_dim.values()) def forward( - self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None + self, + imgs: Optional[torch.Tensor], + masks: Optional[torch.Tensor] = None, + **kwargs, ) -> Dict[Any, torch.Tensor]: """ Args: @@ -164,7 +166,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module): out_feats = {} imgs_input = imgs - if self.image_rescale != 1.0: + if self.image_rescale != 1.0 and imgs_input is not None: imgs_resized = Fu.interpolate( imgs_input, # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but @@ -175,12 +177,13 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module): else: imgs_resized = imgs_input - if self.normalize_image: - imgs_normed = self._resnet_normalize_image(imgs_resized) - else: - imgs_normed = imgs_resized - if len(self.stages) > 0: + assert imgs_resized is not None + + if self.normalize_image: + imgs_normed = self._resnet_normalize_image(imgs_resized) + else: + imgs_normed = imgs_resized # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]` # is not a function. feats = self.stem(imgs_normed) @@ -207,7 +210,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module): out_feats[MASK_FEATURE_NAME] = masks if self.add_images: - assert imgs_input is not None + assert imgs_resized is not None out_feats[IMAGE_FEATURE_NAME] = imgs_resized if self.feature_rescale != 1.0: diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index ffd767d5..bd433e69 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. +# Note: The #noqa comments below are for unused imports of pluggable implementations +# which are part of implicitron. They ensure that the registry is prepopulated. + import logging import math import warnings @@ -27,6 +30,8 @@ from visdom import Visdom from .autodecoder import Autodecoder from .base_model import ImplicitronModelBase, ImplicitronRender +from .feature_extractor import FeatureExtractorBase +from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa from .implicit_function.base import ImplicitFunctionBase from .implicit_function.idr_feature_field import IdrFeatureField # noqa from .implicit_function.neural_radiance_field import ( # noqa @@ -49,7 +54,6 @@ from .renderer.lstm_renderer import LSTMRenderer # noqa from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa from .renderer.ray_sampler import RaySamplerBase from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa -from .resnet_feature_extractor import ResNetFeatureExtractor from .view_pooler.view_pooler import ViewPooler @@ -139,9 +143,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by splatting onto an image grid. Default: False. bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0) - view_pool: If True, features are sampled from the source image(s) - at the projected 2d locations of the sampled 3d ray points from the target - view(s), i.e. this activates step (3) above. num_passes: The specified implicit_function is initialized num_passes times and run sequentially. chunk_size_grid: The total number of points which can be rendered @@ -169,10 +170,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 registry. renderer: A renderer class which inherits from BaseRenderer. This is used to generate the images from the target view(s). - image_feature_extractor_enabled: If `True`, constructs and enables - the `image_feature_extractor` object. + image_feature_extractor_class_type: If a str, constructs and enables + the `image_feature_extractor` object of this type. Or None if not needed. image_feature_extractor: A module for extrating features from an input image. view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object. + This means features are sampled from the source image(s) + at the projected 2d locations of the sampled 3d ray points from the target + view(s), i.e. this activates step (3) above. view_pooler: An instance of ViewPooler which is used for sampling of image-based features at the 2D projections of a set of 3D points and aggregating the sampled features. @@ -215,8 +219,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 renderer: BaseRenderer # ---- image feature extractor settings - image_feature_extractor_enabled: bool = False - image_feature_extractor: Optional[ResNetFeatureExtractor] + # (This is only created if view_pooler is enabled) + image_feature_extractor: Optional[FeatureExtractorBase] + image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor" # ---- view pooler settings view_pooler_enabled: bool = False view_pooler: Optional[ViewPooler] @@ -266,6 +271,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 super().__init__() self.view_metrics = ViewMetrics() + if self.view_pooler_enabled: + if self.image_feature_extractor_class_type is None: + raise ValueError( + "image_feature_extractor must be present for view pooling." + ) + else: + self.image_feature_extractor_class_type = None run_auto_creation(self) self._implicit_functions = self._construct_implicit_functions() @@ -349,20 +361,16 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 # custom_args hold additional arguments to the implicit function. custom_args = {} - if self.image_feature_extractor_enabled: + if self.image_feature_extractor is not None: # (2) Extract features for the image - img_feats = self.image_feature_extractor( # pyre-fixme[29] - image_rgb, fg_probability - ) + img_feats = self.image_feature_extractor(image_rgb, fg_probability) + else: + img_feats = None if self.view_pooler_enabled: if sequence_name is None: raise ValueError("sequence_name must be provided for view pooling") - if not self.image_feature_extractor_enabled: - raise ValueError( - "image_feature_extractor has to be enabled for for view pooling" - + " (I.e. set self.image_feature_extractor_enabled=True)." - ) + assert img_feats is not None # (3-4) Sample features and masks at the ray points. # Aggregate features from multiple views. @@ -555,13 +563,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 **kwargs, ) - def _get_viewpooled_feature_dim(self): - return ( - self.view_pooler.get_aggregated_feature_dim( - self.image_feature_extractor.get_feat_dims() - ) - if self.view_pooler_enabled - else 0 + def _get_viewpooled_feature_dim(self) -> int: + if self.view_pooler is None: + return 0 + assert self.image_feature_extractor is not None + return self.view_pooler.get_aggregated_feature_dim( + self.image_feature_extractor.get_feat_dims() ) def create_raysampler(self): @@ -617,11 +624,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 that image_feature_extractor is enabled when view_pooler is enabled. """ if self.view_pooler_enabled: - if not self.image_feature_extractor_enabled: - raise ValueError( - "image_feature_extractor has to be enabled for view pooling" - + " (I.e. set self.image_feature_extractor_enabled=True)." - ) self.view_pooler = ViewPooler(**self.view_pooler_args) else: self.view_pooler = None diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml index 5d65d661..eda265d0 100644 --- a/tests/implicitron/data/overrides.yaml +++ b/tests/implicitron/data/overrides.yaml @@ -17,7 +17,7 @@ sampling_mode_training: mask_sample sampling_mode_evaluation: full_grid raysampler_class_type: AdaptiveRaySampler renderer_class_type: LSTMRenderer -image_feature_extractor_enabled: true +image_feature_extractor_class_type: ResNetFeatureExtractor view_pooler_enabled: true implicit_function_class_type: IdrFeatureField loss_weights: @@ -73,7 +73,7 @@ renderer_LSTMRenderer_args: hidden_size: 16 n_feature_channels: 256 verbose: false -image_feature_extractor_args: +image_feature_extractor_ResNetFeatureExtractor_args: name: resnet34 pretrained: true stages: diff --git a/tests/implicitron/test_config_use.py b/tests/implicitron/test_config_use.py index 22cfa353..00d5c6fd 100644 --- a/tests/implicitron/test_config_use.py +++ b/tests/implicitron/test_config_use.py @@ -9,6 +9,9 @@ import unittest from omegaconf import OmegaConf from pytorch3d.implicitron.models.autodecoder import Autodecoder +from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( + ResNetFeatureExtractor, +) from pytorch3d.implicitron.models.generic_model import GenericModel from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( IdrFeatureField, @@ -63,7 +66,6 @@ class TestGenericModel(unittest.TestCase): provide_resnet34() args = get_default_args(GenericModel) args.view_pooler_enabled = True - args.image_feature_extractor_enabled = True args.view_pooler_args.feature_aggregator_class_type = ( "AngleWeightedIdentityFeatureAggregator" ) @@ -77,9 +79,13 @@ class TestGenericModel(unittest.TestCase): ) self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) + self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor) self.assertFalse(hasattr(gm, "implicit_function")) instance_args = OmegaConf.structured(gm) + if DEBUG: + full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) + (DATA_DIR / "overrides_full.yaml").write_text(full_yaml) remove_unused_components(instance_args) yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) if DEBUG: diff --git a/tests/implicitron/test_forward_pass.py b/tests/implicitron/test_forward_pass.py index b2028dac..6755f04a 100644 --- a/tests/implicitron/test_forward_pass.py +++ b/tests/implicitron/test_forward_pass.py @@ -33,7 +33,7 @@ class TestGenericModel(unittest.TestCase): # Simple test of a forward and backward pass of the default GenericModel. device = torch.device("cuda:1") expand_args_fields(GenericModel) - model = GenericModel() + model = GenericModel(render_image_height=80, render_image_width=80) model.to(device) self._one_model_test(model, device) @@ -149,6 +149,37 @@ class TestGenericModel(unittest.TestCase): ) self.assertGreater(train_preds["objective"].item(), 0) + def test_viewpool(self): + device = torch.device("cuda:1") + args = get_default_args(GenericModel) + args.view_pooler_enabled = True + args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False + model = GenericModel(**args) + model.to(device) + + n_train_cameras = 2 + R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) + cameras = PerspectiveCameras(R=R, T=T, device=device) + + defaulted_args = { + "fg_probability": None, + "depth_map": None, + "mask_crop": None, + } + + target_image_rgb = torch.rand( + (n_train_cameras, 3, model.render_image_height, model.render_image_width), + device=device, + ) + train_preds = model( + camera=cameras, + evaluation_mode=EvaluationMode.TRAINING, + image_rgb=target_image_rgb, + sequence_name=["a"] * n_train_cameras, + **defaulted_args, + ) + self.assertGreater(train_preds["objective"].item(), 0) + def _random_input_tensor( N: int,