mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
ViewPooler class
Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator. Reviewed By: shapovalov Differential Revision: D35852367 fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
This commit is contained in:
parent
bef959c755
commit
47d06c8924
@ -63,6 +63,7 @@ generic_model_args:
|
|||||||
n_pts_per_ray_fine_evaluation: 64
|
n_pts_per_ray_fine_evaluation: 64
|
||||||
append_coarse_samples_to_fine: true
|
append_coarse_samples_to_fine: true
|
||||||
density_noise_std_train: 1.0
|
density_noise_std_train: 1.0
|
||||||
|
view_pooler_args:
|
||||||
view_sampler_args:
|
view_sampler_args:
|
||||||
masked_sampling: false
|
masked_sampling: false
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_args:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
|
image_feature_extractor_enabled: true
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_args:
|
||||||
add_images: true
|
add_images: true
|
||||||
add_masks: true
|
add_masks: true
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
|
image_feature_extractor_enabled: true
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_args:
|
||||||
add_images: true
|
add_images: true
|
||||||
add_masks: true
|
add_masks: true
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
|
image_feature_extractor_enabled: true
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_args:
|
||||||
stages:
|
stages:
|
||||||
- 1
|
- 1
|
||||||
@ -11,6 +12,7 @@ generic_model_args:
|
|||||||
name: resnet34
|
name: resnet34
|
||||||
normalize_image: true
|
normalize_image: true
|
||||||
pretrained: true
|
pretrained: true
|
||||||
|
view_pooler_args:
|
||||||
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
||||||
reduction_functions:
|
reduction_functions:
|
||||||
- AVG
|
- AVG
|
||||||
|
@ -11,7 +11,6 @@ generic_model_args:
|
|||||||
num_passes: 1
|
num_passes: 1
|
||||||
output_rasterized_mc: true
|
output_rasterized_mc: true
|
||||||
sampling_mode_training: mask_sample
|
sampling_mode_training: mask_sample
|
||||||
view_pool: false
|
|
||||||
sequence_autodecoder_args:
|
sequence_autodecoder_args:
|
||||||
n_instances: 20000
|
n_instances: 20000
|
||||||
init_scale: 1.0
|
init_scale: 1.0
|
||||||
|
@ -3,7 +3,7 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: false
|
view_pooler_enabled: false
|
||||||
sequence_autodecoder_args:
|
sequence_autodecoder_args:
|
||||||
n_instances: 20000
|
n_instances: 20000
|
||||||
encoding_dim: 256
|
encoding_dim: 256
|
||||||
|
@ -5,6 +5,6 @@ defaults:
|
|||||||
clip_grad: 1.0
|
clip_grad: 1.0
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
view_pooler_enabled: true
|
||||||
raysampler_args:
|
raysampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 850
|
n_rays_per_image_sampled_from_mask: 850
|
||||||
|
@ -4,7 +4,6 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
|
||||||
raysampler_args:
|
raysampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
n_rays_per_image_sampled_from_mask: 800
|
||||||
n_pts_per_ray_training: 32
|
n_pts_per_ray_training: 32
|
||||||
@ -13,4 +12,6 @@ generic_model_args:
|
|||||||
n_pts_per_ray_fine_training: 16
|
n_pts_per_ray_fine_training: 16
|
||||||
n_pts_per_ray_fine_evaluation: 16
|
n_pts_per_ray_fine_evaluation: 16
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
implicit_function_class_type: NeRFormerImplicitFunction
|
||||||
|
view_pooler_enabled: true
|
||||||
|
view_pooler_args:
|
||||||
feature_aggregator_class_type: IdentityFeatureAggregator
|
feature_aggregator_class_type: IdentityFeatureAggregator
|
||||||
|
@ -4,7 +4,6 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
|
||||||
raysampler_args:
|
raysampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
n_rays_per_image_sampled_from_mask: 800
|
||||||
n_pts_per_ray_training: 32
|
n_pts_per_ray_training: 32
|
||||||
@ -13,4 +12,6 @@ generic_model_args:
|
|||||||
n_pts_per_ray_fine_training: 16
|
n_pts_per_ray_fine_training: 16
|
||||||
n_pts_per_ray_fine_evaluation: 16
|
n_pts_per_ray_fine_evaluation: 16
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
implicit_function_class_type: NeRFormerImplicitFunction
|
||||||
|
view_pooler_enabled: true
|
||||||
|
view_pooler_args:
|
||||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||||
|
@ -3,7 +3,7 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: false
|
view_pooler_enabled: false
|
||||||
n_train_target_views: -1
|
n_train_target_views: -1
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
loss_weights:
|
loss_weights:
|
||||||
|
@ -4,7 +4,6 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
view_pool: true
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
n_train_target_views: -1
|
n_train_target_views: -1
|
||||||
loss_weights:
|
loss_weights:
|
||||||
@ -25,6 +24,7 @@ generic_model_args:
|
|||||||
stratified_point_sampling_evaluation: false
|
stratified_point_sampling_evaluation: false
|
||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
implicit_function_class_type: SRNImplicitFunction
|
implicit_function_class_type: SRNImplicitFunction
|
||||||
|
view_pooler_enabled: true
|
||||||
solver_args:
|
solver_args:
|
||||||
breed: adam
|
breed: adam
|
||||||
lr: 5.0e-05
|
lr: 5.0e-05
|
||||||
|
@ -9,7 +9,7 @@ generic_model_args:
|
|||||||
loss_eikonal: 0.1
|
loss_eikonal: 0.1
|
||||||
chunk_size_grid: 65536
|
chunk_size_grid: 65536
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
view_pool: false
|
view_pooler_enabled: false
|
||||||
implicit_function_IdrFeatureField_args:
|
implicit_function_IdrFeatureField_args:
|
||||||
n_harmonic_functions_xyz: 6
|
n_harmonic_functions_xyz: 6
|
||||||
bias: 0.6
|
bias: 0.6
|
||||||
|
@ -4,6 +4,6 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
view_pooler_enabled: true
|
||||||
raysampler_args:
|
raysampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 850
|
n_rays_per_image_sampled_from_mask: 850
|
||||||
|
@ -4,7 +4,7 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
view_pooler_enabled: true
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
implicit_function_class_type: NeRFormerImplicitFunction
|
||||||
raysampler_args:
|
raysampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
n_rays_per_image_sampled_from_mask: 800
|
||||||
@ -13,4 +13,5 @@ generic_model_args:
|
|||||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||||
n_pts_per_ray_fine_training: 16
|
n_pts_per_ray_fine_training: 16
|
||||||
n_pts_per_ray_fine_evaluation: 16
|
n_pts_per_ray_fine_evaluation: 16
|
||||||
|
view_pooler_args:
|
||||||
feature_aggregator_class_type: IdentityFeatureAggregator
|
feature_aggregator_class_type: IdentityFeatureAggregator
|
||||||
|
@ -4,7 +4,7 @@ defaults:
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
view_pool: false
|
view_pooler_enabled: false
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_rgb_mse: 200.0
|
loss_rgb_mse: 200.0
|
||||||
loss_prev_stage_rgb_mse: 0.0
|
loss_prev_stage_rgb_mse: 0.0
|
||||||
|
@ -5,7 +5,7 @@ defaults:
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
view_pool: true
|
view_pooler_enabled: true
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_rgb_mse: 200.0
|
loss_rgb_mse: 200.0
|
||||||
loss_prev_stage_rgb_mse: 0.0
|
loss_prev_stage_rgb_mse: 0.0
|
||||||
|
@ -49,8 +49,7 @@ from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
|||||||
from .renderer.ray_sampler import RaySampler
|
from .renderer.ray_sampler import RaySampler
|
||||||
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
||||||
from .resnet_feature_extractor import ResNetFeatureExtractor
|
from .resnet_feature_extractor import ResNetFeatureExtractor
|
||||||
from .view_pooling.feature_aggregation import FeatureAggregatorBase
|
from .view_pooler.view_pooler import ViewPooler
|
||||||
from .view_pooling.view_sampling import ViewSampler
|
|
||||||
|
|
||||||
|
|
||||||
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
||||||
@ -167,16 +166,13 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
registry.
|
registry.
|
||||||
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
||||||
generate the images from the target view(s).
|
generate the images from the target view(s).
|
||||||
|
image_feature_extractor_enabled: If `True`, constructs and enables
|
||||||
|
the `image_feature_extractor` object.
|
||||||
image_feature_extractor: A module for extrating features from an input image.
|
image_feature_extractor: A module for extrating features from an input image.
|
||||||
view_sampler: An instance of ViewSampler which is used for sampling of
|
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
|
||||||
|
view_pooler: An instance of ViewPooler which is used for sampling of
|
||||||
image-based features at the 2D projections of a set
|
image-based features at the 2D projections of a set
|
||||||
of 3D points.
|
of 3D points and aggregating the sampled features.
|
||||||
feature_aggregator_class_type: The name of the feature aggregator class which
|
|
||||||
is available in the global registry.
|
|
||||||
feature_aggregator: A feature aggregator class which inherits from
|
|
||||||
FeatureAggregatorBase. Typically, the aggregated features and their
|
|
||||||
masks are output by a `ViewSampler` which samples feature tensors extracted
|
|
||||||
from a set of source images. FeatureAggregator executes step (4) above.
|
|
||||||
implicit_function_class_type: The type of implicit function to use which
|
implicit_function_class_type: The type of implicit function to use which
|
||||||
is available in the global registry.
|
is available in the global registry.
|
||||||
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
|
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
|
||||||
@ -195,7 +191,6 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
mask_threshold: float = 0.5
|
mask_threshold: float = 0.5
|
||||||
output_rasterized_mc: bool = False
|
output_rasterized_mc: bool = False
|
||||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
view_pool: bool = False
|
|
||||||
num_passes: int = 1
|
num_passes: int = 1
|
||||||
chunk_size_grid: int = 4096
|
chunk_size_grid: int = 4096
|
||||||
render_features_dimensions: int = 3
|
render_features_dimensions: int = 3
|
||||||
@ -215,13 +210,12 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
||||||
renderer: BaseRenderer
|
renderer: BaseRenderer
|
||||||
|
|
||||||
# ---- view sampling settings - used if view_pool=True
|
# ---- image feature extractor settings
|
||||||
# (This is only created if view_pool is False)
|
image_feature_extractor_enabled: bool = False
|
||||||
image_feature_extractor: ResNetFeatureExtractor
|
image_feature_extractor: Optional[ResNetFeatureExtractor]
|
||||||
view_sampler: ViewSampler
|
# ---- view pooler settings
|
||||||
# ---- ---- view sampling feature aggregator settings
|
view_pooler_enabled: bool = False
|
||||||
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
|
view_pooler: Optional[ViewPooler]
|
||||||
feature_aggregator: FeatureAggregatorBase
|
|
||||||
|
|
||||||
# ---- implicit function settings
|
# ---- implicit function settings
|
||||||
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
||||||
@ -356,32 +350,34 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
# custom_args hold additional arguments to the implicit function.
|
# custom_args hold additional arguments to the implicit function.
|
||||||
custom_args = {}
|
custom_args = {}
|
||||||
|
|
||||||
if self.view_pool:
|
if self.image_feature_extractor_enabled:
|
||||||
|
# (2) Extract features for the image
|
||||||
|
img_feats = self.image_feature_extractor( # pyre-fixme[29]
|
||||||
|
image_rgb, fg_probability
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.view_pooler_enabled:
|
||||||
if sequence_name is None:
|
if sequence_name is None:
|
||||||
raise ValueError("sequence_name must be provided for view pooling")
|
raise ValueError("sequence_name must be provided for view pooling")
|
||||||
# (2) Extract features for the image
|
if not self.image_feature_extractor_enabled:
|
||||||
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
|
raise ValueError(
|
||||||
|
"image_feature_extractor has to be enabled for for view pooling"
|
||||||
|
+ " (I.e. set self.image_feature_extractor_enabled=True)."
|
||||||
|
)
|
||||||
|
|
||||||
# (3) Sample features and masks at the ray points
|
# (3-4) Sample features and masks at the ray points.
|
||||||
curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731
|
# Aggregate features from multiple views.
|
||||||
|
def curried_viewpooler(pts):
|
||||||
|
return self.view_pooler(
|
||||||
pts=pts,
|
pts=pts,
|
||||||
seq_id_pts=sequence_name[:n_targets],
|
seq_id_pts=sequence_name[:n_targets],
|
||||||
camera=camera,
|
camera=camera,
|
||||||
seq_id_camera=sequence_name,
|
seq_id_camera=sequence_name,
|
||||||
feats=img_feats,
|
feats=img_feats,
|
||||||
masks=mask_crop,
|
masks=mask_crop,
|
||||||
) # returns feats_sampled, masks_sampled
|
)
|
||||||
|
|
||||||
# (4) Aggregate features from multiple views
|
custom_args["fun_viewpool"] = curried_viewpooler
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
|
||||||
curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731
|
|
||||||
*curried_view_sampler(pts=pts),
|
|
||||||
pts=pts,
|
|
||||||
camera=camera,
|
|
||||||
) # TODO: do we need to pass a callback rather than compute here?
|
|
||||||
# precomputing will be faster for 2 passes
|
|
||||||
# -> but this is important for non-nerf
|
|
||||||
custom_args["fun_viewpool"] = curried_view_pool
|
|
||||||
|
|
||||||
global_code = None
|
global_code = None
|
||||||
if self.sequence_autodecoder.n_instances > 0:
|
if self.sequence_autodecoder.n_instances > 0:
|
||||||
@ -562,10 +558,10 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
|
|
||||||
def _get_viewpooled_feature_dim(self):
|
def _get_viewpooled_feature_dim(self):
|
||||||
return (
|
return (
|
||||||
self.feature_aggregator.get_aggregated_feature_dim(
|
self.view_pooler.get_aggregated_feature_dim(
|
||||||
self.image_feature_extractor.get_feat_dims()
|
self.image_feature_extractor.get_feat_dims()
|
||||||
)
|
)
|
||||||
if self.view_pool
|
if self.view_pooler_enabled
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -583,15 +579,20 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
"object_bounding_sphere"
|
"object_bounding_sphere"
|
||||||
] = self.raysampler_args["scene_extent"]
|
] = self.raysampler_args["scene_extent"]
|
||||||
|
|
||||||
def create_image_feature_extractor(self):
|
def create_view_pooler(self):
|
||||||
"""
|
"""
|
||||||
Custom creation function called by run_auto_creation so that the
|
Custom creation function called by run_auto_creation checking
|
||||||
image_feature_extractor is not created if it is not be needed.
|
that image_feature_extractor is enabled when view_pooler is enabled.
|
||||||
"""
|
"""
|
||||||
if self.view_pool:
|
if self.view_pooler_enabled:
|
||||||
self.image_feature_extractor = ResNetFeatureExtractor(
|
if not self.image_feature_extractor_enabled:
|
||||||
**self.image_feature_extractor_args
|
raise ValueError(
|
||||||
|
"image_feature_extractor has to be enabled for view pooling"
|
||||||
|
+ " (I.e. set self.image_feature_extractor_enabled=True)."
|
||||||
)
|
)
|
||||||
|
self.view_pooler = ViewPooler(**self.view_pooler_args)
|
||||||
|
else:
|
||||||
|
self.view_pooler = None
|
||||||
|
|
||||||
def create_implicit_function(self) -> None:
|
def create_implicit_function(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -652,10 +653,9 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
|||||||
)
|
)
|
||||||
|
|
||||||
if implicit_function_type.requires_pooling_without_aggregation():
|
if implicit_function_type.requires_pooling_without_aggregation():
|
||||||
has_aggregation = hasattr(self.feature_aggregator, "reduction_functions")
|
if self.view_pooler_enabled and self.view_pooler.has_aggregation():
|
||||||
if not self.view_pool or has_aggregation:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Chosen implicit function requires view pooling without aggregation."
|
"The chosen implicit function requires view pooling without aggregation."
|
||||||
)
|
)
|
||||||
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
|
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
|
||||||
config = getattr(self, config_name, None)
|
config = getattr(self, config_name, None)
|
||||||
|
@ -141,11 +141,12 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
|||||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||||
return (img - self._resnet_mean) / self._resnet_std
|
return (img - self._resnet_mean) / self._resnet_std
|
||||||
|
|
||||||
def get_feat_dims(self, size_dict: bool = False):
|
def get_feat_dims(self) -> int:
|
||||||
if size_dict:
|
return (
|
||||||
return copy.deepcopy(self._feat_dim)
|
sum(self._feat_dim.values()) # pyre-fixme[29]
|
||||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.values)[[Na...
|
if len(self._feat_dim) > 0 # pyre-fixme[6]
|
||||||
return sum(self._feat_dim.values())
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
|
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
|
||||||
|
@ -10,7 +10,7 @@ from typing import Dict, Optional, Sequence, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pytorch3d.implicitron.models.view_pooling.view_sampling import (
|
from pytorch3d.implicitron.models.view_pooler.view_sampler import (
|
||||||
cameras_points_cartesian_product,
|
cameras_points_cartesian_product,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
@ -82,6 +82,33 @@ class FeatureAggregatorBase(ABC, ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_aggregated_feature_dim(
|
||||||
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns the final dimensionality of the output aggregated features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats_or_feats_dim: Either a `dict` of sampled features `{f_i: t_i}` corresponding
|
||||||
|
to the `feats_sampled` argument of `forward`,
|
||||||
|
or an `int` representing the sum of dimensionalities of each `t_i`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
aggregated_feature_dim: The final dimensionality of the output
|
||||||
|
aggregated features.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def has_aggregation(self) -> bool:
|
||||||
|
"""
|
||||||
|
Specifies whether the aggregator reduces the output `reduce_dim` dimension to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
has_aggregation: `True` if `reduce_dim==1`, else `False`.
|
||||||
|
"""
|
||||||
|
return hasattr(self, "reduction_functions")
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
||||||
@ -94,8 +121,10 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(
|
||||||
return _get_reduction_aggregator_feature_dim(feats, [])
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -155,8 +184,12 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(
|
||||||
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
return _get_reduction_aggregator_feature_dim(
|
||||||
|
feats_or_feats_dim, self.reduction_functions
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -246,8 +279,12 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(
|
||||||
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
return _get_reduction_aggregator_feature_dim(
|
||||||
|
feats_or_feats_dim, self.reduction_functions
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -345,8 +382,10 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(
|
||||||
return _get_reduction_aggregator_feature_dim(feats, [])
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
128
pytorch3d/implicitron/models/view_pooler/view_pooler.py
Normal file
128
pytorch3d/implicitron/models/view_pooler/view_pooler.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import Configurable, run_auto_creation
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .feature_aggregator import FeatureAggregatorBase
|
||||||
|
from .view_sampler import ViewSampler
|
||||||
|
|
||||||
|
|
||||||
|
# pyre-ignore: 13
|
||||||
|
class ViewPooler(Configurable, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implements sampling of image-based features at the 2d projections of a set
|
||||||
|
of 3D points, and a subsequent aggregation of the resulting set of features
|
||||||
|
per-point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
view_sampler: An instance of ViewSampler which is used for sampling of
|
||||||
|
image-based features at the 2D projections of a set
|
||||||
|
of 3D points.
|
||||||
|
feature_aggregator_class_type: The name of the feature aggregator class which
|
||||||
|
is available in the global registry.
|
||||||
|
feature_aggregator: A feature aggregator class which inherits from
|
||||||
|
FeatureAggregatorBase. Typically, the aggregated features and their
|
||||||
|
masks are output by a `ViewSampler` which samples feature tensors extracted
|
||||||
|
from a set of source images. FeatureAggregator executes step (4) above.
|
||||||
|
"""
|
||||||
|
|
||||||
|
view_sampler: ViewSampler
|
||||||
|
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
|
||||||
|
feature_aggregator: FeatureAggregatorBase
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||||
|
"""
|
||||||
|
Returns the final dimensionality of the output aggregated features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats: Either a `dict` of sampled features `{f_i: t_i}` corresponding
|
||||||
|
to the `feats_sampled` argument of `feature_aggregator,forward`,
|
||||||
|
or an `int` representing the sum of dimensionalities of each `t_i`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
aggregated_feature_dim: The final dimensionality of the output
|
||||||
|
aggregated features.
|
||||||
|
"""
|
||||||
|
return self.feature_aggregator.get_aggregated_feature_dim(feats)
|
||||||
|
|
||||||
|
def has_aggregation(self):
|
||||||
|
"""
|
||||||
|
Specifies whether the `feature_aggregator` reduces the output `reduce_dim`
|
||||||
|
dimension to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
has_aggregation: `True` if `reduce_dim==1`, else `False`.
|
||||||
|
"""
|
||||||
|
return self.feature_aggregator.has_aggregation()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
*, # force kw args
|
||||||
|
pts: torch.Tensor,
|
||||||
|
seq_id_pts: Union[List[int], List[str], torch.LongTensor],
|
||||||
|
camera: CamerasBase,
|
||||||
|
seq_id_camera: Union[List[int], List[str], torch.LongTensor],
|
||||||
|
feats: Dict[str, torch.Tensor],
|
||||||
|
masks: Optional[torch.Tensor],
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Project each point cloud from a batch of point clouds to corresponding
|
||||||
|
input cameras, sample features at the 2D projection locations in a batch
|
||||||
|
of source images, and aggregate the pointwise sampled features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords.
|
||||||
|
seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes
|
||||||
|
from which `pts` were extracted, or a list of string names.
|
||||||
|
camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`.
|
||||||
|
seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes
|
||||||
|
corresponding to cameras in `camera`, or a list of string names.
|
||||||
|
feats: a dict of tensors of per-image features `{feat_i: T_i}`.
|
||||||
|
Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`.
|
||||||
|
masks: `[n_cameras x 1 x H x W]`, define valid image regions
|
||||||
|
for sampling `feats`.
|
||||||
|
Returns:
|
||||||
|
feats_aggregated: If `feature_aggregator.concatenate_output==True`, a tensor
|
||||||
|
of shape `(pts_batch, reduce_dim, n_pts, sum(dim_1, ... dim_N))`
|
||||||
|
containing the aggregated features. `reduce_dim` depends on
|
||||||
|
the specific feature aggregator implementation and typically
|
||||||
|
equals 1 or `n_cameras`.
|
||||||
|
If `feature_aggregator.concatenate_output==False`, the aggregator
|
||||||
|
does not concatenate the aggregated features and returns a dictionary
|
||||||
|
of per-feature aggregations `{f_i: t_i_aggregated}` instead.
|
||||||
|
Each `t_i_aggregated` is of shape
|
||||||
|
`(pts_batch, reduce_dim, n_pts, aggr_dim_i)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# (1) Sample features and masks at the ray points
|
||||||
|
sampled_feats, sampled_masks = self.view_sampler(
|
||||||
|
pts=pts,
|
||||||
|
seq_id_pts=seq_id_pts,
|
||||||
|
camera=camera,
|
||||||
|
seq_id_camera=seq_id_camera,
|
||||||
|
feats=feats,
|
||||||
|
masks=masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# (2) Aggregate features from multiple views
|
||||||
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
|
feats_aggregated = self.feature_aggregator( # noqa: E731
|
||||||
|
sampled_feats,
|
||||||
|
sampled_masks,
|
||||||
|
pts=pts,
|
||||||
|
camera=camera,
|
||||||
|
) # TODO: do we need to pass a callback rather than compute here?
|
||||||
|
|
||||||
|
return feats_aggregated
|
@ -8,7 +8,6 @@ bg_color:
|
|||||||
- 0.0
|
- 0.0
|
||||||
- 0.0
|
- 0.0
|
||||||
- 0.0
|
- 0.0
|
||||||
view_pool: false
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
chunk_size_grid: 4096
|
chunk_size_grid: 4096
|
||||||
render_features_dimensions: 3
|
render_features_dimensions: 3
|
||||||
@ -17,7 +16,8 @@ n_train_target_views: 1
|
|||||||
sampling_mode_training: mask_sample
|
sampling_mode_training: mask_sample
|
||||||
sampling_mode_evaluation: full_grid
|
sampling_mode_evaluation: full_grid
|
||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
image_feature_extractor_enabled: true
|
||||||
|
view_pooler_enabled: true
|
||||||
implicit_function_class_type: IdrFeatureField
|
implicit_function_class_type: IdrFeatureField
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_rgb_mse: 1.0
|
loss_rgb_mse: 1.0
|
||||||
@ -91,10 +91,12 @@ image_feature_extractor_args:
|
|||||||
add_images: true
|
add_images: true
|
||||||
global_average_pool: false
|
global_average_pool: false
|
||||||
feature_rescale: 1.0
|
feature_rescale: 1.0
|
||||||
view_sampler_args:
|
view_pooler_args:
|
||||||
|
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||||
|
view_sampler_args:
|
||||||
masked_sampling: false
|
masked_sampling: false
|
||||||
sampling_mode: bilinear
|
sampling_mode: bilinear
|
||||||
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
||||||
exclude_target_view: true
|
exclude_target_view: true
|
||||||
exclude_target_view_mask_features: true
|
exclude_target_view_mask_features: true
|
||||||
concatenate_output: true
|
concatenate_output: true
|
||||||
|
@ -20,9 +20,8 @@ from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
|
|||||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||||
MultiPassEmissionAbsorptionRenderer,
|
MultiPassEmissionAbsorptionRenderer,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
|
from pytorch3d.implicitron.models.view_pooler.feature_aggregator import (
|
||||||
AngleWeightedIdentityFeatureAggregator,
|
AngleWeightedIdentityFeatureAggregator,
|
||||||
AngleWeightedReductionFeatureAggregator,
|
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
get_default_args,
|
get_default_args,
|
||||||
@ -32,7 +31,10 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
|
|
||||||
if os.environ.get("FB_TEST", False):
|
if os.environ.get("FB_TEST", False):
|
||||||
from common_testing import get_tests_dir
|
from common_testing import get_tests_dir
|
||||||
|
|
||||||
|
from .common_resources import provide_lpips_vgg
|
||||||
else:
|
else:
|
||||||
|
from common_resources import provide_lpips_vgg # noqa
|
||||||
from tests.common_testing import get_tests_dir
|
from tests.common_testing import get_tests_dir
|
||||||
|
|
||||||
DATA_DIR = get_tests_dir() / "implicitron/data"
|
DATA_DIR = get_tests_dir() / "implicitron/data"
|
||||||
@ -46,28 +48,33 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
|
|
||||||
def test_create_gm(self):
|
def test_create_gm(self):
|
||||||
|
provide_lpips_vgg()
|
||||||
args = get_default_args(GenericModel)
|
args = get_default_args(GenericModel)
|
||||||
gm = GenericModel(**args)
|
gm = GenericModel(**args)
|
||||||
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
|
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
|
||||||
self.assertIsInstance(
|
|
||||||
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
|
|
||||||
)
|
|
||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
||||||
)
|
)
|
||||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||||
self.assertFalse(hasattr(gm, "image_feature_extractor"))
|
self.assertIsNone(gm.view_pooler)
|
||||||
|
self.assertIsNone(gm.image_feature_extractor)
|
||||||
|
|
||||||
def test_create_gm_overrides(self):
|
def _test_create_gm_overrides(self):
|
||||||
|
provide_lpips_vgg()
|
||||||
args = get_default_args(GenericModel)
|
args = get_default_args(GenericModel)
|
||||||
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
|
args.view_pooler_enabled = True
|
||||||
|
args.image_feature_extractor_enabled = True
|
||||||
|
args.view_pooler_args.feature_aggregator_class_type = (
|
||||||
|
"AngleWeightedIdentityFeatureAggregator"
|
||||||
|
)
|
||||||
args.implicit_function_class_type = "IdrFeatureField"
|
args.implicit_function_class_type = "IdrFeatureField"
|
||||||
args.renderer_class_type = "LSTMRenderer"
|
args.renderer_class_type = "LSTMRenderer"
|
||||||
gm = GenericModel(**args)
|
gm = GenericModel(**args)
|
||||||
self.assertIsInstance(gm.renderer, LSTMRenderer)
|
self.assertIsInstance(gm.renderer, LSTMRenderer)
|
||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
|
gm.view_pooler.feature_aggregator,
|
||||||
|
AngleWeightedIdentityFeatureAggregator,
|
||||||
)
|
)
|
||||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||||
|
@ -56,7 +56,12 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
cfg = _load_model_config_from_yaml(str(config_file))
|
cfg = _load_model_config_from_yaml(str(config_file))
|
||||||
model = GenericModel(**cfg)
|
model = GenericModel(**cfg)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
self._one_model_test(model, device, eval_test=True)
|
self._one_model_test(
|
||||||
|
model,
|
||||||
|
device,
|
||||||
|
eval_test=True,
|
||||||
|
bw_test=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _one_model_test(
|
def _one_model_test(
|
||||||
self,
|
self,
|
||||||
@ -64,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
device,
|
device,
|
||||||
n_train_cameras: int = 5,
|
n_train_cameras: int = 5,
|
||||||
eval_test: bool = True,
|
eval_test: bool = True,
|
||||||
|
bw_test: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||||
@ -86,7 +92,11 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
**random_args,
|
**random_args,
|
||||||
evaluation_mode=EvaluationMode.TRAINING,
|
evaluation_mode=EvaluationMode.TRAINING,
|
||||||
)
|
)
|
||||||
self.assertGreater(train_preds["objective"].item(), 0)
|
self.assertTrue(
|
||||||
|
train_preds["objective"].isfinite().item()
|
||||||
|
) # check finiteness of the objective
|
||||||
|
|
||||||
|
if bw_test:
|
||||||
train_preds["objective"].backward()
|
train_preds["objective"].backward()
|
||||||
|
|
||||||
if eval_test:
|
if eval_test:
|
||||||
|
@ -9,7 +9,7 @@ import unittest
|
|||||||
|
|
||||||
import pytorch3d as pt3d
|
import pytorch3d as pt3d
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler
|
from pytorch3d.implicitron.models.view_pooler.view_sampler import ViewSampler
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user