Make feature extractor pluggable

Summary: Make ResNetFeatureExtractor be an implementation of FeatureExtractorBase.

Reviewed By: davnov134

Differential Revision: D35433098

fbshipit-source-id: 0664a9166a88e150231cfe2eceba017ae55aed3a
This commit is contained in:
Jeremy Reizenstein 2022-05-18 08:50:18 -07:00 committed by Facebook GitHub Bot
parent cd7b885169
commit 9ec9d057cc
12 changed files with 151 additions and 57 deletions

View File

@ -224,7 +224,8 @@ generic_model_args: GenericModel
└-- hypernet_args: SRNRaymarchHyperNet
└-- pixel_generator_args: SRNPixelGenerator
╘== IdrFeatureField
└-- image_feature_extractor_args: ResNetFeatureExtractor
└-- image_feature_extractor_*_args: FeatureExtractorBase
╘== ResNetFeatureExtractor
└-- view_sampler_args: ViewSampler
└-- feature_aggregator_*_args: FeatureAggregatorBase
╘== IdentityFeatureAggregator

View File

@ -64,7 +64,7 @@ generic_model_args:
view_pooler_args:
view_sampler_args:
masked_sampling: false
image_feature_extractor_args:
image_feature_extractor_ResNetFeatureExtractor_args:
stages:
- 1
- 2

View File

@ -1,6 +1,6 @@
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
image_feature_extractor_class_type: ResNetFeatureExtractor
image_feature_extractor_ResNetFeatureExtractor_args:
add_images: true
add_masks: true
first_max_pool: true

View File

@ -1,6 +1,6 @@
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
image_feature_extractor_class_type: ResNetFeatureExtractor
image_feature_extractor_ResNetFeatureExtractor_args:
add_images: true
add_masks: true
first_max_pool: false

View File

@ -1,6 +1,6 @@
generic_model_args:
image_feature_extractor_enabled: true
image_feature_extractor_args:
image_feature_extractor_class_type: ResNetFeatureExtractor
image_feature_extractor_ResNetFeatureExtractor_args:
stages:
- 1
- 2

View File

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .feature_extractor import FeatureExtractorBase

View File

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
"""
Base class for an extractor of a set of features from images.
"""
def __init__(self):
super().__init__()
def get_feat_dims(self) -> int:
"""
Returns:
total number of feature dimensions of the output.
(i.e. sum_i(dim_i))
"""
raise NotImplementedError
def forward(
self,
imgs: Optional[torch.Tensor],
masks: Optional[torch.Tensor] = None,
**kwargs
) -> Dict[Any, torch.Tensor]:
"""
Args:
imgs: A batch of input images of shape `(B, 3, H, W)`.
masks: A batch of input masks of shape `(B, 3, H, W)`.
Returns:
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
"""
raise NotImplementedError

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
import math
from typing import Any, Dict, Optional, Tuple
@ -12,7 +11,9 @@ from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as Fu
import torchvision
from pytorch3d.implicitron.tools.config import Configurable
from pytorch3d.implicitron.tools.config import registry
from . import FeatureExtractorBase
logger = logging.getLogger(__name__)
@ -32,7 +33,8 @@ _RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
class ResNetFeatureExtractor(Configurable, torch.nn.Module):
@registry.register
class ResNetFeatureExtractor(FeatureExtractorBase):
"""
Implements an image feature extractor. Depending on the settings allows
to extract:
@ -142,14 +144,14 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self) -> int:
return (
sum(self._feat_dim.values()) # pyre-fixme[29]
if len(self._feat_dim) > 0 # pyre-fixme[6]
else 0
)
# pyre-fixme[29]
return sum(self._feat_dim.values())
def forward(
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
self,
imgs: Optional[torch.Tensor],
masks: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[Any, torch.Tensor]:
"""
Args:
@ -164,7 +166,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
out_feats = {}
imgs_input = imgs
if self.image_rescale != 1.0:
if self.image_rescale != 1.0 and imgs_input is not None:
imgs_resized = Fu.interpolate(
imgs_input,
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
@ -175,12 +177,13 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
else:
imgs_resized = imgs_input
if self.normalize_image:
imgs_normed = self._resnet_normalize_image(imgs_resized)
else:
imgs_normed = imgs_resized
if len(self.stages) > 0:
assert imgs_resized is not None
if self.normalize_image:
imgs_normed = self._resnet_normalize_image(imgs_resized)
else:
imgs_normed = imgs_resized
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
# is not a function.
feats = self.stem(imgs_normed)
@ -207,7 +210,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
out_feats[MASK_FEATURE_NAME] = masks
if self.add_images:
assert imgs_input is not None
assert imgs_resized is not None
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
if self.feature_rescale != 1.0:

View File

@ -5,6 +5,9 @@
# LICENSE file in the root directory of this source tree.
# Note: The #noqa comments below are for unused imports of pluggable implementations
# which are part of implicitron. They ensure that the registry is prepopulated.
import logging
import math
import warnings
@ -27,6 +30,8 @@ from visdom import Visdom
from .autodecoder import Autodecoder
from .base_model import ImplicitronModelBase, ImplicitronRender
from .feature_extractor import FeatureExtractorBase
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
from .implicit_function.base import ImplicitFunctionBase
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
from .implicit_function.neural_radiance_field import ( # noqa
@ -49,7 +54,6 @@ from .renderer.lstm_renderer import LSTMRenderer # noqa
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
from .renderer.ray_sampler import RaySamplerBase
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .resnet_feature_extractor import ResNetFeatureExtractor
from .view_pooler.view_pooler import ViewPooler
@ -139,9 +143,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
splatting onto an image grid. Default: False.
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
view_pool: If True, features are sampled from the source image(s)
at the projected 2d locations of the sampled 3d ray points from the target
view(s), i.e. this activates step (3) above.
num_passes: The specified implicit_function is initialized num_passes
times and run sequentially.
chunk_size_grid: The total number of points which can be rendered
@ -169,10 +170,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
registry.
renderer: A renderer class which inherits from BaseRenderer. This is used to
generate the images from the target view(s).
image_feature_extractor_enabled: If `True`, constructs and enables
the `image_feature_extractor` object.
image_feature_extractor_class_type: If a str, constructs and enables
the `image_feature_extractor` object of this type. Or None if not needed.
image_feature_extractor: A module for extrating features from an input image.
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
This means features are sampled from the source image(s)
at the projected 2d locations of the sampled 3d ray points from the target
view(s), i.e. this activates step (3) above.
view_pooler: An instance of ViewPooler which is used for sampling of
image-based features at the 2D projections of a set
of 3D points and aggregating the sampled features.
@ -215,8 +219,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
renderer: BaseRenderer
# ---- image feature extractor settings
image_feature_extractor_enabled: bool = False
image_feature_extractor: Optional[ResNetFeatureExtractor]
# (This is only created if view_pooler is enabled)
image_feature_extractor: Optional[FeatureExtractorBase]
image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor"
# ---- view pooler settings
view_pooler_enabled: bool = False
view_pooler: Optional[ViewPooler]
@ -266,6 +271,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
super().__init__()
self.view_metrics = ViewMetrics()
if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None:
raise ValueError(
"image_feature_extractor must be present for view pooling."
)
else:
self.image_feature_extractor_class_type = None
run_auto_creation(self)
self._implicit_functions = self._construct_implicit_functions()
@ -349,20 +361,16 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
# custom_args hold additional arguments to the implicit function.
custom_args = {}
if self.image_feature_extractor_enabled:
if self.image_feature_extractor is not None:
# (2) Extract features for the image
img_feats = self.image_feature_extractor( # pyre-fixme[29]
image_rgb, fg_probability
)
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
else:
img_feats = None
if self.view_pooler_enabled:
if sequence_name is None:
raise ValueError("sequence_name must be provided for view pooling")
if not self.image_feature_extractor_enabled:
raise ValueError(
"image_feature_extractor has to be enabled for for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
assert img_feats is not None
# (3-4) Sample features and masks at the ray points.
# Aggregate features from multiple views.
@ -555,13 +563,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
**kwargs,
)
def _get_viewpooled_feature_dim(self):
return (
self.view_pooler.get_aggregated_feature_dim(
self.image_feature_extractor.get_feat_dims()
)
if self.view_pooler_enabled
else 0
def _get_viewpooled_feature_dim(self) -> int:
if self.view_pooler is None:
return 0
assert self.image_feature_extractor is not None
return self.view_pooler.get_aggregated_feature_dim(
self.image_feature_extractor.get_feat_dims()
)
def create_raysampler(self):
@ -617,11 +624,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
that image_feature_extractor is enabled when view_pooler is enabled.
"""
if self.view_pooler_enabled:
if not self.image_feature_extractor_enabled:
raise ValueError(
"image_feature_extractor has to be enabled for view pooling"
+ " (I.e. set self.image_feature_extractor_enabled=True)."
)
self.view_pooler = ViewPooler(**self.view_pooler_args)
else:
self.view_pooler = None

View File

@ -17,7 +17,7 @@ sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
raysampler_class_type: AdaptiveRaySampler
renderer_class_type: LSTMRenderer
image_feature_extractor_enabled: true
image_feature_extractor_class_type: ResNetFeatureExtractor
view_pooler_enabled: true
implicit_function_class_type: IdrFeatureField
loss_weights:
@ -73,7 +73,7 @@ renderer_LSTMRenderer_args:
hidden_size: 16
n_feature_channels: 256
verbose: false
image_feature_extractor_args:
image_feature_extractor_ResNetFeatureExtractor_args:
name: resnet34
pretrained: true
stages:

View File

@ -9,6 +9,9 @@ import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
@ -63,7 +66,6 @@ class TestGenericModel(unittest.TestCase):
provide_resnet34()
args = get_default_args(GenericModel)
args.view_pooler_enabled = True
args.image_feature_extractor_enabled = True
args.view_pooler_args.feature_aggregator_class_type = (
"AngleWeightedIdentityFeatureAggregator"
)
@ -77,9 +79,13 @@ class TestGenericModel(unittest.TestCase):
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
self.assertFalse(hasattr(gm, "implicit_function"))
instance_args = OmegaConf.structured(gm)
if DEBUG:
full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
(DATA_DIR / "overrides_full.yaml").write_text(full_yaml)
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:

View File

@ -33,7 +33,7 @@ class TestGenericModel(unittest.TestCase):
# Simple test of a forward and backward pass of the default GenericModel.
device = torch.device("cuda:1")
expand_args_fields(GenericModel)
model = GenericModel()
model = GenericModel(render_image_height=80, render_image_width=80)
model.to(device)
self._one_model_test(model, device)
@ -149,6 +149,37 @@ class TestGenericModel(unittest.TestCase):
)
self.assertGreater(train_preds["objective"].item(), 0)
def test_viewpool(self):
device = torch.device("cuda:1")
args = get_default_args(GenericModel)
args.view_pooler_enabled = True
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False
model = GenericModel(**args)
model.to(device)
n_train_cameras = 2
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
cameras = PerspectiveCameras(R=R, T=T, device=device)
defaulted_args = {
"fg_probability": None,
"depth_map": None,
"mask_crop": None,
}
target_image_rgb = torch.rand(
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
device=device,
)
train_preds = model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
image_rgb=target_image_rgb,
sequence_name=["a"] * n_train_cameras,
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)
def _random_input_tensor(
N: int,