Make feature extractor pluggable

Summary: Make ResNetFeatureExtractor be an implementation of FeatureExtractorBase.

Reviewed By: davnov134

Differential Revision: D35433098

fbshipit-source-id: 0664a9166a88e150231cfe2eceba017ae55aed3a
This commit is contained in:
Jeremy Reizenstein 2022-05-18 08:50:18 -07:00 committed by Facebook GitHub Bot
parent cd7b885169
commit 9ec9d057cc
12 changed files with 151 additions and 57 deletions

View File

@ -224,7 +224,8 @@ generic_model_args: GenericModel
└-- hypernet_args: SRNRaymarchHyperNet └-- hypernet_args: SRNRaymarchHyperNet
└-- pixel_generator_args: SRNPixelGenerator └-- pixel_generator_args: SRNPixelGenerator
╘== IdrFeatureField ╘== IdrFeatureField
└-- image_feature_extractor_args: ResNetFeatureExtractor └-- image_feature_extractor_*_args: FeatureExtractorBase
╘== ResNetFeatureExtractor
└-- view_sampler_args: ViewSampler └-- view_sampler_args: ViewSampler
└-- feature_aggregator_*_args: FeatureAggregatorBase └-- feature_aggregator_*_args: FeatureAggregatorBase
╘== IdentityFeatureAggregator ╘== IdentityFeatureAggregator

View File

@ -64,7 +64,7 @@ generic_model_args:
view_pooler_args: view_pooler_args:
view_sampler_args: view_sampler_args:
masked_sampling: false masked_sampling: false
image_feature_extractor_args: image_feature_extractor_ResNetFeatureExtractor_args:
stages: stages:
- 1 - 1
- 2 - 2

View File

@ -1,6 +1,6 @@
generic_model_args: generic_model_args:
image_feature_extractor_enabled: true image_feature_extractor_class_type: ResNetFeatureExtractor
image_feature_extractor_args: image_feature_extractor_ResNetFeatureExtractor_args:
add_images: true add_images: true
add_masks: true add_masks: true
first_max_pool: true first_max_pool: true

View File

@ -1,6 +1,6 @@
generic_model_args: generic_model_args:
image_feature_extractor_enabled: true image_feature_extractor_class_type: ResNetFeatureExtractor
image_feature_extractor_args: image_feature_extractor_ResNetFeatureExtractor_args:
add_images: true add_images: true
add_masks: true add_masks: true
first_max_pool: false first_max_pool: false

View File

@ -1,6 +1,6 @@
generic_model_args: generic_model_args:
image_feature_extractor_enabled: true image_feature_extractor_class_type: ResNetFeatureExtractor
image_feature_extractor_args: image_feature_extractor_ResNetFeatureExtractor_args:
stages: stages:
- 1 - 1
- 2 - 2

View File

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .feature_extractor import FeatureExtractorBase

View File

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
"""
Base class for an extractor of a set of features from images.
"""
def __init__(self):
super().__init__()
def get_feat_dims(self) -> int:
"""
Returns:
total number of feature dimensions of the output.
(i.e. sum_i(dim_i))
"""
raise NotImplementedError
def forward(
self,
imgs: Optional[torch.Tensor],
masks: Optional[torch.Tensor] = None,
**kwargs
) -> Dict[Any, torch.Tensor]:
"""
Args:
imgs: A batch of input images of shape `(B, 3, H, W)`.
masks: A batch of input masks of shape `(B, 3, H, W)`.
Returns:
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
"""
raise NotImplementedError

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import logging import logging
import math import math
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -12,7 +11,9 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import torch.nn.functional as Fu import torch.nn.functional as Fu
import torchvision import torchvision
from pytorch3d.implicitron.tools.config import Configurable from pytorch3d.implicitron.tools.config import registry
from . import FeatureExtractorBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,7 +33,8 @@ _RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225] _RESNET_STD = [0.229, 0.224, 0.225]
class ResNetFeatureExtractor(Configurable, torch.nn.Module): @registry.register
class ResNetFeatureExtractor(FeatureExtractorBase):
""" """
Implements an image feature extractor. Depending on the settings allows Implements an image feature extractor. Depending on the settings allows
to extract: to extract:
@ -142,14 +144,14 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
return (img - self._resnet_mean) / self._resnet_std return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self) -> int: def get_feat_dims(self) -> int:
return ( # pyre-fixme[29]
sum(self._feat_dim.values()) # pyre-fixme[29] return sum(self._feat_dim.values())
if len(self._feat_dim) > 0 # pyre-fixme[6]
else 0
)
def forward( def forward(
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None self,
imgs: Optional[torch.Tensor],
masks: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[Any, torch.Tensor]: ) -> Dict[Any, torch.Tensor]:
""" """
Args: Args:
@ -164,7 +166,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
out_feats = {} out_feats = {}
imgs_input = imgs imgs_input = imgs
if self.image_rescale != 1.0: if self.image_rescale != 1.0 and imgs_input is not None:
imgs_resized = Fu.interpolate( imgs_resized = Fu.interpolate(
imgs_input, imgs_input,
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
@ -175,12 +177,13 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
else: else:
imgs_resized = imgs_input imgs_resized = imgs_input
if self.normalize_image:
imgs_normed = self._resnet_normalize_image(imgs_resized)
else:
imgs_normed = imgs_resized
if len(self.stages) > 0: if len(self.stages) > 0:
assert imgs_resized is not None
if self.normalize_image:
imgs_normed = self._resnet_normalize_image(imgs_resized)
else:
imgs_normed = imgs_resized
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]` # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
# is not a function. # is not a function.
feats = self.stem(imgs_normed) feats = self.stem(imgs_normed)
@ -207,7 +210,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
out_feats[MASK_FEATURE_NAME] = masks out_feats[MASK_FEATURE_NAME] = masks
if self.add_images: if self.add_images:
assert imgs_input is not None assert imgs_resized is not None
out_feats[IMAGE_FEATURE_NAME] = imgs_resized out_feats[IMAGE_FEATURE_NAME] = imgs_resized
if self.feature_rescale != 1.0: if self.feature_rescale != 1.0:

View File

@ -5,6 +5,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# Note: The #noqa comments below are for unused imports of pluggable implementations
# which are part of implicitron. They ensure that the registry is prepopulated.
import logging import logging
import math import math
import warnings import warnings
@ -27,6 +30,8 @@ from visdom import Visdom
from .autodecoder import Autodecoder from .autodecoder import Autodecoder
from .base_model import ImplicitronModelBase, ImplicitronRender from .base_model import ImplicitronModelBase, ImplicitronRender
from .feature_extractor import FeatureExtractorBase
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
from .implicit_function.base import ImplicitFunctionBase from .implicit_function.base import ImplicitFunctionBase
from .implicit_function.idr_feature_field import IdrFeatureField # noqa from .implicit_function.idr_feature_field import IdrFeatureField # noqa
from .implicit_function.neural_radiance_field import ( # noqa from .implicit_function.neural_radiance_field import ( # noqa
@ -49,7 +54,6 @@ from .renderer.lstm_renderer import LSTMRenderer # noqa
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
from .renderer.ray_sampler import RaySamplerBase from .renderer.ray_sampler import RaySamplerBase
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .resnet_feature_extractor import ResNetFeatureExtractor
from .view_pooler.view_pooler import ViewPooler from .view_pooler.view_pooler import ViewPooler
@ -139,9 +143,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
splatting onto an image grid. Default: False. splatting onto an image grid. Default: False.
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0) bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
view_pool: If True, features are sampled from the source image(s)
at the projected 2d locations of the sampled 3d ray points from the target
view(s), i.e. this activates step (3) above.
num_passes: The specified implicit_function is initialized num_passes num_passes: The specified implicit_function is initialized num_passes
times and run sequentially. times and run sequentially.
chunk_size_grid: The total number of points which can be rendered chunk_size_grid: The total number of points which can be rendered
@ -169,10 +170,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
registry. registry.
renderer: A renderer class which inherits from BaseRenderer. This is used to renderer: A renderer class which inherits from BaseRenderer. This is used to
generate the images from the target view(s). generate the images from the target view(s).
image_feature_extractor_enabled: If `True`, constructs and enables image_feature_extractor_class_type: If a str, constructs and enables
the `image_feature_extractor` object. the `image_feature_extractor` object of this type. Or None if not needed.
image_feature_extractor: A module for extrating features from an input image. image_feature_extractor: A module for extrating features from an input image.
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object. view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
This means features are sampled from the source image(s)
at the projected 2d locations of the sampled 3d ray points from the target
view(s), i.e. this activates step (3) above.
view_pooler: An instance of ViewPooler which is used for sampling of view_pooler: An instance of ViewPooler which is used for sampling of
image-based features at the 2D projections of a set image-based features at the 2D projections of a set
of 3D points and aggregating the sampled features. of 3D points and aggregating the sampled features.
@ -215,8 +219,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
renderer: BaseRenderer renderer: BaseRenderer
# ---- image feature extractor settings # ---- image feature extractor settings
image_feature_extractor_enabled: bool = False # (This is only created if view_pooler is enabled)
image_feature_extractor: Optional[ResNetFeatureExtractor] image_feature_extractor: Optional[FeatureExtractorBase]
image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor"
# ---- view pooler settings # ---- view pooler settings
view_pooler_enabled: bool = False view_pooler_enabled: bool = False
view_pooler: Optional[ViewPooler] view_pooler: Optional[ViewPooler]
@ -266,6 +271,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
super().__init__() super().__init__()
self.view_metrics = ViewMetrics() self.view_metrics = ViewMetrics()
if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None:
raise ValueError(
"image_feature_extractor must be present for view pooling."
)
else:
self.image_feature_extractor_class_type = None
run_auto_creation(self) run_auto_creation(self)
self._implicit_functions = self._construct_implicit_functions() self._implicit_functions = self._construct_implicit_functions()
@ -349,20 +361,16 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# custom_args hold additional arguments to the implicit function. # custom_args hold additional arguments to the implicit function.
custom_args = {} custom_args = {}
if self.image_feature_extractor_enabled: if self.image_feature_extractor is not None:
# (2) Extract features for the image # (2) Extract features for the image
img_feats = self.image_feature_extractor( # pyre-fixme[29] img_feats = self.image_feature_extractor(image_rgb, fg_probability)
image_rgb, fg_probability else:
) img_feats = None
if self.view_pooler_enabled: if self.view_pooler_enabled:
if sequence_name is None: if sequence_name is None:
raise ValueError("sequence_name must be provided for view pooling") raise ValueError("sequence_name must be provided for view pooling")
if not self.image_feature_extractor_enabled: assert img_feats is not None
raise ValueError(
"image_feature_extractor has to be enabled for for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
# (3-4) Sample features and masks at the ray points. # (3-4) Sample features and masks at the ray points.
# Aggregate features from multiple views. # Aggregate features from multiple views.
@ -555,13 +563,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
**kwargs, **kwargs,
) )
def _get_viewpooled_feature_dim(self): def _get_viewpooled_feature_dim(self) -> int:
return ( if self.view_pooler is None:
self.view_pooler.get_aggregated_feature_dim( return 0
self.image_feature_extractor.get_feat_dims() assert self.image_feature_extractor is not None
) return self.view_pooler.get_aggregated_feature_dim(
if self.view_pooler_enabled self.image_feature_extractor.get_feat_dims()
else 0
) )
def create_raysampler(self): def create_raysampler(self):
@ -617,11 +624,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
that image_feature_extractor is enabled when view_pooler is enabled. that image_feature_extractor is enabled when view_pooler is enabled.
""" """
if self.view_pooler_enabled: if self.view_pooler_enabled:
if not self.image_feature_extractor_enabled:
raise ValueError(
"image_feature_extractor has to be enabled for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
self.view_pooler = ViewPooler(**self.view_pooler_args) self.view_pooler = ViewPooler(**self.view_pooler_args)
else: else:
self.view_pooler = None self.view_pooler = None

View File

@ -17,7 +17,7 @@ sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid sampling_mode_evaluation: full_grid
raysampler_class_type: AdaptiveRaySampler raysampler_class_type: AdaptiveRaySampler
renderer_class_type: LSTMRenderer renderer_class_type: LSTMRenderer
image_feature_extractor_enabled: true image_feature_extractor_class_type: ResNetFeatureExtractor
view_pooler_enabled: true view_pooler_enabled: true
implicit_function_class_type: IdrFeatureField implicit_function_class_type: IdrFeatureField
loss_weights: loss_weights:
@ -73,7 +73,7 @@ renderer_LSTMRenderer_args:
hidden_size: 16 hidden_size: 16
n_feature_channels: 256 n_feature_channels: 256
verbose: false verbose: false
image_feature_extractor_args: image_feature_extractor_ResNetFeatureExtractor_args:
name: resnet34 name: resnet34
pretrained: true pretrained: true
stages: stages:

View File

@ -9,6 +9,9 @@ import unittest
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.generic_model import GenericModel from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField, IdrFeatureField,
@ -63,7 +66,6 @@ class TestGenericModel(unittest.TestCase):
provide_resnet34() provide_resnet34()
args = get_default_args(GenericModel) args = get_default_args(GenericModel)
args.view_pooler_enabled = True args.view_pooler_enabled = True
args.image_feature_extractor_enabled = True
args.view_pooler_args.feature_aggregator_class_type = ( args.view_pooler_args.feature_aggregator_class_type = (
"AngleWeightedIdentityFeatureAggregator" "AngleWeightedIdentityFeatureAggregator"
) )
@ -77,9 +79,13 @@ class TestGenericModel(unittest.TestCase):
) )
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
self.assertFalse(hasattr(gm, "implicit_function")) self.assertFalse(hasattr(gm, "implicit_function"))
instance_args = OmegaConf.structured(gm) instance_args = OmegaConf.structured(gm)
if DEBUG:
full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
(DATA_DIR / "overrides_full.yaml").write_text(full_yaml)
remove_unused_components(instance_args) remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG: if DEBUG:

View File

@ -33,7 +33,7 @@ class TestGenericModel(unittest.TestCase):
# Simple test of a forward and backward pass of the default GenericModel. # Simple test of a forward and backward pass of the default GenericModel.
device = torch.device("cuda:1") device = torch.device("cuda:1")
expand_args_fields(GenericModel) expand_args_fields(GenericModel)
model = GenericModel() model = GenericModel(render_image_height=80, render_image_width=80)
model.to(device) model.to(device)
self._one_model_test(model, device) self._one_model_test(model, device)
@ -149,6 +149,37 @@ class TestGenericModel(unittest.TestCase):
) )
self.assertGreater(train_preds["objective"].item(), 0) self.assertGreater(train_preds["objective"].item(), 0)
def test_viewpool(self):
device = torch.device("cuda:1")
args = get_default_args(GenericModel)
args.view_pooler_enabled = True
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False
model = GenericModel(**args)
model.to(device)
n_train_cameras = 2
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
cameras = PerspectiveCameras(R=R, T=T, device=device)
defaulted_args = {
"fg_probability": None,
"depth_map": None,
"mask_crop": None,
}
target_image_rgb = torch.rand(
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
device=device,
)
train_preds = model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
image_rgb=target_image_rgb,
sequence_name=["a"] * n_train_cameras,
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)
def _random_input_tensor( def _random_input_tensor(
N: int, N: int,