mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Make feature extractor pluggable
Summary: Make ResNetFeatureExtractor be an implementation of FeatureExtractorBase. Reviewed By: davnov134 Differential Revision: D35433098 fbshipit-source-id: 0664a9166a88e150231cfe2eceba017ae55aed3a
This commit is contained in:
parent
cd7b885169
commit
9ec9d057cc
@ -224,7 +224,8 @@ generic_model_args: GenericModel
|
||||
└-- hypernet_args: SRNRaymarchHyperNet
|
||||
└-- pixel_generator_args: SRNPixelGenerator
|
||||
╘== IdrFeatureField
|
||||
└-- image_feature_extractor_args: ResNetFeatureExtractor
|
||||
└-- image_feature_extractor_*_args: FeatureExtractorBase
|
||||
╘== ResNetFeatureExtractor
|
||||
└-- view_sampler_args: ViewSampler
|
||||
└-- feature_aggregator_*_args: FeatureAggregatorBase
|
||||
╘== IdentityFeatureAggregator
|
||||
|
@ -64,7 +64,7 @@ generic_model_args:
|
||||
view_pooler_args:
|
||||
view_sampler_args:
|
||||
masked_sampling: false
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
stages:
|
||||
- 1
|
||||
- 2
|
||||
|
@ -1,6 +1,6 @@
|
||||
generic_model_args:
|
||||
image_feature_extractor_enabled: true
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
add_images: true
|
||||
add_masks: true
|
||||
first_max_pool: true
|
||||
|
@ -1,6 +1,6 @@
|
||||
generic_model_args:
|
||||
image_feature_extractor_enabled: true
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
add_images: true
|
||||
add_masks: true
|
||||
first_max_pool: false
|
||||
|
@ -1,6 +1,6 @@
|
||||
generic_model_args:
|
||||
image_feature_extractor_enabled: true
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
stages:
|
||||
- 1
|
||||
- 2
|
||||
|
@ -0,0 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .feature_extractor import FeatureExtractorBase
|
@ -0,0 +1,44 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||
|
||||
|
||||
class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Base class for an extractor of a set of features from images.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_feat_dims(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
total number of feature dimensions of the output.
|
||||
(i.e. sum_i(dim_i))
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
imgs: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> Dict[Any, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
imgs: A batch of input images of shape `(B, 3, H, W)`.
|
||||
masks: A batch of input masks of shape `(B, 3, H, W)`.
|
||||
|
||||
Returns:
|
||||
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
|
||||
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
|
||||
"""
|
||||
raise NotImplementedError
|
@ -4,7 +4,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@ -12,7 +11,9 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as Fu
|
||||
import torchvision
|
||||
from pytorch3d.implicitron.tools.config import Configurable
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
|
||||
from . import FeatureExtractorBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -32,7 +33,8 @@ _RESNET_MEAN = [0.485, 0.456, 0.406]
|
||||
_RESNET_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
@registry.register
|
||||
class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
"""
|
||||
Implements an image feature extractor. Depending on the settings allows
|
||||
to extract:
|
||||
@ -142,14 +144,14 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
return (img - self._resnet_mean) / self._resnet_std
|
||||
|
||||
def get_feat_dims(self) -> int:
|
||||
return (
|
||||
sum(self._feat_dim.values()) # pyre-fixme[29]
|
||||
if len(self._feat_dim) > 0 # pyre-fixme[6]
|
||||
else 0
|
||||
)
|
||||
# pyre-fixme[29]
|
||||
return sum(self._feat_dim.values())
|
||||
|
||||
def forward(
|
||||
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
|
||||
self,
|
||||
imgs: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Dict[Any, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -164,7 +166,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
out_feats = {}
|
||||
|
||||
imgs_input = imgs
|
||||
if self.image_rescale != 1.0:
|
||||
if self.image_rescale != 1.0 and imgs_input is not None:
|
||||
imgs_resized = Fu.interpolate(
|
||||
imgs_input,
|
||||
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
|
||||
@ -175,12 +177,13 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
else:
|
||||
imgs_resized = imgs_input
|
||||
|
||||
if self.normalize_image:
|
||||
imgs_normed = self._resnet_normalize_image(imgs_resized)
|
||||
else:
|
||||
imgs_normed = imgs_resized
|
||||
|
||||
if len(self.stages) > 0:
|
||||
assert imgs_resized is not None
|
||||
|
||||
if self.normalize_image:
|
||||
imgs_normed = self._resnet_normalize_image(imgs_resized)
|
||||
else:
|
||||
imgs_normed = imgs_resized
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
|
||||
# is not a function.
|
||||
feats = self.stem(imgs_normed)
|
||||
@ -207,7 +210,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
out_feats[MASK_FEATURE_NAME] = masks
|
||||
|
||||
if self.add_images:
|
||||
assert imgs_input is not None
|
||||
assert imgs_resized is not None
|
||||
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
|
||||
|
||||
if self.feature_rescale != 1.0:
|
@ -5,6 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Note: The #noqa comments below are for unused imports of pluggable implementations
|
||||
# which are part of implicitron. They ensure that the registry is prepopulated.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
@ -27,6 +30,8 @@ from visdom import Visdom
|
||||
|
||||
from .autodecoder import Autodecoder
|
||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||
from .feature_extractor import FeatureExtractorBase
|
||||
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
|
||||
from .implicit_function.base import ImplicitFunctionBase
|
||||
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
||||
from .implicit_function.neural_radiance_field import ( # noqa
|
||||
@ -49,7 +54,6 @@ from .renderer.lstm_renderer import LSTMRenderer # noqa
|
||||
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
||||
from .renderer.ray_sampler import RaySamplerBase
|
||||
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
||||
from .resnet_feature_extractor import ResNetFeatureExtractor
|
||||
from .view_pooler.view_pooler import ViewPooler
|
||||
|
||||
|
||||
@ -139,9 +143,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
||||
splatting onto an image grid. Default: False.
|
||||
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
|
||||
view_pool: If True, features are sampled from the source image(s)
|
||||
at the projected 2d locations of the sampled 3d ray points from the target
|
||||
view(s), i.e. this activates step (3) above.
|
||||
num_passes: The specified implicit_function is initialized num_passes
|
||||
times and run sequentially.
|
||||
chunk_size_grid: The total number of points which can be rendered
|
||||
@ -169,10 +170,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
registry.
|
||||
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
||||
generate the images from the target view(s).
|
||||
image_feature_extractor_enabled: If `True`, constructs and enables
|
||||
the `image_feature_extractor` object.
|
||||
image_feature_extractor_class_type: If a str, constructs and enables
|
||||
the `image_feature_extractor` object of this type. Or None if not needed.
|
||||
image_feature_extractor: A module for extrating features from an input image.
|
||||
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
|
||||
This means features are sampled from the source image(s)
|
||||
at the projected 2d locations of the sampled 3d ray points from the target
|
||||
view(s), i.e. this activates step (3) above.
|
||||
view_pooler: An instance of ViewPooler which is used for sampling of
|
||||
image-based features at the 2D projections of a set
|
||||
of 3D points and aggregating the sampled features.
|
||||
@ -215,8 +219,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
renderer: BaseRenderer
|
||||
|
||||
# ---- image feature extractor settings
|
||||
image_feature_extractor_enabled: bool = False
|
||||
image_feature_extractor: Optional[ResNetFeatureExtractor]
|
||||
# (This is only created if view_pooler is enabled)
|
||||
image_feature_extractor: Optional[FeatureExtractorBase]
|
||||
image_feature_extractor_class_type: Optional[str] = "ResNetFeatureExtractor"
|
||||
# ---- view pooler settings
|
||||
view_pooler_enabled: bool = False
|
||||
view_pooler: Optional[ViewPooler]
|
||||
@ -266,6 +271,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
super().__init__()
|
||||
self.view_metrics = ViewMetrics()
|
||||
|
||||
if self.view_pooler_enabled:
|
||||
if self.image_feature_extractor_class_type is None:
|
||||
raise ValueError(
|
||||
"image_feature_extractor must be present for view pooling."
|
||||
)
|
||||
else:
|
||||
self.image_feature_extractor_class_type = None
|
||||
run_auto_creation(self)
|
||||
|
||||
self._implicit_functions = self._construct_implicit_functions()
|
||||
@ -349,20 +361,16 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
# custom_args hold additional arguments to the implicit function.
|
||||
custom_args = {}
|
||||
|
||||
if self.image_feature_extractor_enabled:
|
||||
if self.image_feature_extractor is not None:
|
||||
# (2) Extract features for the image
|
||||
img_feats = self.image_feature_extractor( # pyre-fixme[29]
|
||||
image_rgb, fg_probability
|
||||
)
|
||||
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
|
||||
else:
|
||||
img_feats = None
|
||||
|
||||
if self.view_pooler_enabled:
|
||||
if sequence_name is None:
|
||||
raise ValueError("sequence_name must be provided for view pooling")
|
||||
if not self.image_feature_extractor_enabled:
|
||||
raise ValueError(
|
||||
"image_feature_extractor has to be enabled for for view pooling"
|
||||
+ " (I.e. set self.image_feature_extractor_enabled=True)."
|
||||
)
|
||||
assert img_feats is not None
|
||||
|
||||
# (3-4) Sample features and masks at the ray points.
|
||||
# Aggregate features from multiple views.
|
||||
@ -555,13 +563,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_viewpooled_feature_dim(self):
|
||||
return (
|
||||
self.view_pooler.get_aggregated_feature_dim(
|
||||
self.image_feature_extractor.get_feat_dims()
|
||||
)
|
||||
if self.view_pooler_enabled
|
||||
else 0
|
||||
def _get_viewpooled_feature_dim(self) -> int:
|
||||
if self.view_pooler is None:
|
||||
return 0
|
||||
assert self.image_feature_extractor is not None
|
||||
return self.view_pooler.get_aggregated_feature_dim(
|
||||
self.image_feature_extractor.get_feat_dims()
|
||||
)
|
||||
|
||||
def create_raysampler(self):
|
||||
@ -617,11 +624,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
that image_feature_extractor is enabled when view_pooler is enabled.
|
||||
"""
|
||||
if self.view_pooler_enabled:
|
||||
if not self.image_feature_extractor_enabled:
|
||||
raise ValueError(
|
||||
"image_feature_extractor has to be enabled for view pooling"
|
||||
+ " (I.e. set self.image_feature_extractor_enabled=True)."
|
||||
)
|
||||
self.view_pooler = ViewPooler(**self.view_pooler_args)
|
||||
else:
|
||||
self.view_pooler = None
|
||||
|
@ -17,7 +17,7 @@ sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
raysampler_class_type: AdaptiveRaySampler
|
||||
renderer_class_type: LSTMRenderer
|
||||
image_feature_extractor_enabled: true
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
view_pooler_enabled: true
|
||||
implicit_function_class_type: IdrFeatureField
|
||||
loss_weights:
|
||||
@ -73,7 +73,7 @@ renderer_LSTMRenderer_args:
|
||||
hidden_size: 16
|
||||
n_feature_channels: 256
|
||||
verbose: false
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
name: resnet34
|
||||
pretrained: true
|
||||
stages:
|
||||
|
@ -9,6 +9,9 @@ import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
|
||||
ResNetFeatureExtractor,
|
||||
)
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||
IdrFeatureField,
|
||||
@ -63,7 +66,6 @@ class TestGenericModel(unittest.TestCase):
|
||||
provide_resnet34()
|
||||
args = get_default_args(GenericModel)
|
||||
args.view_pooler_enabled = True
|
||||
args.image_feature_extractor_enabled = True
|
||||
args.view_pooler_args.feature_aggregator_class_type = (
|
||||
"AngleWeightedIdentityFeatureAggregator"
|
||||
)
|
||||
@ -77,9 +79,13 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
|
||||
instance_args = OmegaConf.structured(gm)
|
||||
if DEBUG:
|
||||
full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
(DATA_DIR / "overrides_full.yaml").write_text(full_yaml)
|
||||
remove_unused_components(instance_args)
|
||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
if DEBUG:
|
||||
|
@ -33,7 +33,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
# Simple test of a forward and backward pass of the default GenericModel.
|
||||
device = torch.device("cuda:1")
|
||||
expand_args_fields(GenericModel)
|
||||
model = GenericModel()
|
||||
model = GenericModel(render_image_height=80, render_image_width=80)
|
||||
model.to(device)
|
||||
self._one_model_test(model, device)
|
||||
|
||||
@ -149,6 +149,37 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
self.assertGreater(train_preds["objective"].item(), 0)
|
||||
|
||||
def test_viewpool(self):
|
||||
device = torch.device("cuda:1")
|
||||
args = get_default_args(GenericModel)
|
||||
args.view_pooler_enabled = True
|
||||
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False
|
||||
model = GenericModel(**args)
|
||||
model.to(device)
|
||||
|
||||
n_train_cameras = 2
|
||||
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||
cameras = PerspectiveCameras(R=R, T=T, device=device)
|
||||
|
||||
defaulted_args = {
|
||||
"fg_probability": None,
|
||||
"depth_map": None,
|
||||
"mask_crop": None,
|
||||
}
|
||||
|
||||
target_image_rgb = torch.rand(
|
||||
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
|
||||
device=device,
|
||||
)
|
||||
train_preds = model(
|
||||
camera=cameras,
|
||||
evaluation_mode=EvaluationMode.TRAINING,
|
||||
image_rgb=target_image_rgb,
|
||||
sequence_name=["a"] * n_train_cameras,
|
||||
**defaulted_args,
|
||||
)
|
||||
self.assertGreater(train_preds["objective"].item(), 0)
|
||||
|
||||
|
||||
def _random_input_tensor(
|
||||
N: int,
|
||||
|
Loading…
x
Reference in New Issue
Block a user